| 1 | //===- TosaValidation.cpp ------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Validate if TOSA dialect input matchs with the specification for given |
| 10 | // requirements. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Tosa/IR/TargetEnv.h" |
| 15 | #include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h" |
| 16 | #include "mlir/Dialect/Tosa/Transforms/Passes.h" |
| 17 | #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc" |
| 18 | |
| 19 | #include <string> |
| 20 | |
| 21 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 22 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
| 23 | #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" |
| 24 | #include "mlir/IR/Builders.h" |
| 25 | #include "mlir/IR/BuiltinOps.h" |
| 26 | #include "mlir/IR/Matchers.h" |
| 27 | #include "mlir/IR/TypeUtilities.h" |
| 28 | #include "mlir/Pass/Pass.h" |
| 29 | #include "mlir/Transforms/DialectConversion.h" |
| 30 | #include "llvm/ADT/StringExtras.h" |
| 31 | |
| 32 | namespace mlir { |
| 33 | namespace tosa { |
| 34 | #define GEN_PASS_DEF_TOSAVALIDATION |
| 35 | #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" |
| 36 | } // namespace tosa |
| 37 | } // namespace mlir |
| 38 | |
| 39 | using namespace mlir; |
| 40 | using namespace mlir::tosa; |
| 41 | |
| 42 | namespace { |
| 43 | |
| 44 | static LogicalResult |
| 45 | checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) { |
| 46 | for (const auto index : operandIndices) { |
| 47 | Attribute attr; |
| 48 | if (!matchPattern(op->getOperand(idx: index), m_Constant(&attr))) { |
| 49 | return op->emitOpError(message: "expected compile time resolvable constant, but " |
| 50 | "got variable value for operand #" ) |
| 51 | << index; |
| 52 | } |
| 53 | } |
| 54 | return success(); |
| 55 | } |
| 56 | |
| 57 | static LogicalResult checkConstantOperandMul(Operation *op, |
| 58 | const TargetEnv &env) { |
| 59 | if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) { |
| 60 | // Check 'shift' |
| 61 | return checkConstantOperands(op, operandIndices: {2}); |
| 62 | } |
| 63 | return success(); |
| 64 | } |
| 65 | |
| 66 | static LogicalResult checkConstantOperandTable(Operation *op, |
| 67 | const TargetEnv &env) { |
| 68 | if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) { |
| 69 | // Check 'table' |
| 70 | return checkConstantOperands(op, operandIndices: {1}); |
| 71 | } |
| 72 | return success(); |
| 73 | } |
| 74 | |
| 75 | static LogicalResult checkConstantOperandPad(Operation *op, |
| 76 | const TargetEnv &env) { |
| 77 | if (auto padOp = dyn_cast<tosa::PadOp>(op)) { |
| 78 | // Assume this op is zero-padding if padConst is not presented |
| 79 | if (!env.allows(Extension::dynamic) && padOp.getPadConst()) |
| 80 | // Check 'pad_const' |
| 81 | // Note: 'padding' (operand 1) is not checked as it is a tosa.shape type |
| 82 | return checkConstantOperands(op, operandIndices: {2}); |
| 83 | } |
| 84 | return success(); |
| 85 | } |
| 86 | |
| 87 | static LogicalResult checkConstantOperandRescale(Operation *op, |
| 88 | const TargetEnv &env) { |
| 89 | if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) { |
| 90 | // Check 'multiplier', 'shift', 'input_zp' and 'output_zp' |
| 91 | return checkConstantOperands(op, {1, 2, 3, 4}); |
| 92 | } |
| 93 | return success(); |
| 94 | } |
| 95 | |
| 96 | template <typename T> |
| 97 | static LogicalResult checkConstantOperandConvOps(Operation *op, |
| 98 | const TargetEnv &env) { |
| 99 | if (!env.allows(Extension::dynamic) && isa<T>(op)) { |
| 100 | // Check 'input_zp' and 'weight_zp' |
| 101 | return checkConstantOperands(op, {3, 4}); |
| 102 | } |
| 103 | return success(); |
| 104 | } |
| 105 | |
| 106 | static LogicalResult checkConstantOperandMatMul(Operation *op, |
| 107 | const TargetEnv &env) { |
| 108 | if (!env.allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) { |
| 109 | // Check 'A_zp' and 'B_zp' |
| 110 | return checkConstantOperands(op, {2, 3}); |
| 111 | } |
| 112 | return success(); |
| 113 | } |
| 114 | |
| 115 | static LogicalResult checkConstantOperandAvgPool2d(Operation *op, |
| 116 | const TargetEnv &env) { |
| 117 | if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) { |
| 118 | // Check 'input_zp' and 'output_zp' |
| 119 | return checkConstantOperands(op, {1, 2}); |
| 120 | } |
| 121 | return success(); |
| 122 | } |
| 123 | |
| 124 | static LogicalResult checkConstantOperandNegate(Operation *op, |
| 125 | const TargetEnv &env) { |
| 126 | if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) { |
| 127 | // Check 'input1_zp' and 'output_zp' |
| 128 | return checkConstantOperands(op, {1, 2}); |
| 129 | } |
| 130 | return success(); |
| 131 | } |
| 132 | |
| 133 | struct TosaLevel { |
| 134 | int32_t MAX_RANK = 0; |
| 135 | int32_t MAX_KERNEL = 0; |
| 136 | int32_t MAX_STRIDE = 0; |
| 137 | int32_t MAX_SCALE = 0; |
| 138 | int32_t MAX_LOG2_SIZE = 0; |
| 139 | int32_t MAX_NESTING = 0; |
| 140 | int32_t MAX_TENSOR_LIST_SIZE = 0; |
| 141 | |
| 142 | bool operator==(const TosaLevel &rhs) { |
| 143 | return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL && |
| 144 | MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE && |
| 145 | MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE && |
| 146 | MAX_NESTING == rhs.MAX_NESTING && |
| 147 | MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE; |
| 148 | } |
| 149 | }; |
| 150 | |
| 151 | static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {.MAX_RANK: 6, .MAX_KERNEL: 8192, .MAX_STRIDE: 8192, .MAX_SCALE: 256, .MAX_LOG2_SIZE: 31, .MAX_NESTING: 6, .MAX_TENSOR_LIST_SIZE: 64}; |
| 152 | static constexpr TosaLevel TOSA_LEVEL_NONE = {.MAX_RANK: 32, .MAX_KERNEL: 2147483647, .MAX_STRIDE: 2147483647, .MAX_SCALE: 2048, |
| 153 | .MAX_LOG2_SIZE: 63, .MAX_NESTING: 256, .MAX_TENSOR_LIST_SIZE: 256}; |
| 154 | |
| 155 | //===----------------------------------------------------------------------===// |
| 156 | // TOSA Validation Pass. |
| 157 | //===----------------------------------------------------------------------===// |
| 158 | |
| 159 | struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> { |
| 160 | public: |
| 161 | explicit TosaValidation() { populateConstantOperandChecks(); } |
| 162 | |
| 163 | explicit TosaValidation(const TosaValidationOptions &options) |
| 164 | : TosaValidation() { |
| 165 | this->profile = options.profile; |
| 166 | this->extension = options.extension; |
| 167 | this->strictOpSpecAlignment = options.strictOpSpecAlignment; |
| 168 | this->allowInvalidOpDatatypeCombinations = |
| 169 | options.allowInvalidOpDatatypeCombinations; |
| 170 | this->level = options.level; |
| 171 | } |
| 172 | void runOnOperation() final; |
| 173 | |
| 174 | LogicalResult applyConstantOperandCheck(Operation *op) { |
| 175 | for (auto &checker : constCheckers) { |
| 176 | if (failed(checker(op, targetEnv))) |
| 177 | return failure(); |
| 178 | } |
| 179 | return success(); |
| 180 | } |
| 181 | |
| 182 | LogicalResult applyLevelCheck(Operation *op); |
| 183 | LogicalResult applyAttributeCheck(Operation *op); |
| 184 | |
| 185 | // check variable read/write data types against variable declarations |
| 186 | LogicalResult applyVariableCheck(Operation *op); |
| 187 | |
| 188 | // check error if conditions |
| 189 | LogicalResult applyErrorIfCheck(Operation *op); |
| 190 | |
| 191 | private: |
| 192 | void populateConstantOperandChecks() { |
| 193 | constCheckers.emplace_back(checkConstantOperandMul); |
| 194 | constCheckers.emplace_back(checkConstantOperandTable); |
| 195 | constCheckers.emplace_back(checkConstantOperandPad); |
| 196 | constCheckers.emplace_back(checkConstantOperandRescale); |
| 197 | constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>); |
| 198 | constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>); |
| 199 | constCheckers.emplace_back( |
| 200 | checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>); |
| 201 | constCheckers.emplace_back( |
| 202 | checkConstantOperandConvOps<tosa::TransposeConv2DOp>); |
| 203 | constCheckers.emplace_back(checkConstantOperandMatMul); |
| 204 | constCheckers.emplace_back(checkConstantOperandAvgPool2d); |
| 205 | constCheckers.emplace_back(checkConstantOperandNegate); |
| 206 | } |
| 207 | |
| 208 | bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) { |
| 209 | if (v > tosaLevel.MAX_KERNEL) { |
| 210 | op->emitOpError() << "failed level check: " << checkDesc; |
| 211 | return false; |
| 212 | } |
| 213 | return true; |
| 214 | } |
| 215 | |
| 216 | bool levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) { |
| 217 | if (v > tosaLevel.MAX_STRIDE) { |
| 218 | op->emitOpError() << "failed level check: " << checkDesc; |
| 219 | return false; |
| 220 | } |
| 221 | return true; |
| 222 | } |
| 223 | |
| 224 | bool levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) { |
| 225 | if (v > tosaLevel.MAX_SCALE) { |
| 226 | op->emitOpError() << "failed level check: " << checkDesc; |
| 227 | return false; |
| 228 | } |
| 229 | return true; |
| 230 | } |
| 231 | |
| 232 | bool levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) { |
| 233 | if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) { |
| 234 | op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: " |
| 235 | << checkDesc; |
| 236 | return false; |
| 237 | } |
| 238 | return true; |
| 239 | } |
| 240 | |
| 241 | // Perform the Level Rank check on the tensor type. |
| 242 | bool levelCheckRank(Operation *op, const Type typeToCheck, |
| 243 | const StringRef operandOrResult, int32_t highest_rank) { |
| 244 | if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) { |
| 245 | if (!type.hasRank()) { |
| 246 | op->emitOpError() << "failed level check: unranked tensor" ; |
| 247 | return false; |
| 248 | } |
| 249 | if (type.getRank() > highest_rank) { |
| 250 | op->emitOpError() << "failed level check: " << operandOrResult |
| 251 | << " rank(shape) <= MAX_RANK" ; |
| 252 | return false; |
| 253 | } |
| 254 | } |
| 255 | return true; |
| 256 | } |
| 257 | |
| 258 | // Perform the Level Rank check on the tensor value. |
| 259 | bool levelCheckRank(Operation *op, const Value &v, |
| 260 | const StringRef operandOrResult, int32_t highest_rank) { |
| 261 | return levelCheckRank(op, typeToCheck: v.getType(), operandOrResult, highest_rank); |
| 262 | } |
| 263 | |
| 264 | // Perform the Level tensor size check on the tensor type. |
| 265 | bool levelCheckSize(Operation *op, const Type &typeToCheck, |
| 266 | const StringRef operandOrResult); |
| 267 | |
| 268 | // Perform the Level tensor size check on the tensor value. |
| 269 | bool levelCheckSize(Operation *op, const Value &v, |
| 270 | const StringRef operandOrResult) { |
| 271 | return levelCheckSize(op, typeToCheck: v.getType(), operandOrResult); |
| 272 | } |
| 273 | |
| 274 | // Level check sizes of all operands and results of the operation. |
| 275 | template <typename T> |
| 276 | bool levelCheckSizes(T tosaOp) { |
| 277 | auto op = tosaOp.getOperation(); |
| 278 | for (auto v : op->getOperands()) { |
| 279 | if (!levelCheckSize(op, v, "operand" )) |
| 280 | return false; |
| 281 | } |
| 282 | |
| 283 | for (auto v : op->getResults()) { |
| 284 | if (!levelCheckSize(op, v, "result" )) |
| 285 | return false; |
| 286 | } |
| 287 | return true; |
| 288 | } |
| 289 | |
| 290 | // Level check ranks of all operands, attribute and results of the operation. |
| 291 | template <typename T> |
| 292 | bool levelCheckRanks(T tosaOp) { |
| 293 | auto op = tosaOp.getOperation(); |
| 294 | for (auto v : op->getOperands()) { |
| 295 | if (!levelCheckRank(op, v, "operand" , tosaLevel.MAX_RANK)) |
| 296 | return false; |
| 297 | } |
| 298 | |
| 299 | for (auto v : op->getResults()) { |
| 300 | if (!levelCheckRank(op, v, "result" , tosaLevel.MAX_RANK)) |
| 301 | return false; |
| 302 | } |
| 303 | return true; |
| 304 | } |
| 305 | |
| 306 | // Level check ranks and sizes. |
| 307 | bool levelCheckRanksAndSizes(Operation *op); |
| 308 | |
| 309 | // Pool Op: level check kernel/stride/pad values |
| 310 | template <typename T> |
| 311 | bool levelCheckPool(Operation *op) { |
| 312 | if (auto poolOp = dyn_cast<T>(op)) { |
| 313 | for (auto k : poolOp.getKernel()) { |
| 314 | if (!levelCheckKernel(op, v: k, checkDesc: "kernel <= MAX_KERNEL" )) { |
| 315 | return false; |
| 316 | } |
| 317 | } |
| 318 | for (auto s : poolOp.getStride()) { |
| 319 | if (!levelCheckStride(op, v: s, checkDesc: "stride <= MAX_STRIDE" )) { |
| 320 | return false; |
| 321 | } |
| 322 | } |
| 323 | for (auto p : poolOp.getPad()) { |
| 324 | if (!levelCheckKernel(op, v: p, checkDesc: "pad <= MAX_KERNEL" )) { |
| 325 | return false; |
| 326 | } |
| 327 | } |
| 328 | } |
| 329 | return true; |
| 330 | } |
| 331 | |
| 332 | // Conv Op: level check dilation/stride/pad values |
| 333 | template <typename T> |
| 334 | bool levelCheckConv(Operation *op) { |
| 335 | if (auto convOp = dyn_cast<T>(op)) { |
| 336 | |
| 337 | for (auto k : convOp.getDilation()) { |
| 338 | if (!levelCheckKernel(op, v: k, checkDesc: "dilation <= MAX_KERNEL" )) { |
| 339 | return false; |
| 340 | } |
| 341 | } |
| 342 | for (auto p : convOp.getPad()) { |
| 343 | if (!levelCheckKernel(op, v: p, checkDesc: "pad <= MAX_KERNEL" )) { |
| 344 | return false; |
| 345 | } |
| 346 | } |
| 347 | for (auto s : convOp.getStride()) { |
| 348 | if (!levelCheckStride(op, v: s, checkDesc: "stride <= MAX_STRIDE" )) { |
| 349 | return false; |
| 350 | } |
| 351 | } |
| 352 | auto dilation = convOp.getDilation(); |
| 353 | if (ShapedType weightType = |
| 354 | dyn_cast<ShapedType>(op->getOperand(1).getType())) { |
| 355 | auto shape = weightType.getShape(); |
| 356 | if (isa<tosa::Conv2DOp>(op)) { |
| 357 | assert(shape.size() == 4); |
| 358 | assert(dilation.size() == 2); |
| 359 | if (!levelCheckKernel(op, v: dilation[0] * shape[1], |
| 360 | checkDesc: "dilation_y * KH <= MAX_KERNEL)" ) || |
| 361 | !levelCheckKernel(op, v: dilation[1] * shape[2], |
| 362 | checkDesc: "dilation_x * KW <= MAX_KERNEL)" )) |
| 363 | return false; |
| 364 | } else if (isa<tosa::Conv3DOp>(op)) { |
| 365 | assert(shape.size() == 5); |
| 366 | assert(dilation.size() == 3); |
| 367 | if (!levelCheckKernel(op, v: dilation[0] * shape[1], |
| 368 | checkDesc: "dilation_d * KD <= MAX_KERNEL)" ) || |
| 369 | !levelCheckKernel(op, v: dilation[1] * shape[2], |
| 370 | checkDesc: "dilation_y * KH <= MAX_KERNEL)" ) || |
| 371 | !levelCheckKernel(op, v: dilation[2] * shape[3], |
| 372 | checkDesc: "dilation_x * KW <= MAX_KERNEL)" )) |
| 373 | return false; |
| 374 | } else if (isa<tosa::DepthwiseConv2DOp>(op)) { |
| 375 | assert(shape.size() == 4); |
| 376 | assert(dilation.size() == 2); |
| 377 | if (!levelCheckKernel(op, v: dilation[0] * shape[0], |
| 378 | checkDesc: "dilation_y * KH <= MAX_KERNEL)" ) || |
| 379 | !levelCheckKernel(op, v: dilation[1] * shape[1], |
| 380 | checkDesc: "dilation_x * KW <= MAX_KERNEL)" )) |
| 381 | return false; |
| 382 | } |
| 383 | } |
| 384 | } |
| 385 | return true; |
| 386 | } |
| 387 | |
| 388 | // FFT op: level check H, W in input shape [N,H,W] |
| 389 | template <typename T> |
| 390 | bool levelCheckFFT(Operation *op) { |
| 391 | if (isa<T>(op)) { |
| 392 | for (auto v : op->getOperands()) { |
| 393 | if (ShapedType type = dyn_cast<ShapedType>(v.getType())) { |
| 394 | auto shape = type.getShape(); |
| 395 | assert(shape.size() == 3); |
| 396 | if (!levelCheckKernel(op, v: shape[1], checkDesc: "H <= MAX_KERNEL" ) || |
| 397 | !levelCheckKernel(op, v: shape[2], checkDesc: "W <= MAX_KERNEL" )) { |
| 398 | return false; |
| 399 | } |
| 400 | } |
| 401 | } |
| 402 | } |
| 403 | return true; |
| 404 | } |
| 405 | |
| 406 | // TransposeConv2d op: level check kH/kW, outpad, and stride |
| 407 | bool levelCheckTransposeConv2d(Operation *op) { |
| 408 | if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) { |
| 409 | if (ShapedType filterType = |
| 410 | dyn_cast<ShapedType>(transpose.getWeight().getType())) { |
| 411 | auto shape = filterType.getShape(); |
| 412 | assert(shape.size() == 4); |
| 413 | // level check kernel sizes for kH and KW |
| 414 | if (!levelCheckKernel(op, v: shape[1], checkDesc: "KH <= MAX_KERNEL" ) || |
| 415 | !levelCheckKernel(op, v: shape[2], checkDesc: "KW <= MAX_KERNEL" )) { |
| 416 | return false; |
| 417 | } |
| 418 | } |
| 419 | for (auto p : transpose.getOutPad()) { |
| 420 | if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL" )) { |
| 421 | return false; |
| 422 | } |
| 423 | } |
| 424 | for (auto s : transpose.getStride()) { |
| 425 | if (!levelCheckStride(op, s, "stride <= MAX_STRIDE" )) { |
| 426 | return false; |
| 427 | } |
| 428 | } |
| 429 | } |
| 430 | return true; |
| 431 | } |
| 432 | |
| 433 | // Resize op: level check max scales |
| 434 | bool levelCheckResize(Operation *op) { |
| 435 | if (auto resize = dyn_cast<tosa::ResizeOp>(op)) { |
| 436 | SmallVector<int64_t> scale; |
| 437 | if (!tosa::getConstShapeValues(op: resize.getScale().getDefiningOp(), |
| 438 | result_shape&: scale)) { |
| 439 | return false; |
| 440 | } |
| 441 | const int64_t scaleYN = scale[0]; |
| 442 | const int64_t scaleYD = scale[1]; |
| 443 | const int64_t scaleXN = scale[2]; |
| 444 | const int64_t scaleXD = scale[3]; |
| 445 | if (!levelCheckScale(op, v: scaleYN / scaleYD, |
| 446 | checkDesc: "scale_y_n/scale_y_d <= MAX_SCALE" ) || |
| 447 | !levelCheckScale(op, v: scaleXN / scaleXD, |
| 448 | checkDesc: "scale_x_n/scale_x_d <= MAX_SCALE" )) { |
| 449 | return false; |
| 450 | } |
| 451 | } |
| 452 | return true; |
| 453 | } |
| 454 | |
| 455 | // Recursively perform a bottom-up search to determine the maximum nesting |
| 456 | // depth, starting from a specific operation and continuing up to the function |
| 457 | // or module scope. Tosa nesting_depth starts at 0 and increments by one each |
| 458 | // time a new nested `region` is encountered. |
| 459 | static void getMaxNestedDepth(Operation *op, int32_t &depth) { |
| 460 | if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op)) |
| 461 | return; |
| 462 | |
| 463 | op = op->getParentOp(); |
| 464 | if (!op) |
| 465 | return; |
| 466 | |
| 467 | depth++; |
| 468 | getMaxNestedDepth(op, depth); |
| 469 | } |
| 470 | |
| 471 | bool levelCheckMaxNesting(Operation *op) { |
| 472 | int32_t maxNestedDepth = 0; |
| 473 | getMaxNestedDepth(op, depth&: maxNestedDepth); |
| 474 | |
| 475 | if (maxNestedDepth >= tosaLevel.MAX_NESTING) { |
| 476 | op->emitOpError() << "failed level check: " << maxNestedDepth |
| 477 | << " >= MAX_NESTING" ; |
| 478 | return false; |
| 479 | } |
| 480 | return true; |
| 481 | } |
| 482 | |
| 483 | bool levelCheckListSize(Operation *op) { |
| 484 | if (auto concat = dyn_cast<tosa::ConcatOp>(op)) { |
| 485 | return levelCheckListSize(op, concat.getInput1().size(), "input1" ); |
| 486 | } |
| 487 | if (auto custom = dyn_cast<tosa::CustomOp>(op)) { |
| 488 | if (!levelCheckListSize(op, custom.getInputList().size(), "input_list" ) || |
| 489 | !levelCheckListSize(op, custom.getOutputList().size(), |
| 490 | "output_list" )) { |
| 491 | return false; |
| 492 | } |
| 493 | } |
| 494 | if (auto condIf = dyn_cast<tosa::IfOp>(op)) { |
| 495 | if (!levelCheckListSize(op, condIf.getInputList().size(), "inputs" ) || |
| 496 | !levelCheckListSize(op, condIf.getOutputList().size(), "outputs" )) { |
| 497 | return false; |
| 498 | } |
| 499 | } |
| 500 | if (auto w = dyn_cast<tosa::WhileOp>(op)) { |
| 501 | if (!levelCheckListSize(op, w.getInputList().size(), "inputs" ) || |
| 502 | !levelCheckListSize(op, w.getOutputList().size(), "outputs" )) { |
| 503 | return false; |
| 504 | } |
| 505 | } |
| 506 | return true; |
| 507 | } |
| 508 | |
| 509 | bool attributeCheckRescale(Operation *op) { |
| 510 | if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) { |
| 511 | if (rescale.getRoundingMode() == "DOUBLE_ROUND" && |
| 512 | !targetEnv.allows(Extension::doubleround)) { |
| 513 | op->emitOpError() |
| 514 | << "failed attribute check: rounding_mode = DOUBLE_ROUND " |
| 515 | << "requires extension [doubleround]" ; |
| 516 | return false; |
| 517 | } else if (rescale.getRoundingMode() == "INEXACT_ROUND" && |
| 518 | !targetEnv.allows(Extension::inexactround)) { |
| 519 | op->emitOpError() |
| 520 | << "failed attribute check: rounding_mode = INEXACT_ROUND " |
| 521 | << "requires extension [inexactround]" ; |
| 522 | return false; |
| 523 | } |
| 524 | } |
| 525 | return true; |
| 526 | } |
| 527 | |
| 528 | // configure profile and level values from pass options profileName and |
| 529 | // levelName |
| 530 | void configLevelAndProfile() { |
| 531 | tosaLevel = TOSA_LEVEL_NONE; |
| 532 | if (level == TosaLevelEnum::EightK) { |
| 533 | tosaLevel = TOSA_LEVEL_EIGHTK; |
| 534 | } |
| 535 | |
| 536 | if (!profile.empty()) { |
| 537 | for (std::string &prof : profile) { |
| 538 | auto profSymbol = symbolizeProfile(prof); |
| 539 | if (profSymbol) { |
| 540 | targetEnv.addProfile(profSymbol.value()); |
| 541 | } else { |
| 542 | llvm::errs() << "unknown TOSA profile name passed in: " << prof |
| 543 | << ", supported profiles are `pro_int` and `pro_fp`\n" ; |
| 544 | return signalPassFailure(); |
| 545 | } |
| 546 | } |
| 547 | } |
| 548 | |
| 549 | if (!extension.empty()) { |
| 550 | for (std::string &ext : extension) { |
| 551 | auto extSymbol = symbolizeExtension(ext); |
| 552 | if (extSymbol) { |
| 553 | targetEnv.addExtension(extSymbol.value()); |
| 554 | } else { |
| 555 | llvm::errs() << "unknown TOSA extension name passed in: " << ext |
| 556 | << ", supported extension are int16, int4, bf16, " |
| 557 | << "fp8e4m3, fp8e5m2, fft, variable, controlflow, " |
| 558 | << "doubleround, inexactround and dynamic\n" ; |
| 559 | return signalPassFailure(); |
| 560 | } |
| 561 | } |
| 562 | } |
| 563 | } |
| 564 | |
| 565 | bool CheckVariable(Operation *op); |
| 566 | bool CheckVariableReadOrWrite(Operation *op); |
| 567 | bool isValidElementType(Type type, const bool allowUnsigned = false); |
| 568 | |
| 569 | SmallVector< |
| 570 | std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>> |
| 571 | constCheckers; |
| 572 | TosaLevel tosaLevel; |
| 573 | DenseMap<StringAttr, mlir::Type> variablesMap; |
| 574 | TosaProfileCompliance profileComp; |
| 575 | tosa::TargetEnv targetEnv; |
| 576 | }; |
| 577 | |
| 578 | template <> |
| 579 | bool TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) { |
| 580 | auto op = tosaOp.getOperation(); |
| 581 | if (!levelCheckRank(op, tosaOp.getInput(), "operand" , tosaLevel.MAX_RANK)) |
| 582 | return false; |
| 583 | |
| 584 | // rank(output) = rank(input) - 1 |
| 585 | if (!levelCheckRank(op, tosaOp.getOutput(), "result" , tosaLevel.MAX_RANK - 1)) |
| 586 | return false; |
| 587 | |
| 588 | return true; |
| 589 | } |
| 590 | |
| 591 | template <> |
| 592 | bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) { |
| 593 | auto op = tosaOp.getOperation(); |
| 594 | |
| 595 | // Only the condition input has rank limitation. |
| 596 | if (!levelCheckRank(op, tosaOp.getCondition(), "operand" , tosaLevel.MAX_RANK)) |
| 597 | return false; |
| 598 | |
| 599 | return true; |
| 600 | } |
| 601 | |
| 602 | template <> |
| 603 | bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) { |
| 604 | auto op = tosaOp.getOperation(); |
| 605 | auto variableType = getVariableType(tosaOp); |
| 606 | if (!levelCheckRank(op, variableType, "variable type" , tosaLevel.MAX_RANK)) |
| 607 | return false; |
| 608 | |
| 609 | return true; |
| 610 | } |
| 611 | |
| 612 | template <> |
| 613 | bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) { |
| 614 | auto op = tosaOp.getOperation(); |
| 615 | auto variableType = getVariableType(tosaOp); |
| 616 | if (!levelCheckSize(op, variableType, "variable type" )) |
| 617 | return false; |
| 618 | |
| 619 | return true; |
| 620 | } |
| 621 | |
| 622 | bool TosaValidation::levelCheckRanksAndSizes(Operation *op) { |
| 623 | #define CHECK_RANKS_AND_SIZES(tosaOp) \ |
| 624 | if (isa<tosa::tosaOp##Op>(op)) { \ |
| 625 | if (!levelCheckRanks(cast<tosa::tosaOp##Op>(op))) \ |
| 626 | return false; \ |
| 627 | if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op))) \ |
| 628 | return false; \ |
| 629 | } |
| 630 | |
| 631 | #define CHECK_SIZES(tosaOp) \ |
| 632 | if (isa<tosa::tosaOp##Op>(op)) { \ |
| 633 | if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op))) \ |
| 634 | return false; \ |
| 635 | } |
| 636 | |
| 637 | // Tensor Operators |
| 638 | CHECK_RANKS_AND_SIZES(ArgMax); |
| 639 | // Activation Functions |
| 640 | CHECK_RANKS_AND_SIZES(Clamp); |
| 641 | CHECK_RANKS_AND_SIZES(Erf); |
| 642 | CHECK_RANKS_AND_SIZES(Sigmoid); |
| 643 | CHECK_RANKS_AND_SIZES(Tanh); |
| 644 | // Elementwise Binary Operators |
| 645 | CHECK_RANKS_AND_SIZES(Add); |
| 646 | CHECK_RANKS_AND_SIZES(ArithmeticRightShift); |
| 647 | CHECK_RANKS_AND_SIZES(BitwiseAnd); |
| 648 | CHECK_RANKS_AND_SIZES(BitwiseOr); |
| 649 | CHECK_RANKS_AND_SIZES(BitwiseXor); |
| 650 | CHECK_RANKS_AND_SIZES(IntDiv); |
| 651 | CHECK_RANKS_AND_SIZES(LogicalAnd); |
| 652 | CHECK_RANKS_AND_SIZES(LogicalLeftShift); |
| 653 | CHECK_RANKS_AND_SIZES(LogicalRightShift); |
| 654 | CHECK_RANKS_AND_SIZES(LogicalOr); |
| 655 | CHECK_RANKS_AND_SIZES(LogicalXor); |
| 656 | CHECK_RANKS_AND_SIZES(Maximum); |
| 657 | CHECK_RANKS_AND_SIZES(Minimum); |
| 658 | CHECK_RANKS_AND_SIZES(Mul); |
| 659 | CHECK_RANKS_AND_SIZES(Pow); |
| 660 | CHECK_RANKS_AND_SIZES(Sub); |
| 661 | CHECK_RANKS_AND_SIZES(Table); |
| 662 | // Elementwise Unary Operators |
| 663 | CHECK_RANKS_AND_SIZES(Abs); |
| 664 | CHECK_RANKS_AND_SIZES(BitwiseNot); |
| 665 | CHECK_RANKS_AND_SIZES(Ceil); |
| 666 | CHECK_RANKS_AND_SIZES(Clz); |
| 667 | CHECK_RANKS_AND_SIZES(Cos); |
| 668 | CHECK_RANKS_AND_SIZES(Exp); |
| 669 | CHECK_RANKS_AND_SIZES(Floor); |
| 670 | CHECK_RANKS_AND_SIZES(Log); |
| 671 | CHECK_RANKS_AND_SIZES(LogicalNot); |
| 672 | CHECK_RANKS_AND_SIZES(Negate); |
| 673 | CHECK_RANKS_AND_SIZES(Reciprocal); |
| 674 | CHECK_RANKS_AND_SIZES(Rsqrt); |
| 675 | CHECK_RANKS_AND_SIZES(Sin); |
| 676 | // Elementwise Ternary Operators |
| 677 | CHECK_RANKS_AND_SIZES(Select); |
| 678 | // Comparison Operators |
| 679 | CHECK_RANKS_AND_SIZES(Equal); |
| 680 | CHECK_RANKS_AND_SIZES(Greater); |
| 681 | CHECK_RANKS_AND_SIZES(GreaterEqual); |
| 682 | // Reduction Operators |
| 683 | CHECK_RANKS_AND_SIZES(ReduceAll); |
| 684 | CHECK_RANKS_AND_SIZES(ReduceAny); |
| 685 | CHECK_RANKS_AND_SIZES(ReduceMax); |
| 686 | CHECK_RANKS_AND_SIZES(ReduceMin); |
| 687 | CHECK_RANKS_AND_SIZES(ReduceProduct); |
| 688 | CHECK_RANKS_AND_SIZES(ReduceSum); |
| 689 | // Data Layout Operators |
| 690 | CHECK_RANKS_AND_SIZES(Concat); |
| 691 | CHECK_RANKS_AND_SIZES(Pad); |
| 692 | CHECK_RANKS_AND_SIZES(Reshape); |
| 693 | CHECK_RANKS_AND_SIZES(Reverse); |
| 694 | CHECK_RANKS_AND_SIZES(Slice); |
| 695 | CHECK_RANKS_AND_SIZES(Tile); |
| 696 | CHECK_RANKS_AND_SIZES(Transpose); |
| 697 | // Type Conversion |
| 698 | CHECK_RANKS_AND_SIZES(Cast); |
| 699 | CHECK_RANKS_AND_SIZES(Rescale); |
| 700 | // Control Flow Operators |
| 701 | CHECK_RANKS_AND_SIZES(If); |
| 702 | // Variable Operators |
| 703 | CHECK_RANKS_AND_SIZES(Variable); |
| 704 | CHECK_RANKS_AND_SIZES(VariableWrite); |
| 705 | CHECK_RANKS_AND_SIZES(VariableRead); |
| 706 | // Data Nodes |
| 707 | CHECK_RANKS_AND_SIZES(Const); |
| 708 | CHECK_RANKS_AND_SIZES(Identity); |
| 709 | |
| 710 | // For the following operators, check whether the size of each tensor |
| 711 | // operand is valid in a given Level. |
| 712 | |
| 713 | // Tensor Operators |
| 714 | CHECK_SIZES(AvgPool2d); |
| 715 | CHECK_SIZES(Conv2D); |
| 716 | CHECK_SIZES(Conv3D); |
| 717 | CHECK_SIZES(DepthwiseConv2D); |
| 718 | CHECK_SIZES(TransposeConv2D); |
| 719 | CHECK_SIZES(FFT2d); |
| 720 | CHECK_SIZES(MatMul); |
| 721 | CHECK_SIZES(MaxPool2d); |
| 722 | CHECK_SIZES(RFFT2d); |
| 723 | // Scatter/Gather Operators |
| 724 | CHECK_SIZES(Gather); |
| 725 | CHECK_SIZES(Scatter); |
| 726 | // Image Operators |
| 727 | CHECK_SIZES(Resize); |
| 728 | // Custom Operators |
| 729 | CHECK_SIZES(Custom); |
| 730 | // Control Flow Operators |
| 731 | CHECK_SIZES(While); |
| 732 | // Shape Operators |
| 733 | CHECK_SIZES(ConstShape); |
| 734 | |
| 735 | #undef CHECK_RANKS_AND_SIZES |
| 736 | #undef CHECK_SIZES |
| 737 | return true; |
| 738 | } |
| 739 | |
| 740 | // Perform the Level tensor size check on the tensor type. |
| 741 | bool TosaValidation::levelCheckSize(Operation *op, const Type &typeToCheck, |
| 742 | const StringRef operandOrResult) { |
| 743 | if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) { |
| 744 | if (!type.hasRank()) { |
| 745 | op->emitOpError() << "failed level check: unranked tensor" ; |
| 746 | return false; |
| 747 | } |
| 748 | auto shape = type.getShape(); |
| 749 | for (auto dim : shape) { |
| 750 | if (mlir::ShapedType::isDynamic(dim)) { |
| 751 | op->emitOpError() << "failed level check: " << operandOrResult |
| 752 | << " shape dimension cannot be dynamic" ; |
| 753 | return false; |
| 754 | } |
| 755 | } |
| 756 | |
| 757 | int64_t element_bits = type.getElementTypeBitWidth(); |
| 758 | int64_t element_bytes = std::max(INT64_C(1), element_bits / 8); |
| 759 | int64_t size = element_bytes * type.getNumElements(); |
| 760 | |
| 761 | // According to 1.11. Tensor Definitions of Tosa spec, the value of |
| 762 | // tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is |
| 763 | // defined in 1.7. Levels. |
| 764 | // For each tensor, the number of tensor elements multiplied by the |
| 765 | // element size in bytes must be representable as a tensor_size_t. |
| 766 | const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1; |
| 767 | if (size > max_size) { |
| 768 | op->emitOpError() |
| 769 | << "failed level check: " << operandOrResult |
| 770 | << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)" ; |
| 771 | return false; |
| 772 | } |
| 773 | } |
| 774 | return true; |
| 775 | } |
| 776 | |
| 777 | LogicalResult TosaValidation::applyLevelCheck(Operation *op) { |
| 778 | if (tosaLevel == TOSA_LEVEL_NONE) { |
| 779 | // no need to do level checks |
| 780 | return success(); |
| 781 | } |
| 782 | |
| 783 | // check rank and sizes early so later checks can assume shaped operands |
| 784 | if (!levelCheckRanksAndSizes(op)) |
| 785 | return failure(); |
| 786 | |
| 787 | // additional level checks from spec 0.70 |
| 788 | if (!levelCheckPool<tosa::AvgPool2dOp>(op) || |
| 789 | !levelCheckConv<tosa::Conv2DOp>(op) || |
| 790 | !levelCheckConv<tosa::Conv3DOp>(op) || |
| 791 | !levelCheckConv<tosa::DepthwiseConv2DOp>(op) || |
| 792 | !levelCheckFFT<tosa::FFT2dOp>(op) || |
| 793 | !levelCheckPool<tosa::MaxPool2dOp>(op) || |
| 794 | !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) || |
| 795 | !levelCheckResize(op)) { |
| 796 | return failure(); |
| 797 | } |
| 798 | |
| 799 | // level check MAX_TENSOR_LIST_SIZE |
| 800 | if (!levelCheckListSize(op)) { |
| 801 | return failure(); |
| 802 | } |
| 803 | |
| 804 | if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) { |
| 805 | if (!levelCheckMaxNesting(op)) { |
| 806 | return failure(); |
| 807 | } |
| 808 | } |
| 809 | |
| 810 | return success(); |
| 811 | } |
| 812 | |
| 813 | LogicalResult TosaValidation::applyAttributeCheck(Operation *op) { |
| 814 | if (!attributeCheckRescale(op)) |
| 815 | return failure(); |
| 816 | return success(); |
| 817 | } |
| 818 | |
| 819 | inline bool CompatibleTypes(const mlir::Type &type, |
| 820 | const mlir::Type &declaredType) { |
| 821 | // for now, simply use type equality comparison |
| 822 | return type == declaredType; |
| 823 | } |
| 824 | |
| 825 | bool TosaValidation::CheckVariable(Operation *op) { |
| 826 | if (auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) { |
| 827 | mlir::StringAttr nameAttr = variableOp.getNameAttr(); |
| 828 | |
| 829 | if (variablesMap.count(nameAttr)) { |
| 830 | op->emitOpError() << "name has already been declared" ; |
| 831 | return false; |
| 832 | } |
| 833 | |
| 834 | auto elementType = variableOp.getType(); |
| 835 | DenseIntElementsAttr varShapeAttr = variableOp.getVarShape(); |
| 836 | SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>()); |
| 837 | RankedTensorType variableType = |
| 838 | RankedTensorType::get(ArrayRef<int64_t>(shape), elementType); |
| 839 | |
| 840 | variablesMap[nameAttr] = variableType; |
| 841 | } |
| 842 | |
| 843 | return true; |
| 844 | } |
| 845 | |
| 846 | bool TosaValidation::CheckVariableReadOrWrite(Operation *op) { |
| 847 | if (isa<mlir::tosa::VariableReadOp>(op) || |
| 848 | isa<mlir::tosa::VariableWriteOp>(op)) { |
| 849 | mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name" )); |
| 850 | if (!variablesMap.count(nameAttr)) { |
| 851 | op->emitOpError() << "name has not been declared" ; |
| 852 | return false; |
| 853 | } |
| 854 | |
| 855 | auto varType = variablesMap[nameAttr]; |
| 856 | |
| 857 | for (auto v : op->getOperands()) { |
| 858 | auto type = v.getType(); |
| 859 | if (!CompatibleTypes(type, varType)) { |
| 860 | op->emitOpError() << "operand type does not equal variable type" ; |
| 861 | return false; |
| 862 | } |
| 863 | } |
| 864 | |
| 865 | for (auto v : op->getResults()) { |
| 866 | auto type = v.getType(); |
| 867 | if (!CompatibleTypes(type, varType)) { |
| 868 | op->emitOpError() << "result type does not equal variable type" ; |
| 869 | return false; |
| 870 | } |
| 871 | } |
| 872 | } |
| 873 | |
| 874 | return true; |
| 875 | } |
| 876 | |
| 877 | LogicalResult TosaValidation::applyVariableCheck(Operation *op) { |
| 878 | if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) { |
| 879 | return failure(); |
| 880 | } |
| 881 | return success(); |
| 882 | } |
| 883 | |
| 884 | bool checkErrorIfResize(Operation *op) { |
| 885 | auto resize = dyn_cast<tosa::ResizeOp>(op); |
| 886 | if (!resize) |
| 887 | return true; |
| 888 | |
| 889 | const Value input = resize.getInput(); |
| 890 | const Value output = resize.getOutput(); |
| 891 | const RankedTensorType inputType = |
| 892 | llvm::dyn_cast<RankedTensorType>(input.getType()); |
| 893 | const RankedTensorType outputType = |
| 894 | llvm::dyn_cast<RankedTensorType>(output.getType()); |
| 895 | |
| 896 | if (!inputType || !outputType) { |
| 897 | op->emitOpError(message: "expect ranked input/output tensor" ); |
| 898 | return false; |
| 899 | } |
| 900 | |
| 901 | // Ensure the image size is supported by GPU APIs and that for integer |
| 902 | // implementations, position * stride does not overflow int32_t. |
| 903 | if (inputType.hasStaticShape() && outputType.hasStaticShape()) { |
| 904 | const SmallVector<int64_t, 4> sizes = { |
| 905 | outputType.getDimSize(1), outputType.getDimSize(2), |
| 906 | inputType.getDimSize(1), inputType.getDimSize(2)}; |
| 907 | const int64_t *maxDim = llvm::max_element(sizes); |
| 908 | if (maxDim != sizes.end() && *maxDim >= 16384) { |
| 909 | op->emitOpError(message: "expect input/output height/width dims to be < 16384, " ) |
| 910 | << "got [OH, OW, IH, IW] = " << sizes; |
| 911 | return false; |
| 912 | } |
| 913 | } |
| 914 | |
| 915 | SmallVector<int64_t> scale; |
| 916 | if (!tosa::getConstShapeValues(op: resize.getScale().getDefiningOp(), result_shape&: scale)) { |
| 917 | return false; |
| 918 | } |
| 919 | |
| 920 | const int64_t scaleYN = scale[0]; |
| 921 | const int64_t scaleYD = scale[1]; |
| 922 | const int64_t scaleXN = scale[2]; |
| 923 | const int64_t scaleXD = scale[3]; |
| 924 | |
| 925 | // Ensure scale values don't overflow int32 accumulator |
| 926 | if (scaleYN > (1 << 11) || scaleXN > (1 << 11)) { |
| 927 | op->emitOpError(message: "expect all scale numerator values to be <= (1 << 11), " |
| 928 | "got scale_y_n=" ) |
| 929 | << scaleYN << ", scale_x_n=" << scaleXN; |
| 930 | return false; |
| 931 | } |
| 932 | |
| 933 | if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) { |
| 934 | op->emitOpError(message: "expect a downscale ratio larger than 1/16, got y=" ) |
| 935 | << scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD; |
| 936 | return false; |
| 937 | } |
| 938 | |
| 939 | SmallVector<int64_t> offset; |
| 940 | SmallVector<int64_t> border; |
| 941 | if (!tosa::getConstShapeValues(op: resize.getOffset().getDefiningOp(), result_shape&: offset) || |
| 942 | !tosa::getConstShapeValues(op: resize.getBorder().getDefiningOp(), result_shape&: border)) { |
| 943 | return false; |
| 944 | } |
| 945 | |
| 946 | const int64_t offsetY = offset[0]; |
| 947 | const int64_t offsetX = offset[1]; |
| 948 | // Set a consistent lower limit of 1/16 downscale to simplify |
| 949 | // implementations |
| 950 | if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) { |
| 951 | op->emitOpError( |
| 952 | message: "expect offsetY / scaleYNumerator to be in range [-1, 16), got " ) |
| 953 | << offsetY << "/" << scaleYN; |
| 954 | return false; |
| 955 | } |
| 956 | if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) { |
| 957 | op->emitOpError( |
| 958 | message: "expect offsetX / scaleXNumerator to be in range [-1, 16), got " ) |
| 959 | << offsetX << "/" << scaleXN; |
| 960 | return false; |
| 961 | } |
| 962 | |
| 963 | const int64_t borderY = border[0]; |
| 964 | const int64_t borderX = border[1]; |
| 965 | if (borderY < -16 * scaleYN || borderY >= scaleYN) { |
| 966 | op->emitOpError( |
| 967 | message: "expect borderY / scaleYNumerator to be in range [-16, 1), got " ) |
| 968 | << borderY << "/" << scaleYN; |
| 969 | return false; |
| 970 | } |
| 971 | if (borderX < -16 * scaleXN || borderX >= scaleXN) { |
| 972 | op->emitOpError( |
| 973 | message: "expect borderX / scaleXNumerator to be in range [-16, 1), got " ) |
| 974 | << borderX << "/" << scaleXN; |
| 975 | return false; |
| 976 | } |
| 977 | |
| 978 | // The following section of code is mostly duplicated with ResizeOp::verify(). |
| 979 | // |
| 980 | // In TOSA specification, we do not support broadcast behavior. |
| 981 | // However, there is a rewrite pattern to materialize broadcast ResizeOp. |
| 982 | // It makes invalid TOSA ResizeOp into valid one. To avoid breaking |
| 983 | // existing code, we keep the rewrite pattern untouched. So, we need |
| 984 | // loose the checking in ResizeOp::verify() to support broadcast ResizeOp. |
| 985 | // |
| 986 | // Here is a strict checking to conform TOSA specification. |
| 987 | // FIXME: Remove the duplicated checkings when broadcast ResizeOp is removed. |
| 988 | auto idivCheck = [](const int64_t lhs, |
| 989 | const int64_t rhs) -> std::optional<int64_t> { |
| 990 | if (lhs % rhs != 0) |
| 991 | return std::nullopt; |
| 992 | return lhs / rhs; |
| 993 | }; |
| 994 | |
| 995 | const int64_t oh = outputType.getDimSize(1); |
| 996 | const int64_t ow = outputType.getDimSize(2); |
| 997 | const int64_t ih = inputType.getDimSize(1); |
| 998 | const int64_t iw = inputType.getDimSize(2); |
| 999 | |
| 1000 | if (ih != ShapedType::kDynamic) { |
| 1001 | const std::optional<int64_t> calculatedOutHeightMinusOne = |
| 1002 | idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD); |
| 1003 | if (!calculatedOutHeightMinusOne.has_value()) { |
| 1004 | op->emitOpError(message: "expected (input_height - 1) * scale_y_n - offset_y + " |
| 1005 | "border_y " ) |
| 1006 | << "to be wholly divisible by scale_y_d, got ((" << ih << " - 1) * " |
| 1007 | << scaleYN << " - " << offsetY << " + " << borderY << ") / " |
| 1008 | << scaleYD; |
| 1009 | return false; |
| 1010 | } |
| 1011 | const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1; |
| 1012 | if (oh != ShapedType::kDynamic && calculatedOutHeight != oh) { |
| 1013 | op->emitOpError(message: "calculated output height did not match expected: " ) |
| 1014 | << "calculated=" << calculatedOutHeight << ", expected=" << oh; |
| 1015 | return false; |
| 1016 | } |
| 1017 | } |
| 1018 | |
| 1019 | if (iw != ShapedType::kDynamic) { |
| 1020 | const std::optional<int64_t> calculatedOutWidthMinusOne = |
| 1021 | idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD); |
| 1022 | if (!calculatedOutWidthMinusOne.has_value()) { |
| 1023 | op->emitOpError(message: "expected (input_width - 1) * scale_x_n - offset_x + " |
| 1024 | "border_x " ) |
| 1025 | << "to be wholly divisible by scale_x_d, got ((" << iw << " - 1) * " |
| 1026 | << scaleXN << " - " << offsetX << " + " << borderX << ") / " |
| 1027 | << scaleXD; |
| 1028 | return false; |
| 1029 | } |
| 1030 | const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1; |
| 1031 | if (ow != ShapedType::kDynamic && calculatedOutWidth != ow) { |
| 1032 | op->emitOpError(message: "calculated output width did not match expected: " ) |
| 1033 | << "calculated=" << calculatedOutWidth << ", expected=" << ow; |
| 1034 | return false; |
| 1035 | } |
| 1036 | } |
| 1037 | |
| 1038 | return true; |
| 1039 | } |
| 1040 | |
| 1041 | bool checkErrorIfMul(Operation *op) { |
| 1042 | auto mul = dyn_cast<tosa::MulOp>(op); |
| 1043 | if (!mul) |
| 1044 | return true; |
| 1045 | |
| 1046 | // REQUIRE(0 <= shift && shift <= 63); |
| 1047 | // REQUIRE(is_same<in_t,int32_t>() || shift == 0); |
| 1048 | ElementsAttr shift_elem; |
| 1049 | if (!matchPattern(mul.getShift(), m_Constant(&shift_elem))) { |
| 1050 | return true; |
| 1051 | } |
| 1052 | int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt(); |
| 1053 | auto inputElemType = getElementTypeOrSelf(mul.getInput1()); |
| 1054 | if (inputElemType.isInteger(32)) { |
| 1055 | // 0 <= shift <= 63 for int32_t type |
| 1056 | if (shift < 0 || shift > 63) { |
| 1057 | op->emitOpError() << "requires 0 <= shift && shift <= 63, but got: " |
| 1058 | << shift; |
| 1059 | return false; |
| 1060 | } |
| 1061 | } else { |
| 1062 | // shift must be 0 for all other types |
| 1063 | if (shift != 0) { |
| 1064 | op->emitOpError() << "requires shift = 0 for all input data types that " |
| 1065 | "are not int32_t, but got: " |
| 1066 | << shift; |
| 1067 | return false; |
| 1068 | } |
| 1069 | } |
| 1070 | |
| 1071 | return true; |
| 1072 | } |
| 1073 | |
| 1074 | bool checkErrorIfTable(Operation *op) { |
| 1075 | auto table = dyn_cast<tosa::TableOp>(op); |
| 1076 | if (!table) |
| 1077 | return true; |
| 1078 | |
| 1079 | // REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513 |
| 1080 | const auto inputElemType = getElementTypeOrSelf(table.getInput1().getType()); |
| 1081 | const int tableSize = inputElemType.isInteger(8) ? 256 : 513; |
| 1082 | |
| 1083 | const ShapeAdaptor tableShape(table.getTable().getType()); |
| 1084 | if (tableShape.hasStaticShape()) { |
| 1085 | const auto numElements = tableShape.getNumElements(); |
| 1086 | if (numElements != tableSize) { |
| 1087 | op->emitOpError() << "requires table size of " << tableSize << ", got " |
| 1088 | << numElements; |
| 1089 | return false; |
| 1090 | } |
| 1091 | } |
| 1092 | |
| 1093 | return true; |
| 1094 | } |
| 1095 | |
| 1096 | bool checkErrorIfRescale(Operation *op) { |
| 1097 | auto rescale = dyn_cast<tosa::RescaleOp>(op); |
| 1098 | if (!rescale) |
| 1099 | return true; |
| 1100 | |
| 1101 | auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType()); |
| 1102 | auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType()); |
| 1103 | if (!inputType || !outputType || !inputType.getElementType().isInteger() || |
| 1104 | !outputType.getElementType().isInteger()) |
| 1105 | return true; |
| 1106 | |
| 1107 | auto inElemType = inputType.getElementType(); |
| 1108 | auto outElemType = outputType.getElementType(); |
| 1109 | auto inWidth = inElemType.getIntOrFloatBitWidth(); |
| 1110 | auto outWidth = outElemType.getIntOrFloatBitWidth(); |
| 1111 | |
| 1112 | bool inputUnsigned = rescale.getInputUnsigned(); |
| 1113 | bool outputUnsigned = rescale.getOutputUnsigned(); |
| 1114 | |
| 1115 | bool scale32 = rescale.getScale32(); |
| 1116 | auto roundingMode = rescale.getRoundingMode(); |
| 1117 | |
| 1118 | // ERROR_IF(scale32 && is_same<in_t,i48_t>()) |
| 1119 | if (scale32 && inWidth == 48) { |
| 1120 | op->emitOpError() << "scale32 is not allowed with 48-bit input." ; |
| 1121 | return false; |
| 1122 | } |
| 1123 | |
| 1124 | // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND)) |
| 1125 | if (!scale32 && roundingMode == "DOUBLE_ROUND" ) { |
| 1126 | op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true." ; |
| 1127 | return false; |
| 1128 | } |
| 1129 | |
| 1130 | // ERROR_IF(input_unsigned && output_unsigned) |
| 1131 | if (inputUnsigned && outputUnsigned) { |
| 1132 | op->emitOpError() << "input and output cannot be both unsigned." ; |
| 1133 | return false; |
| 1134 | } |
| 1135 | |
| 1136 | // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned) |
| 1137 | if (outWidth == 32 && inputUnsigned) { |
| 1138 | op->emitOpError() << "i32 output type is not allowed with unsigned input." ; |
| 1139 | return false; |
| 1140 | } |
| 1141 | |
| 1142 | // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned) |
| 1143 | if (inWidth == 32 && outputUnsigned) { |
| 1144 | op->emitOpError() << "i32 input type is not allowed with unsigned output." ; |
| 1145 | return false; |
| 1146 | } |
| 1147 | |
| 1148 | // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned) |
| 1149 | if (inWidth == 48 && outputUnsigned) { |
| 1150 | op->emitOpError() << "i48 input type is not allowed with unsigned output." ; |
| 1151 | return false; |
| 1152 | } |
| 1153 | |
| 1154 | // ERROR_IF(is_same<in_t, i48_t> && input_unsigned) |
| 1155 | if (inWidth == 48 && inputUnsigned) { |
| 1156 | op->emitOpError() << "i48 input type cannot be unsigned." ; |
| 1157 | return false; |
| 1158 | } |
| 1159 | |
| 1160 | // ERROR_IF(is_same<in_t, i32_t> && input_unsigned) |
| 1161 | if (inWidth == 32 && inputUnsigned) { |
| 1162 | op->emitOpError() << "i32 input type cannot be unsigned." ; |
| 1163 | return false; |
| 1164 | } |
| 1165 | |
| 1166 | // ERROR_IF(is_same<out_t, i32_t> && output_unsigned) |
| 1167 | if (outWidth == 32 && outputUnsigned) { |
| 1168 | op->emitOpError() << "i32 output type cannot be unsigned." ; |
| 1169 | return false; |
| 1170 | } |
| 1171 | |
| 1172 | return true; |
| 1173 | } |
| 1174 | |
| 1175 | bool checkErrorIfPad(Operation *op) { |
| 1176 | auto pad = dyn_cast<tosa::PadOp>(op); |
| 1177 | if (!pad) |
| 1178 | return true; |
| 1179 | |
| 1180 | DenseIntElementsAttr paddingAttr; |
| 1181 | if (!matchPattern(pad.getPadding(), m_Constant(&paddingAttr))) |
| 1182 | // Pad verifier will catch this |
| 1183 | return true; |
| 1184 | |
| 1185 | for (const APInt &val : paddingAttr.getValues<APInt>()) { |
| 1186 | if (val.getSExtValue() < 0) { |
| 1187 | op->emitOpError() << "padding value must all be non-negative, got " |
| 1188 | << val.getSExtValue(); |
| 1189 | return false; |
| 1190 | } |
| 1191 | } |
| 1192 | |
| 1193 | return true; |
| 1194 | } |
| 1195 | |
| 1196 | // Returns true if the operation takes no input operands, excluding attributes. |
| 1197 | static bool isNullaryOperation(Operation *op) { |
| 1198 | if (isa<tosa::ConstOp>(op) || isa<tosa::ConstShapeOp>(op) || |
| 1199 | isa<tosa::YieldOp>(op) || isa<tosa::VariableOp>(op)) |
| 1200 | return true; |
| 1201 | return false; |
| 1202 | } |
| 1203 | |
| 1204 | bool checkErrorIfCondIf(Operation *op) { |
| 1205 | auto ifOp = dyn_cast<tosa::IfOp>(op); |
| 1206 | if (!ifOp) |
| 1207 | return true; |
| 1208 | |
| 1209 | // Whether the types and shapes of operands between the input/output list and |
| 1210 | // internal regions are validated by the operation verifier. However, with |
| 1211 | // support for the simplified form - where redundant operand notations are |
| 1212 | // omitted - is not conformant to the specification. According to the |
| 1213 | // specification, all operands passed into an operation must be explicitly |
| 1214 | // declared at each operation's structure. This code section verify that the |
| 1215 | // operation's form complies with this requirement. |
| 1216 | |
| 1217 | // Returns true if the region uses no external input operands. |
| 1218 | auto isNullaryRegion = [](Region ®ion) -> bool { |
| 1219 | bool noLiveInValue = true; |
| 1220 | region.walk([&noLiveInValue](Operation *op) { |
| 1221 | if (!isNullaryOperation(op)) { |
| 1222 | noLiveInValue = false; |
| 1223 | return WalkResult::interrupt(); |
| 1224 | } |
| 1225 | return WalkResult::advance(); |
| 1226 | }); |
| 1227 | return noLiveInValue; |
| 1228 | }; |
| 1229 | |
| 1230 | mlir::Region &thenGraph = ifOp.getThenGraph(); |
| 1231 | mlir::Region &elseGraph = ifOp.getElseGraph(); |
| 1232 | bool isThenGraphNullaryRegion = isNullaryRegion(thenGraph); |
| 1233 | bool isElseGraphNullaryRegion = isNullaryRegion(elseGraph); |
| 1234 | bool isInputListEmpty = ifOp.getInputList().size() == 0; |
| 1235 | |
| 1236 | if ((isInputListEmpty != isThenGraphNullaryRegion) || |
| 1237 | (isInputListEmpty != isElseGraphNullaryRegion)) { |
| 1238 | op->emitOpError() |
| 1239 | << "the current simplified form is not strictly conformant to the " |
| 1240 | "spec, please use the generic format\n" ; |
| 1241 | return false; |
| 1242 | } |
| 1243 | |
| 1244 | return true; |
| 1245 | } |
| 1246 | |
| 1247 | LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) { |
| 1248 | if (!checkErrorIfResize(op) || !checkErrorIfMul(op) || |
| 1249 | !checkErrorIfTable(op) || !checkErrorIfRescale(op) || |
| 1250 | !checkErrorIfPad(op) || !checkErrorIfCondIf(op)) |
| 1251 | return failure(); |
| 1252 | return success(); |
| 1253 | } |
| 1254 | |
| 1255 | bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { |
| 1256 | if (isa<FloatType>(type)) { |
| 1257 | return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType, |
| 1258 | Float8E5M2Type>(type); |
| 1259 | } else if (auto intTy = dyn_cast<IntegerType>(type)) { |
| 1260 | if (intTy.isSignless()) { |
| 1261 | switch (intTy.getWidth()) { |
| 1262 | case 1: |
| 1263 | case 4: |
| 1264 | case 8: |
| 1265 | case 16: |
| 1266 | case 32: |
| 1267 | case 48: |
| 1268 | return true; |
| 1269 | } |
| 1270 | } else if (allowUnsigned && intTy.isUnsigned()) { |
| 1271 | switch (intTy.getWidth()) { |
| 1272 | case 8: |
| 1273 | case 16: |
| 1274 | case 32: |
| 1275 | return true; |
| 1276 | } |
| 1277 | } |
| 1278 | } else if (mlir::isa<tosa::shapeType>(type)) { |
| 1279 | return true; |
| 1280 | } |
| 1281 | return false; |
| 1282 | } |
| 1283 | |
| 1284 | void TosaValidation::runOnOperation() { |
| 1285 | configLevelAndProfile(); |
| 1286 | |
| 1287 | TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>(); |
| 1288 | if (!tosaDialect) |
| 1289 | return; |
| 1290 | |
| 1291 | getOperation().walk([&](Operation *op) { |
| 1292 | if (op->getDialect() != tosaDialect) |
| 1293 | return; |
| 1294 | |
| 1295 | // validate operator element types: |
| 1296 | // - rescale operator is allowed to have ui8/ui16/ui32 |
| 1297 | // operands/results |
| 1298 | // - perform valid element type check at the beginning to |
| 1299 | // protect rest of code against quantized element types |
| 1300 | const bool opIsRescale = isa<tosa::RescaleOp>(op); |
| 1301 | for (Value operand : op->getOperands()) { |
| 1302 | auto elementTy = getElementTypeOrSelf(val: operand); |
| 1303 | if (!isValidElementType(type: elementTy, allowUnsigned: opIsRescale)) { |
| 1304 | op->emitOpError() << "is not profile-aligned: element type " |
| 1305 | << elementTy << " is not legal" ; |
| 1306 | return signalPassFailure(); |
| 1307 | } |
| 1308 | } |
| 1309 | for (Type resultTy : op->getResultTypes()) { |
| 1310 | auto elementTy = getElementTypeOrSelf(resultTy); |
| 1311 | if (!isValidElementType(elementTy, opIsRescale)) { |
| 1312 | op->emitOpError() << "is not profile-aligned: element type " |
| 1313 | << elementTy << " is not legal" ; |
| 1314 | return signalPassFailure(); |
| 1315 | } |
| 1316 | } |
| 1317 | |
| 1318 | if (strictOpSpecAlignment && |
| 1319 | failed(profileComp.checkProfile(op, targetEnv))) |
| 1320 | return signalPassFailure(); |
| 1321 | |
| 1322 | if (strictOpSpecAlignment && |
| 1323 | failed(profileComp.checkExtension(op, targetEnv))) |
| 1324 | return signalPassFailure(); |
| 1325 | |
| 1326 | if (!allowInvalidOpDatatypeCombinations && |
| 1327 | failed(profileComp.checkInvalid(op))) |
| 1328 | return signalPassFailure(); |
| 1329 | |
| 1330 | // Some uses of TOSA rely on the constant operands of particular |
| 1331 | // operations. |
| 1332 | if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op))) |
| 1333 | signalPassFailure(); |
| 1334 | |
| 1335 | // do level checks |
| 1336 | if (failed(Result: applyLevelCheck(op))) |
| 1337 | signalPassFailure(); |
| 1338 | |
| 1339 | // check additional attribute restrictions |
| 1340 | if (failed(Result: applyAttributeCheck(op))) |
| 1341 | signalPassFailure(); |
| 1342 | |
| 1343 | // do variable type checks |
| 1344 | if (failed(Result: applyVariableCheck(op))) |
| 1345 | signalPassFailure(); |
| 1346 | |
| 1347 | // do error if checks |
| 1348 | if (strictOpSpecAlignment && failed(applyErrorIfCheck(op))) |
| 1349 | signalPassFailure(); |
| 1350 | }); |
| 1351 | } |
| 1352 | } // namespace |
| 1353 | |