1//===- TosaValidation.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Validate if TOSA dialect input matchs with the specification for given
10// requirements.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
15#include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h"
16#include "mlir/Dialect/Tosa/Transforms/Passes.h"
17#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
18
19#include <string>
20
21#include "mlir/Dialect/Func/IR/FuncOps.h"
22#include "mlir/Dialect/Tosa/IR/TosaOps.h"
23#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
24#include "mlir/IR/Builders.h"
25#include "mlir/IR/BuiltinOps.h"
26#include "mlir/IR/Matchers.h"
27#include "mlir/IR/TypeUtilities.h"
28#include "mlir/Pass/Pass.h"
29#include "mlir/Transforms/DialectConversion.h"
30#include "llvm/ADT/StringExtras.h"
31
32namespace mlir {
33namespace tosa {
34#define GEN_PASS_DEF_TOSAVALIDATION
35#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
36} // namespace tosa
37} // namespace mlir
38
39using namespace mlir;
40using namespace mlir::tosa;
41
42namespace {
43
44static LogicalResult
45checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) {
46 for (const auto index : operandIndices) {
47 Attribute attr;
48 if (!matchPattern(op->getOperand(idx: index), m_Constant(&attr))) {
49 return op->emitOpError(message: "expected compile time resolvable constant, but "
50 "got variable value for operand #")
51 << index;
52 }
53 }
54 return success();
55}
56
57static LogicalResult checkConstantOperandMul(Operation *op,
58 const TargetEnv &env) {
59 if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
60 // Check 'shift'
61 return checkConstantOperands(op, operandIndices: {2});
62 }
63 return success();
64}
65
66static LogicalResult checkConstantOperandTable(Operation *op,
67 const TargetEnv &env) {
68 if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
69 // Check 'table'
70 return checkConstantOperands(op, operandIndices: {1});
71 }
72 return success();
73}
74
75static LogicalResult checkConstantOperandPad(Operation *op,
76 const TargetEnv &env) {
77 if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
78 // Assume this op is zero-padding if padConst is not presented
79 if (!env.allows(Extension::dynamic) && padOp.getPadConst())
80 // Check 'pad_const'
81 // Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
82 return checkConstantOperands(op, operandIndices: {2});
83 }
84 return success();
85}
86
87static LogicalResult checkConstantOperandRescale(Operation *op,
88 const TargetEnv &env) {
89 if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
90 // Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
91 return checkConstantOperands(op, {1, 2, 3, 4});
92 }
93 return success();
94}
95
96template <typename T>
97static LogicalResult checkConstantOperandConvOps(Operation *op,
98 const TargetEnv &env) {
99 if (!env.allows(Extension::dynamic) && isa<T>(op)) {
100 // Check 'input_zp' and 'weight_zp'
101 return checkConstantOperands(op, {3, 4});
102 }
103 return success();
104}
105
106static LogicalResult checkConstantOperandMatMul(Operation *op,
107 const TargetEnv &env) {
108 if (!env.allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
109 // Check 'A_zp' and 'B_zp'
110 return checkConstantOperands(op, {2, 3});
111 }
112 return success();
113}
114
115static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
116 const TargetEnv &env) {
117 if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
118 // Check 'input_zp' and 'output_zp'
119 return checkConstantOperands(op, {1, 2});
120 }
121 return success();
122}
123
124static LogicalResult checkConstantOperandNegate(Operation *op,
125 const TargetEnv &env) {
126 if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
127 // Check 'input1_zp' and 'output_zp'
128 return checkConstantOperands(op, {1, 2});
129 }
130 return success();
131}
132
133struct TosaLevel {
134 int32_t MAX_RANK = 0;
135 int32_t MAX_KERNEL = 0;
136 int32_t MAX_STRIDE = 0;
137 int32_t MAX_SCALE = 0;
138 int32_t MAX_LOG2_SIZE = 0;
139 int32_t MAX_NESTING = 0;
140 int32_t MAX_TENSOR_LIST_SIZE = 0;
141
142 bool operator==(const TosaLevel &rhs) {
143 return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
144 MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
145 MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
146 MAX_NESTING == rhs.MAX_NESTING &&
147 MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
148 }
149};
150
151static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {.MAX_RANK: 6, .MAX_KERNEL: 8192, .MAX_STRIDE: 8192, .MAX_SCALE: 256, .MAX_LOG2_SIZE: 31, .MAX_NESTING: 6, .MAX_TENSOR_LIST_SIZE: 64};
152static constexpr TosaLevel TOSA_LEVEL_NONE = {.MAX_RANK: 32, .MAX_KERNEL: 2147483647, .MAX_STRIDE: 2147483647, .MAX_SCALE: 2048,
153 .MAX_LOG2_SIZE: 63, .MAX_NESTING: 256, .MAX_TENSOR_LIST_SIZE: 256};
154
155//===----------------------------------------------------------------------===//
156// TOSA Validation Pass.
157//===----------------------------------------------------------------------===//
158
159struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
160public:
161 explicit TosaValidation() { populateConstantOperandChecks(); }
162
163 explicit TosaValidation(const TosaValidationOptions &options)
164 : TosaValidation() {
165 this->profile = options.profile;
166 this->extension = options.extension;
167 this->strictOpSpecAlignment = options.strictOpSpecAlignment;
168 this->allowInvalidOpDatatypeCombinations =
169 options.allowInvalidOpDatatypeCombinations;
170 this->level = options.level;
171 }
172 void runOnOperation() final;
173
174 LogicalResult applyConstantOperandCheck(Operation *op) {
175 for (auto &checker : constCheckers) {
176 if (failed(checker(op, targetEnv)))
177 return failure();
178 }
179 return success();
180 }
181
182 LogicalResult applyLevelCheck(Operation *op);
183 LogicalResult applyAttributeCheck(Operation *op);
184
185 // check variable read/write data types against variable declarations
186 LogicalResult applyVariableCheck(Operation *op);
187
188 // check error if conditions
189 LogicalResult applyErrorIfCheck(Operation *op);
190
191private:
192 void populateConstantOperandChecks() {
193 constCheckers.emplace_back(checkConstantOperandMul);
194 constCheckers.emplace_back(checkConstantOperandTable);
195 constCheckers.emplace_back(checkConstantOperandPad);
196 constCheckers.emplace_back(checkConstantOperandRescale);
197 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
198 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
199 constCheckers.emplace_back(
200 checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
201 constCheckers.emplace_back(
202 checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
203 constCheckers.emplace_back(checkConstantOperandMatMul);
204 constCheckers.emplace_back(checkConstantOperandAvgPool2d);
205 constCheckers.emplace_back(checkConstantOperandNegate);
206 }
207
208 bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
209 if (v > tosaLevel.MAX_KERNEL) {
210 op->emitOpError() << "failed level check: " << checkDesc;
211 return false;
212 }
213 return true;
214 }
215
216 bool levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) {
217 if (v > tosaLevel.MAX_STRIDE) {
218 op->emitOpError() << "failed level check: " << checkDesc;
219 return false;
220 }
221 return true;
222 }
223
224 bool levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) {
225 if (v > tosaLevel.MAX_SCALE) {
226 op->emitOpError() << "failed level check: " << checkDesc;
227 return false;
228 }
229 return true;
230 }
231
232 bool levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) {
233 if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
234 op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
235 << checkDesc;
236 return false;
237 }
238 return true;
239 }
240
241 // Perform the Level Rank check on the tensor type.
242 bool levelCheckRank(Operation *op, const Type typeToCheck,
243 const StringRef operandOrResult, int32_t highest_rank) {
244 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
245 if (!type.hasRank()) {
246 op->emitOpError() << "failed level check: unranked tensor";
247 return false;
248 }
249 if (type.getRank() > highest_rank) {
250 op->emitOpError() << "failed level check: " << operandOrResult
251 << " rank(shape) <= MAX_RANK";
252 return false;
253 }
254 }
255 return true;
256 }
257
258 // Perform the Level Rank check on the tensor value.
259 bool levelCheckRank(Operation *op, const Value &v,
260 const StringRef operandOrResult, int32_t highest_rank) {
261 return levelCheckRank(op, typeToCheck: v.getType(), operandOrResult, highest_rank);
262 }
263
264 // Perform the Level tensor size check on the tensor type.
265 bool levelCheckSize(Operation *op, const Type &typeToCheck,
266 const StringRef operandOrResult);
267
268 // Perform the Level tensor size check on the tensor value.
269 bool levelCheckSize(Operation *op, const Value &v,
270 const StringRef operandOrResult) {
271 return levelCheckSize(op, typeToCheck: v.getType(), operandOrResult);
272 }
273
274 // Level check sizes of all operands and results of the operation.
275 template <typename T>
276 bool levelCheckSizes(T tosaOp) {
277 auto op = tosaOp.getOperation();
278 for (auto v : op->getOperands()) {
279 if (!levelCheckSize(op, v, "operand"))
280 return false;
281 }
282
283 for (auto v : op->getResults()) {
284 if (!levelCheckSize(op, v, "result"))
285 return false;
286 }
287 return true;
288 }
289
290 // Level check ranks of all operands, attribute and results of the operation.
291 template <typename T>
292 bool levelCheckRanks(T tosaOp) {
293 auto op = tosaOp.getOperation();
294 for (auto v : op->getOperands()) {
295 if (!levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK))
296 return false;
297 }
298
299 for (auto v : op->getResults()) {
300 if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK))
301 return false;
302 }
303 return true;
304 }
305
306 // Level check ranks and sizes.
307 bool levelCheckRanksAndSizes(Operation *op);
308
309 // Pool Op: level check kernel/stride/pad values
310 template <typename T>
311 bool levelCheckPool(Operation *op) {
312 if (auto poolOp = dyn_cast<T>(op)) {
313 for (auto k : poolOp.getKernel()) {
314 if (!levelCheckKernel(op, v: k, checkDesc: "kernel <= MAX_KERNEL")) {
315 return false;
316 }
317 }
318 for (auto s : poolOp.getStride()) {
319 if (!levelCheckStride(op, v: s, checkDesc: "stride <= MAX_STRIDE")) {
320 return false;
321 }
322 }
323 for (auto p : poolOp.getPad()) {
324 if (!levelCheckKernel(op, v: p, checkDesc: "pad <= MAX_KERNEL")) {
325 return false;
326 }
327 }
328 }
329 return true;
330 }
331
332 // Conv Op: level check dilation/stride/pad values
333 template <typename T>
334 bool levelCheckConv(Operation *op) {
335 if (auto convOp = dyn_cast<T>(op)) {
336
337 for (auto k : convOp.getDilation()) {
338 if (!levelCheckKernel(op, v: k, checkDesc: "dilation <= MAX_KERNEL")) {
339 return false;
340 }
341 }
342 for (auto p : convOp.getPad()) {
343 if (!levelCheckKernel(op, v: p, checkDesc: "pad <= MAX_KERNEL")) {
344 return false;
345 }
346 }
347 for (auto s : convOp.getStride()) {
348 if (!levelCheckStride(op, v: s, checkDesc: "stride <= MAX_STRIDE")) {
349 return false;
350 }
351 }
352 auto dilation = convOp.getDilation();
353 if (ShapedType weightType =
354 dyn_cast<ShapedType>(op->getOperand(1).getType())) {
355 auto shape = weightType.getShape();
356 if (isa<tosa::Conv2DOp>(op)) {
357 assert(shape.size() == 4);
358 assert(dilation.size() == 2);
359 if (!levelCheckKernel(op, v: dilation[0] * shape[1],
360 checkDesc: "dilation_y * KH <= MAX_KERNEL)") ||
361 !levelCheckKernel(op, v: dilation[1] * shape[2],
362 checkDesc: "dilation_x * KW <= MAX_KERNEL)"))
363 return false;
364 } else if (isa<tosa::Conv3DOp>(op)) {
365 assert(shape.size() == 5);
366 assert(dilation.size() == 3);
367 if (!levelCheckKernel(op, v: dilation[0] * shape[1],
368 checkDesc: "dilation_d * KD <= MAX_KERNEL)") ||
369 !levelCheckKernel(op, v: dilation[1] * shape[2],
370 checkDesc: "dilation_y * KH <= MAX_KERNEL)") ||
371 !levelCheckKernel(op, v: dilation[2] * shape[3],
372 checkDesc: "dilation_x * KW <= MAX_KERNEL)"))
373 return false;
374 } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
375 assert(shape.size() == 4);
376 assert(dilation.size() == 2);
377 if (!levelCheckKernel(op, v: dilation[0] * shape[0],
378 checkDesc: "dilation_y * KH <= MAX_KERNEL)") ||
379 !levelCheckKernel(op, v: dilation[1] * shape[1],
380 checkDesc: "dilation_x * KW <= MAX_KERNEL)"))
381 return false;
382 }
383 }
384 }
385 return true;
386 }
387
388 // FFT op: level check H, W in input shape [N,H,W]
389 template <typename T>
390 bool levelCheckFFT(Operation *op) {
391 if (isa<T>(op)) {
392 for (auto v : op->getOperands()) {
393 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
394 auto shape = type.getShape();
395 assert(shape.size() == 3);
396 if (!levelCheckKernel(op, v: shape[1], checkDesc: "H <= MAX_KERNEL") ||
397 !levelCheckKernel(op, v: shape[2], checkDesc: "W <= MAX_KERNEL")) {
398 return false;
399 }
400 }
401 }
402 }
403 return true;
404 }
405
406 // TransposeConv2d op: level check kH/kW, outpad, and stride
407 bool levelCheckTransposeConv2d(Operation *op) {
408 if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
409 if (ShapedType filterType =
410 dyn_cast<ShapedType>(transpose.getWeight().getType())) {
411 auto shape = filterType.getShape();
412 assert(shape.size() == 4);
413 // level check kernel sizes for kH and KW
414 if (!levelCheckKernel(op, v: shape[1], checkDesc: "KH <= MAX_KERNEL") ||
415 !levelCheckKernel(op, v: shape[2], checkDesc: "KW <= MAX_KERNEL")) {
416 return false;
417 }
418 }
419 for (auto p : transpose.getOutPad()) {
420 if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
421 return false;
422 }
423 }
424 for (auto s : transpose.getStride()) {
425 if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
426 return false;
427 }
428 }
429 }
430 return true;
431 }
432
433 // Resize op: level check max scales
434 bool levelCheckResize(Operation *op) {
435 if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
436 SmallVector<int64_t> scale;
437 if (!tosa::getConstShapeValues(op: resize.getScale().getDefiningOp(),
438 result_shape&: scale)) {
439 return false;
440 }
441 const int64_t scaleYN = scale[0];
442 const int64_t scaleYD = scale[1];
443 const int64_t scaleXN = scale[2];
444 const int64_t scaleXD = scale[3];
445 if (!levelCheckScale(op, v: scaleYN / scaleYD,
446 checkDesc: "scale_y_n/scale_y_d <= MAX_SCALE") ||
447 !levelCheckScale(op, v: scaleXN / scaleXD,
448 checkDesc: "scale_x_n/scale_x_d <= MAX_SCALE")) {
449 return false;
450 }
451 }
452 return true;
453 }
454
455 // Recursively perform a bottom-up search to determine the maximum nesting
456 // depth, starting from a specific operation and continuing up to the function
457 // or module scope. Tosa nesting_depth starts at 0 and increments by one each
458 // time a new nested `region` is encountered.
459 static void getMaxNestedDepth(Operation *op, int32_t &depth) {
460 if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
461 return;
462
463 op = op->getParentOp();
464 if (!op)
465 return;
466
467 depth++;
468 getMaxNestedDepth(op, depth);
469 }
470
471 bool levelCheckMaxNesting(Operation *op) {
472 int32_t maxNestedDepth = 0;
473 getMaxNestedDepth(op, depth&: maxNestedDepth);
474
475 if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
476 op->emitOpError() << "failed level check: " << maxNestedDepth
477 << " >= MAX_NESTING";
478 return false;
479 }
480 return true;
481 }
482
483 bool levelCheckListSize(Operation *op) {
484 if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
485 return levelCheckListSize(op, concat.getInput1().size(), "input1");
486 }
487 if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
488 if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") ||
489 !levelCheckListSize(op, custom.getOutputList().size(),
490 "output_list")) {
491 return false;
492 }
493 }
494 if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
495 if (!levelCheckListSize(op, condIf.getInputList().size(), "inputs") ||
496 !levelCheckListSize(op, condIf.getOutputList().size(), "outputs")) {
497 return false;
498 }
499 }
500 if (auto w = dyn_cast<tosa::WhileOp>(op)) {
501 if (!levelCheckListSize(op, w.getInputList().size(), "inputs") ||
502 !levelCheckListSize(op, w.getOutputList().size(), "outputs")) {
503 return false;
504 }
505 }
506 return true;
507 }
508
509 bool attributeCheckRescale(Operation *op) {
510 if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
511 if (rescale.getRoundingMode() == "DOUBLE_ROUND" &&
512 !targetEnv.allows(Extension::doubleround)) {
513 op->emitOpError()
514 << "failed attribute check: rounding_mode = DOUBLE_ROUND "
515 << "requires extension [doubleround]";
516 return false;
517 } else if (rescale.getRoundingMode() == "INEXACT_ROUND" &&
518 !targetEnv.allows(Extension::inexactround)) {
519 op->emitOpError()
520 << "failed attribute check: rounding_mode = INEXACT_ROUND "
521 << "requires extension [inexactround]";
522 return false;
523 }
524 }
525 return true;
526 }
527
528 // configure profile and level values from pass options profileName and
529 // levelName
530 void configLevelAndProfile() {
531 tosaLevel = TOSA_LEVEL_NONE;
532 if (level == TosaLevelEnum::EightK) {
533 tosaLevel = TOSA_LEVEL_EIGHTK;
534 }
535
536 if (!profile.empty()) {
537 for (std::string &prof : profile) {
538 auto profSymbol = symbolizeProfile(prof);
539 if (profSymbol) {
540 targetEnv.addProfile(profSymbol.value());
541 } else {
542 llvm::errs() << "unknown TOSA profile name passed in: " << prof
543 << ", supported profiles are `pro_int` and `pro_fp`\n";
544 return signalPassFailure();
545 }
546 }
547 }
548
549 if (!extension.empty()) {
550 for (std::string &ext : extension) {
551 auto extSymbol = symbolizeExtension(ext);
552 if (extSymbol) {
553 targetEnv.addExtension(extSymbol.value());
554 } else {
555 llvm::errs() << "unknown TOSA extension name passed in: " << ext
556 << ", supported extension are int16, int4, bf16, "
557 << "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
558 << "doubleround, inexactround and dynamic\n";
559 return signalPassFailure();
560 }
561 }
562 }
563 }
564
565 bool CheckVariable(Operation *op);
566 bool CheckVariableReadOrWrite(Operation *op);
567 bool isValidElementType(Type type, const bool allowUnsigned = false);
568
569 SmallVector<
570 std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
571 constCheckers;
572 TosaLevel tosaLevel;
573 DenseMap<StringAttr, mlir::Type> variablesMap;
574 TosaProfileCompliance profileComp;
575 tosa::TargetEnv targetEnv;
576};
577
578template <>
579bool TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
580 auto op = tosaOp.getOperation();
581 if (!levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK))
582 return false;
583
584 // rank(output) = rank(input) - 1
585 if (!levelCheckRank(op, tosaOp.getOutput(), "result", tosaLevel.MAX_RANK - 1))
586 return false;
587
588 return true;
589}
590
591template <>
592bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
593 auto op = tosaOp.getOperation();
594
595 // Only the condition input has rank limitation.
596 if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK))
597 return false;
598
599 return true;
600}
601
602template <>
603bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
604 auto op = tosaOp.getOperation();
605 auto variableType = getVariableType(tosaOp);
606 if (!levelCheckRank(op, variableType, "variable type", tosaLevel.MAX_RANK))
607 return false;
608
609 return true;
610}
611
612template <>
613bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
614 auto op = tosaOp.getOperation();
615 auto variableType = getVariableType(tosaOp);
616 if (!levelCheckSize(op, variableType, "variable type"))
617 return false;
618
619 return true;
620}
621
622bool TosaValidation::levelCheckRanksAndSizes(Operation *op) {
623#define CHECK_RANKS_AND_SIZES(tosaOp) \
624 if (isa<tosa::tosaOp##Op>(op)) { \
625 if (!levelCheckRanks(cast<tosa::tosaOp##Op>(op))) \
626 return false; \
627 if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op))) \
628 return false; \
629 }
630
631#define CHECK_SIZES(tosaOp) \
632 if (isa<tosa::tosaOp##Op>(op)) { \
633 if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op))) \
634 return false; \
635 }
636
637 // Tensor Operators
638 CHECK_RANKS_AND_SIZES(ArgMax);
639 // Activation Functions
640 CHECK_RANKS_AND_SIZES(Clamp);
641 CHECK_RANKS_AND_SIZES(Erf);
642 CHECK_RANKS_AND_SIZES(Sigmoid);
643 CHECK_RANKS_AND_SIZES(Tanh);
644 // Elementwise Binary Operators
645 CHECK_RANKS_AND_SIZES(Add);
646 CHECK_RANKS_AND_SIZES(ArithmeticRightShift);
647 CHECK_RANKS_AND_SIZES(BitwiseAnd);
648 CHECK_RANKS_AND_SIZES(BitwiseOr);
649 CHECK_RANKS_AND_SIZES(BitwiseXor);
650 CHECK_RANKS_AND_SIZES(IntDiv);
651 CHECK_RANKS_AND_SIZES(LogicalAnd);
652 CHECK_RANKS_AND_SIZES(LogicalLeftShift);
653 CHECK_RANKS_AND_SIZES(LogicalRightShift);
654 CHECK_RANKS_AND_SIZES(LogicalOr);
655 CHECK_RANKS_AND_SIZES(LogicalXor);
656 CHECK_RANKS_AND_SIZES(Maximum);
657 CHECK_RANKS_AND_SIZES(Minimum);
658 CHECK_RANKS_AND_SIZES(Mul);
659 CHECK_RANKS_AND_SIZES(Pow);
660 CHECK_RANKS_AND_SIZES(Sub);
661 CHECK_RANKS_AND_SIZES(Table);
662 // Elementwise Unary Operators
663 CHECK_RANKS_AND_SIZES(Abs);
664 CHECK_RANKS_AND_SIZES(BitwiseNot);
665 CHECK_RANKS_AND_SIZES(Ceil);
666 CHECK_RANKS_AND_SIZES(Clz);
667 CHECK_RANKS_AND_SIZES(Cos);
668 CHECK_RANKS_AND_SIZES(Exp);
669 CHECK_RANKS_AND_SIZES(Floor);
670 CHECK_RANKS_AND_SIZES(Log);
671 CHECK_RANKS_AND_SIZES(LogicalNot);
672 CHECK_RANKS_AND_SIZES(Negate);
673 CHECK_RANKS_AND_SIZES(Reciprocal);
674 CHECK_RANKS_AND_SIZES(Rsqrt);
675 CHECK_RANKS_AND_SIZES(Sin);
676 // Elementwise Ternary Operators
677 CHECK_RANKS_AND_SIZES(Select);
678 // Comparison Operators
679 CHECK_RANKS_AND_SIZES(Equal);
680 CHECK_RANKS_AND_SIZES(Greater);
681 CHECK_RANKS_AND_SIZES(GreaterEqual);
682 // Reduction Operators
683 CHECK_RANKS_AND_SIZES(ReduceAll);
684 CHECK_RANKS_AND_SIZES(ReduceAny);
685 CHECK_RANKS_AND_SIZES(ReduceMax);
686 CHECK_RANKS_AND_SIZES(ReduceMin);
687 CHECK_RANKS_AND_SIZES(ReduceProduct);
688 CHECK_RANKS_AND_SIZES(ReduceSum);
689 // Data Layout Operators
690 CHECK_RANKS_AND_SIZES(Concat);
691 CHECK_RANKS_AND_SIZES(Pad);
692 CHECK_RANKS_AND_SIZES(Reshape);
693 CHECK_RANKS_AND_SIZES(Reverse);
694 CHECK_RANKS_AND_SIZES(Slice);
695 CHECK_RANKS_AND_SIZES(Tile);
696 CHECK_RANKS_AND_SIZES(Transpose);
697 // Type Conversion
698 CHECK_RANKS_AND_SIZES(Cast);
699 CHECK_RANKS_AND_SIZES(Rescale);
700 // Control Flow Operators
701 CHECK_RANKS_AND_SIZES(If);
702 // Variable Operators
703 CHECK_RANKS_AND_SIZES(Variable);
704 CHECK_RANKS_AND_SIZES(VariableWrite);
705 CHECK_RANKS_AND_SIZES(VariableRead);
706 // Data Nodes
707 CHECK_RANKS_AND_SIZES(Const);
708 CHECK_RANKS_AND_SIZES(Identity);
709
710 // For the following operators, check whether the size of each tensor
711 // operand is valid in a given Level.
712
713 // Tensor Operators
714 CHECK_SIZES(AvgPool2d);
715 CHECK_SIZES(Conv2D);
716 CHECK_SIZES(Conv3D);
717 CHECK_SIZES(DepthwiseConv2D);
718 CHECK_SIZES(TransposeConv2D);
719 CHECK_SIZES(FFT2d);
720 CHECK_SIZES(MatMul);
721 CHECK_SIZES(MaxPool2d);
722 CHECK_SIZES(RFFT2d);
723 // Scatter/Gather Operators
724 CHECK_SIZES(Gather);
725 CHECK_SIZES(Scatter);
726 // Image Operators
727 CHECK_SIZES(Resize);
728 // Custom Operators
729 CHECK_SIZES(Custom);
730 // Control Flow Operators
731 CHECK_SIZES(While);
732 // Shape Operators
733 CHECK_SIZES(ConstShape);
734
735#undef CHECK_RANKS_AND_SIZES
736#undef CHECK_SIZES
737 return true;
738}
739
740// Perform the Level tensor size check on the tensor type.
741bool TosaValidation::levelCheckSize(Operation *op, const Type &typeToCheck,
742 const StringRef operandOrResult) {
743 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
744 if (!type.hasRank()) {
745 op->emitOpError() << "failed level check: unranked tensor";
746 return false;
747 }
748 auto shape = type.getShape();
749 for (auto dim : shape) {
750 if (mlir::ShapedType::isDynamic(dim)) {
751 op->emitOpError() << "failed level check: " << operandOrResult
752 << " shape dimension cannot be dynamic";
753 return false;
754 }
755 }
756
757 int64_t element_bits = type.getElementTypeBitWidth();
758 int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
759 int64_t size = element_bytes * type.getNumElements();
760
761 // According to 1.11. Tensor Definitions of Tosa spec, the value of
762 // tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
763 // defined in 1.7. Levels.
764 // For each tensor, the number of tensor elements multiplied by the
765 // element size in bytes must be representable as a tensor_size_t.
766 const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
767 if (size > max_size) {
768 op->emitOpError()
769 << "failed level check: " << operandOrResult
770 << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
771 return false;
772 }
773 }
774 return true;
775}
776
777LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
778 if (tosaLevel == TOSA_LEVEL_NONE) {
779 // no need to do level checks
780 return success();
781 }
782
783 // check rank and sizes early so later checks can assume shaped operands
784 if (!levelCheckRanksAndSizes(op))
785 return failure();
786
787 // additional level checks from spec 0.70
788 if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
789 !levelCheckConv<tosa::Conv2DOp>(op) ||
790 !levelCheckConv<tosa::Conv3DOp>(op) ||
791 !levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
792 !levelCheckFFT<tosa::FFT2dOp>(op) ||
793 !levelCheckPool<tosa::MaxPool2dOp>(op) ||
794 !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
795 !levelCheckResize(op)) {
796 return failure();
797 }
798
799 // level check MAX_TENSOR_LIST_SIZE
800 if (!levelCheckListSize(op)) {
801 return failure();
802 }
803
804 if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
805 if (!levelCheckMaxNesting(op)) {
806 return failure();
807 }
808 }
809
810 return success();
811}
812
813LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
814 if (!attributeCheckRescale(op))
815 return failure();
816 return success();
817}
818
819inline bool CompatibleTypes(const mlir::Type &type,
820 const mlir::Type &declaredType) {
821 // for now, simply use type equality comparison
822 return type == declaredType;
823}
824
825bool TosaValidation::CheckVariable(Operation *op) {
826 if (auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
827 mlir::StringAttr nameAttr = variableOp.getNameAttr();
828
829 if (variablesMap.count(nameAttr)) {
830 op->emitOpError() << "name has already been declared";
831 return false;
832 }
833
834 auto elementType = variableOp.getType();
835 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
836 SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
837 RankedTensorType variableType =
838 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
839
840 variablesMap[nameAttr] = variableType;
841 }
842
843 return true;
844}
845
846bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
847 if (isa<mlir::tosa::VariableReadOp>(op) ||
848 isa<mlir::tosa::VariableWriteOp>(op)) {
849 mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
850 if (!variablesMap.count(nameAttr)) {
851 op->emitOpError() << "name has not been declared";
852 return false;
853 }
854
855 auto varType = variablesMap[nameAttr];
856
857 for (auto v : op->getOperands()) {
858 auto type = v.getType();
859 if (!CompatibleTypes(type, varType)) {
860 op->emitOpError() << "operand type does not equal variable type";
861 return false;
862 }
863 }
864
865 for (auto v : op->getResults()) {
866 auto type = v.getType();
867 if (!CompatibleTypes(type, varType)) {
868 op->emitOpError() << "result type does not equal variable type";
869 return false;
870 }
871 }
872 }
873
874 return true;
875}
876
877LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
878 if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
879 return failure();
880 }
881 return success();
882}
883
884bool checkErrorIfResize(Operation *op) {
885 auto resize = dyn_cast<tosa::ResizeOp>(op);
886 if (!resize)
887 return true;
888
889 const Value input = resize.getInput();
890 const Value output = resize.getOutput();
891 const RankedTensorType inputType =
892 llvm::dyn_cast<RankedTensorType>(input.getType());
893 const RankedTensorType outputType =
894 llvm::dyn_cast<RankedTensorType>(output.getType());
895
896 if (!inputType || !outputType) {
897 op->emitOpError(message: "expect ranked input/output tensor");
898 return false;
899 }
900
901 // Ensure the image size is supported by GPU APIs and that for integer
902 // implementations, position * stride does not overflow int32_t.
903 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
904 const SmallVector<int64_t, 4> sizes = {
905 outputType.getDimSize(1), outputType.getDimSize(2),
906 inputType.getDimSize(1), inputType.getDimSize(2)};
907 const int64_t *maxDim = llvm::max_element(sizes);
908 if (maxDim != sizes.end() && *maxDim >= 16384) {
909 op->emitOpError(message: "expect input/output height/width dims to be < 16384, ")
910 << "got [OH, OW, IH, IW] = " << sizes;
911 return false;
912 }
913 }
914
915 SmallVector<int64_t> scale;
916 if (!tosa::getConstShapeValues(op: resize.getScale().getDefiningOp(), result_shape&: scale)) {
917 return false;
918 }
919
920 const int64_t scaleYN = scale[0];
921 const int64_t scaleYD = scale[1];
922 const int64_t scaleXN = scale[2];
923 const int64_t scaleXD = scale[3];
924
925 // Ensure scale values don't overflow int32 accumulator
926 if (scaleYN > (1 << 11) || scaleXN > (1 << 11)) {
927 op->emitOpError(message: "expect all scale numerator values to be <= (1 << 11), "
928 "got scale_y_n=")
929 << scaleYN << ", scale_x_n=" << scaleXN;
930 return false;
931 }
932
933 if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) {
934 op->emitOpError(message: "expect a downscale ratio larger than 1/16, got y=")
935 << scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD;
936 return false;
937 }
938
939 SmallVector<int64_t> offset;
940 SmallVector<int64_t> border;
941 if (!tosa::getConstShapeValues(op: resize.getOffset().getDefiningOp(), result_shape&: offset) ||
942 !tosa::getConstShapeValues(op: resize.getBorder().getDefiningOp(), result_shape&: border)) {
943 return false;
944 }
945
946 const int64_t offsetY = offset[0];
947 const int64_t offsetX = offset[1];
948 // Set a consistent lower limit of 1/16 downscale to simplify
949 // implementations
950 if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) {
951 op->emitOpError(
952 message: "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
953 << offsetY << "/" << scaleYN;
954 return false;
955 }
956 if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) {
957 op->emitOpError(
958 message: "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
959 << offsetX << "/" << scaleXN;
960 return false;
961 }
962
963 const int64_t borderY = border[0];
964 const int64_t borderX = border[1];
965 if (borderY < -16 * scaleYN || borderY >= scaleYN) {
966 op->emitOpError(
967 message: "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
968 << borderY << "/" << scaleYN;
969 return false;
970 }
971 if (borderX < -16 * scaleXN || borderX >= scaleXN) {
972 op->emitOpError(
973 message: "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
974 << borderX << "/" << scaleXN;
975 return false;
976 }
977
978 // The following section of code is mostly duplicated with ResizeOp::verify().
979 //
980 // In TOSA specification, we do not support broadcast behavior.
981 // However, there is a rewrite pattern to materialize broadcast ResizeOp.
982 // It makes invalid TOSA ResizeOp into valid one. To avoid breaking
983 // existing code, we keep the rewrite pattern untouched. So, we need
984 // loose the checking in ResizeOp::verify() to support broadcast ResizeOp.
985 //
986 // Here is a strict checking to conform TOSA specification.
987 // FIXME: Remove the duplicated checkings when broadcast ResizeOp is removed.
988 auto idivCheck = [](const int64_t lhs,
989 const int64_t rhs) -> std::optional<int64_t> {
990 if (lhs % rhs != 0)
991 return std::nullopt;
992 return lhs / rhs;
993 };
994
995 const int64_t oh = outputType.getDimSize(1);
996 const int64_t ow = outputType.getDimSize(2);
997 const int64_t ih = inputType.getDimSize(1);
998 const int64_t iw = inputType.getDimSize(2);
999
1000 if (ih != ShapedType::kDynamic) {
1001 const std::optional<int64_t> calculatedOutHeightMinusOne =
1002 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
1003 if (!calculatedOutHeightMinusOne.has_value()) {
1004 op->emitOpError(message: "expected (input_height - 1) * scale_y_n - offset_y + "
1005 "border_y ")
1006 << "to be wholly divisible by scale_y_d, got ((" << ih << " - 1) * "
1007 << scaleYN << " - " << offsetY << " + " << borderY << ") / "
1008 << scaleYD;
1009 return false;
1010 }
1011 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
1012 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh) {
1013 op->emitOpError(message: "calculated output height did not match expected: ")
1014 << "calculated=" << calculatedOutHeight << ", expected=" << oh;
1015 return false;
1016 }
1017 }
1018
1019 if (iw != ShapedType::kDynamic) {
1020 const std::optional<int64_t> calculatedOutWidthMinusOne =
1021 idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
1022 if (!calculatedOutWidthMinusOne.has_value()) {
1023 op->emitOpError(message: "expected (input_width - 1) * scale_x_n - offset_x + "
1024 "border_x ")
1025 << "to be wholly divisible by scale_x_d, got ((" << iw << " - 1) * "
1026 << scaleXN << " - " << offsetX << " + " << borderX << ") / "
1027 << scaleXD;
1028 return false;
1029 }
1030 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1031 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow) {
1032 op->emitOpError(message: "calculated output width did not match expected: ")
1033 << "calculated=" << calculatedOutWidth << ", expected=" << ow;
1034 return false;
1035 }
1036 }
1037
1038 return true;
1039}
1040
1041bool checkErrorIfMul(Operation *op) {
1042 auto mul = dyn_cast<tosa::MulOp>(op);
1043 if (!mul)
1044 return true;
1045
1046 // REQUIRE(0 <= shift && shift <= 63);
1047 // REQUIRE(is_same<in_t,int32_t>() || shift == 0);
1048 ElementsAttr shift_elem;
1049 if (!matchPattern(mul.getShift(), m_Constant(&shift_elem))) {
1050 return true;
1051 }
1052 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1053 auto inputElemType = getElementTypeOrSelf(mul.getInput1());
1054 if (inputElemType.isInteger(32)) {
1055 // 0 <= shift <= 63 for int32_t type
1056 if (shift < 0 || shift > 63) {
1057 op->emitOpError() << "requires 0 <= shift && shift <= 63, but got: "
1058 << shift;
1059 return false;
1060 }
1061 } else {
1062 // shift must be 0 for all other types
1063 if (shift != 0) {
1064 op->emitOpError() << "requires shift = 0 for all input data types that "
1065 "are not int32_t, but got: "
1066 << shift;
1067 return false;
1068 }
1069 }
1070
1071 return true;
1072}
1073
1074bool checkErrorIfTable(Operation *op) {
1075 auto table = dyn_cast<tosa::TableOp>(op);
1076 if (!table)
1077 return true;
1078
1079 // REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513
1080 const auto inputElemType = getElementTypeOrSelf(table.getInput1().getType());
1081 const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
1082
1083 const ShapeAdaptor tableShape(table.getTable().getType());
1084 if (tableShape.hasStaticShape()) {
1085 const auto numElements = tableShape.getNumElements();
1086 if (numElements != tableSize) {
1087 op->emitOpError() << "requires table size of " << tableSize << ", got "
1088 << numElements;
1089 return false;
1090 }
1091 }
1092
1093 return true;
1094}
1095
1096bool checkErrorIfRescale(Operation *op) {
1097 auto rescale = dyn_cast<tosa::RescaleOp>(op);
1098 if (!rescale)
1099 return true;
1100
1101 auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1102 auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1103 if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1104 !outputType.getElementType().isInteger())
1105 return true;
1106
1107 auto inElemType = inputType.getElementType();
1108 auto outElemType = outputType.getElementType();
1109 auto inWidth = inElemType.getIntOrFloatBitWidth();
1110 auto outWidth = outElemType.getIntOrFloatBitWidth();
1111
1112 bool inputUnsigned = rescale.getInputUnsigned();
1113 bool outputUnsigned = rescale.getOutputUnsigned();
1114
1115 bool scale32 = rescale.getScale32();
1116 auto roundingMode = rescale.getRoundingMode();
1117
1118 // ERROR_IF(scale32 && is_same<in_t,i48_t>())
1119 if (scale32 && inWidth == 48) {
1120 op->emitOpError() << "scale32 is not allowed with 48-bit input.";
1121 return false;
1122 }
1123
1124 // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
1125 if (!scale32 && roundingMode == "DOUBLE_ROUND") {
1126 op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true.";
1127 return false;
1128 }
1129
1130 // ERROR_IF(input_unsigned && output_unsigned)
1131 if (inputUnsigned && outputUnsigned) {
1132 op->emitOpError() << "input and output cannot be both unsigned.";
1133 return false;
1134 }
1135
1136 // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
1137 if (outWidth == 32 && inputUnsigned) {
1138 op->emitOpError() << "i32 output type is not allowed with unsigned input.";
1139 return false;
1140 }
1141
1142 // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
1143 if (inWidth == 32 && outputUnsigned) {
1144 op->emitOpError() << "i32 input type is not allowed with unsigned output.";
1145 return false;
1146 }
1147
1148 // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
1149 if (inWidth == 48 && outputUnsigned) {
1150 op->emitOpError() << "i48 input type is not allowed with unsigned output.";
1151 return false;
1152 }
1153
1154 // ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
1155 if (inWidth == 48 && inputUnsigned) {
1156 op->emitOpError() << "i48 input type cannot be unsigned.";
1157 return false;
1158 }
1159
1160 // ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
1161 if (inWidth == 32 && inputUnsigned) {
1162 op->emitOpError() << "i32 input type cannot be unsigned.";
1163 return false;
1164 }
1165
1166 // ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
1167 if (outWidth == 32 && outputUnsigned) {
1168 op->emitOpError() << "i32 output type cannot be unsigned.";
1169 return false;
1170 }
1171
1172 return true;
1173}
1174
1175bool checkErrorIfPad(Operation *op) {
1176 auto pad = dyn_cast<tosa::PadOp>(op);
1177 if (!pad)
1178 return true;
1179
1180 DenseIntElementsAttr paddingAttr;
1181 if (!matchPattern(pad.getPadding(), m_Constant(&paddingAttr)))
1182 // Pad verifier will catch this
1183 return true;
1184
1185 for (const APInt &val : paddingAttr.getValues<APInt>()) {
1186 if (val.getSExtValue() < 0) {
1187 op->emitOpError() << "padding value must all be non-negative, got "
1188 << val.getSExtValue();
1189 return false;
1190 }
1191 }
1192
1193 return true;
1194}
1195
1196// Returns true if the operation takes no input operands, excluding attributes.
1197static bool isNullaryOperation(Operation *op) {
1198 if (isa<tosa::ConstOp>(op) || isa<tosa::ConstShapeOp>(op) ||
1199 isa<tosa::YieldOp>(op) || isa<tosa::VariableOp>(op))
1200 return true;
1201 return false;
1202}
1203
1204bool checkErrorIfCondIf(Operation *op) {
1205 auto ifOp = dyn_cast<tosa::IfOp>(op);
1206 if (!ifOp)
1207 return true;
1208
1209 // Whether the types and shapes of operands between the input/output list and
1210 // internal regions are validated by the operation verifier. However, with
1211 // support for the simplified form - where redundant operand notations are
1212 // omitted - is not conformant to the specification. According to the
1213 // specification, all operands passed into an operation must be explicitly
1214 // declared at each operation's structure. This code section verify that the
1215 // operation's form complies with this requirement.
1216
1217 // Returns true if the region uses no external input operands.
1218 auto isNullaryRegion = [](Region &region) -> bool {
1219 bool noLiveInValue = true;
1220 region.walk([&noLiveInValue](Operation *op) {
1221 if (!isNullaryOperation(op)) {
1222 noLiveInValue = false;
1223 return WalkResult::interrupt();
1224 }
1225 return WalkResult::advance();
1226 });
1227 return noLiveInValue;
1228 };
1229
1230 mlir::Region &thenGraph = ifOp.getThenGraph();
1231 mlir::Region &elseGraph = ifOp.getElseGraph();
1232 bool isThenGraphNullaryRegion = isNullaryRegion(thenGraph);
1233 bool isElseGraphNullaryRegion = isNullaryRegion(elseGraph);
1234 bool isInputListEmpty = ifOp.getInputList().size() == 0;
1235
1236 if ((isInputListEmpty != isThenGraphNullaryRegion) ||
1237 (isInputListEmpty != isElseGraphNullaryRegion)) {
1238 op->emitOpError()
1239 << "the current simplified form is not strictly conformant to the "
1240 "spec, please use the generic format\n";
1241 return false;
1242 }
1243
1244 return true;
1245}
1246
1247LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1248 if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
1249 !checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
1250 !checkErrorIfPad(op) || !checkErrorIfCondIf(op))
1251 return failure();
1252 return success();
1253}
1254
1255bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
1256 if (isa<FloatType>(type)) {
1257 return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1258 Float8E5M2Type>(type);
1259 } else if (auto intTy = dyn_cast<IntegerType>(type)) {
1260 if (intTy.isSignless()) {
1261 switch (intTy.getWidth()) {
1262 case 1:
1263 case 4:
1264 case 8:
1265 case 16:
1266 case 32:
1267 case 48:
1268 return true;
1269 }
1270 } else if (allowUnsigned && intTy.isUnsigned()) {
1271 switch (intTy.getWidth()) {
1272 case 8:
1273 case 16:
1274 case 32:
1275 return true;
1276 }
1277 }
1278 } else if (mlir::isa<tosa::shapeType>(type)) {
1279 return true;
1280 }
1281 return false;
1282}
1283
1284void TosaValidation::runOnOperation() {
1285 configLevelAndProfile();
1286
1287 TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
1288 if (!tosaDialect)
1289 return;
1290
1291 getOperation().walk([&](Operation *op) {
1292 if (op->getDialect() != tosaDialect)
1293 return;
1294
1295 // validate operator element types:
1296 // - rescale operator is allowed to have ui8/ui16/ui32
1297 // operands/results
1298 // - perform valid element type check at the beginning to
1299 // protect rest of code against quantized element types
1300 const bool opIsRescale = isa<tosa::RescaleOp>(op);
1301 for (Value operand : op->getOperands()) {
1302 auto elementTy = getElementTypeOrSelf(val: operand);
1303 if (!isValidElementType(type: elementTy, allowUnsigned: opIsRescale)) {
1304 op->emitOpError() << "is not profile-aligned: element type "
1305 << elementTy << " is not legal";
1306 return signalPassFailure();
1307 }
1308 }
1309 for (Type resultTy : op->getResultTypes()) {
1310 auto elementTy = getElementTypeOrSelf(resultTy);
1311 if (!isValidElementType(elementTy, opIsRescale)) {
1312 op->emitOpError() << "is not profile-aligned: element type "
1313 << elementTy << " is not legal";
1314 return signalPassFailure();
1315 }
1316 }
1317
1318 if (strictOpSpecAlignment &&
1319 failed(profileComp.checkProfile(op, targetEnv)))
1320 return signalPassFailure();
1321
1322 if (strictOpSpecAlignment &&
1323 failed(profileComp.checkExtension(op, targetEnv)))
1324 return signalPassFailure();
1325
1326 if (!allowInvalidOpDatatypeCombinations &&
1327 failed(profileComp.checkInvalid(op)))
1328 return signalPassFailure();
1329
1330 // Some uses of TOSA rely on the constant operands of particular
1331 // operations.
1332 if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op)))
1333 signalPassFailure();
1334
1335 // do level checks
1336 if (failed(Result: applyLevelCheck(op)))
1337 signalPassFailure();
1338
1339 // check additional attribute restrictions
1340 if (failed(Result: applyAttributeCheck(op)))
1341 signalPassFailure();
1342
1343 // do variable type checks
1344 if (failed(Result: applyVariableCheck(op)))
1345 signalPassFailure();
1346
1347 // do error if checks
1348 if (strictOpSpecAlignment && failed(applyErrorIfCheck(op)))
1349 signalPassFailure();
1350 });
1351}
1352} // namespace
1353

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp