1 | //===- TosaValidation.cpp ------------------------------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Validate if TOSA dialect input matchs with the specification for given |
10 | // requirements. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Tosa/IR/TargetEnv.h" |
15 | #include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h" |
16 | #include "mlir/Dialect/Tosa/Transforms/Passes.h" |
17 | #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc" |
18 | |
19 | #include <string> |
20 | |
21 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
22 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
23 | #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" |
24 | #include "mlir/IR/Builders.h" |
25 | #include "mlir/IR/BuiltinOps.h" |
26 | #include "mlir/IR/Matchers.h" |
27 | #include "mlir/IR/TypeUtilities.h" |
28 | #include "mlir/Pass/Pass.h" |
29 | #include "mlir/Transforms/DialectConversion.h" |
30 | #include "llvm/ADT/StringExtras.h" |
31 | |
32 | namespace mlir { |
33 | namespace tosa { |
34 | #define GEN_PASS_DEF_TOSAVALIDATION |
35 | #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" |
36 | } // namespace tosa |
37 | } // namespace mlir |
38 | |
39 | using namespace mlir; |
40 | using namespace mlir::tosa; |
41 | |
42 | namespace { |
43 | |
44 | static LogicalResult |
45 | checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) { |
46 | for (const auto index : operandIndices) { |
47 | Attribute attr; |
48 | if (!matchPattern(op->getOperand(idx: index), m_Constant(&attr))) { |
49 | return op->emitOpError(message: "expected compile time resolvable constant, but " |
50 | "got variable value for operand #") |
51 | << index; |
52 | } |
53 | } |
54 | return success(); |
55 | } |
56 | |
57 | static LogicalResult checkConstantOperandMul(Operation *op, |
58 | const TargetEnv &env) { |
59 | if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) { |
60 | // Check 'shift' |
61 | return checkConstantOperands(op, operandIndices: {2}); |
62 | } |
63 | return success(); |
64 | } |
65 | |
66 | static LogicalResult checkConstantOperandTable(Operation *op, |
67 | const TargetEnv &env) { |
68 | if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) { |
69 | // Check 'table' |
70 | return checkConstantOperands(op, operandIndices: {1}); |
71 | } |
72 | return success(); |
73 | } |
74 | |
75 | static LogicalResult checkConstantOperandPad(Operation *op, |
76 | const TargetEnv &env) { |
77 | if (auto padOp = dyn_cast<tosa::PadOp>(op)) { |
78 | // Assume this op is zero-padding if padConst is not presented |
79 | if (!env.allows(Extension::dynamic) && padOp.getPadConst()) |
80 | // Check 'pad_const' |
81 | // Note: 'padding' (operand 1) is not checked as it is a tosa.shape type |
82 | return checkConstantOperands(op, operandIndices: {2}); |
83 | } |
84 | return success(); |
85 | } |
86 | |
87 | static LogicalResult checkConstantOperandRescale(Operation *op, |
88 | const TargetEnv &env) { |
89 | if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) { |
90 | // Check 'multiplier', 'shift', 'input_zp' and 'output_zp' |
91 | return checkConstantOperands(op, {1, 2, 3, 4}); |
92 | } |
93 | return success(); |
94 | } |
95 | |
96 | template <typename T> |
97 | static LogicalResult checkConstantOperandConvOps(Operation *op, |
98 | const TargetEnv &env) { |
99 | if (!env.allows(Extension::dynamic) && isa<T>(op)) { |
100 | // Check 'input_zp' and 'weight_zp' |
101 | return checkConstantOperands(op, {3, 4}); |
102 | } |
103 | return success(); |
104 | } |
105 | |
106 | static LogicalResult checkConstantOperandMatMul(Operation *op, |
107 | const TargetEnv &env) { |
108 | if (!env.allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) { |
109 | // Check 'A_zp' and 'B_zp' |
110 | return checkConstantOperands(op, {2, 3}); |
111 | } |
112 | return success(); |
113 | } |
114 | |
115 | static LogicalResult checkConstantOperandAvgPool2d(Operation *op, |
116 | const TargetEnv &env) { |
117 | if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) { |
118 | // Check 'input_zp' and 'output_zp' |
119 | return checkConstantOperands(op, {1, 2}); |
120 | } |
121 | return success(); |
122 | } |
123 | |
124 | static LogicalResult checkConstantOperandNegate(Operation *op, |
125 | const TargetEnv &env) { |
126 | if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) { |
127 | // Check 'input1_zp' and 'output_zp' |
128 | return checkConstantOperands(op, {1, 2}); |
129 | } |
130 | return success(); |
131 | } |
132 | |
133 | struct TosaLevel { |
134 | int32_t MAX_RANK = 0; |
135 | int32_t MAX_KERNEL = 0; |
136 | int32_t MAX_STRIDE = 0; |
137 | int32_t MAX_SCALE = 0; |
138 | int32_t MAX_LOG2_SIZE = 0; |
139 | int32_t MAX_NESTING = 0; |
140 | int32_t MAX_TENSOR_LIST_SIZE = 0; |
141 | |
142 | bool operator==(const TosaLevel &rhs) { |
143 | return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL && |
144 | MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE && |
145 | MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE && |
146 | MAX_NESTING == rhs.MAX_NESTING && |
147 | MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE; |
148 | } |
149 | }; |
150 | |
151 | static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {.MAX_RANK: 6, .MAX_KERNEL: 8192, .MAX_STRIDE: 8192, .MAX_SCALE: 256, .MAX_LOG2_SIZE: 31, .MAX_NESTING: 6, .MAX_TENSOR_LIST_SIZE: 64}; |
152 | static constexpr TosaLevel TOSA_LEVEL_NONE = {.MAX_RANK: 32, .MAX_KERNEL: 2147483647, .MAX_STRIDE: 2147483647, .MAX_SCALE: 2048, |
153 | .MAX_LOG2_SIZE: 63, .MAX_NESTING: 256, .MAX_TENSOR_LIST_SIZE: 256}; |
154 | |
155 | //===----------------------------------------------------------------------===// |
156 | // TOSA Validation Pass. |
157 | //===----------------------------------------------------------------------===// |
158 | |
159 | struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> { |
160 | public: |
161 | explicit TosaValidation() { populateConstantOperandChecks(); } |
162 | |
163 | explicit TosaValidation(const TosaValidationOptions &options) |
164 | : TosaValidation() { |
165 | this->profile = options.profile; |
166 | this->extension = options.extension; |
167 | this->strictOpSpecAlignment = options.strictOpSpecAlignment; |
168 | this->allowInvalidOpDatatypeCombinations = |
169 | options.allowInvalidOpDatatypeCombinations; |
170 | this->level = options.level; |
171 | } |
172 | void runOnOperation() final; |
173 | |
174 | LogicalResult applyConstantOperandCheck(Operation *op) { |
175 | for (auto &checker : constCheckers) { |
176 | if (failed(checker(op, targetEnv))) |
177 | return failure(); |
178 | } |
179 | return success(); |
180 | } |
181 | |
182 | LogicalResult applyLevelCheck(Operation *op); |
183 | LogicalResult applyAttributeCheck(Operation *op); |
184 | |
185 | // check variable read/write data types against variable declarations |
186 | LogicalResult applyVariableCheck(Operation *op); |
187 | |
188 | // check error if conditions |
189 | LogicalResult applyErrorIfCheck(Operation *op); |
190 | |
191 | private: |
192 | void populateConstantOperandChecks() { |
193 | constCheckers.emplace_back(checkConstantOperandMul); |
194 | constCheckers.emplace_back(checkConstantOperandTable); |
195 | constCheckers.emplace_back(checkConstantOperandPad); |
196 | constCheckers.emplace_back(checkConstantOperandRescale); |
197 | constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>); |
198 | constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>); |
199 | constCheckers.emplace_back( |
200 | checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>); |
201 | constCheckers.emplace_back( |
202 | checkConstantOperandConvOps<tosa::TransposeConv2DOp>); |
203 | constCheckers.emplace_back(checkConstantOperandMatMul); |
204 | constCheckers.emplace_back(checkConstantOperandAvgPool2d); |
205 | constCheckers.emplace_back(checkConstantOperandNegate); |
206 | } |
207 | |
208 | bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) { |
209 | if (v > tosaLevel.MAX_KERNEL) { |
210 | op->emitOpError() << "failed level check: "<< checkDesc; |
211 | return false; |
212 | } |
213 | return true; |
214 | } |
215 | |
216 | bool levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) { |
217 | if (v > tosaLevel.MAX_STRIDE) { |
218 | op->emitOpError() << "failed level check: "<< checkDesc; |
219 | return false; |
220 | } |
221 | return true; |
222 | } |
223 | |
224 | bool levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) { |
225 | if (v > tosaLevel.MAX_SCALE) { |
226 | op->emitOpError() << "failed level check: "<< checkDesc; |
227 | return false; |
228 | } |
229 | return true; |
230 | } |
231 | |
232 | bool levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) { |
233 | if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) { |
234 | op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: " |
235 | << checkDesc; |
236 | return false; |
237 | } |
238 | return true; |
239 | } |
240 | |
241 | // Perform the Level Rank check on the tensor type. |
242 | bool levelCheckRank(Operation *op, const Type typeToCheck, |
243 | const StringRef operandOrResult, int32_t highest_rank) { |
244 | if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) { |
245 | if (!type.hasRank()) { |
246 | op->emitOpError() << "failed level check: unranked tensor"; |
247 | return false; |
248 | } |
249 | if (type.getRank() > highest_rank) { |
250 | op->emitOpError() << "failed level check: "<< operandOrResult |
251 | << " rank(shape) <= MAX_RANK"; |
252 | return false; |
253 | } |
254 | } |
255 | return true; |
256 | } |
257 | |
258 | // Perform the Level Rank check on the tensor value. |
259 | bool levelCheckRank(Operation *op, const Value &v, |
260 | const StringRef operandOrResult, int32_t highest_rank) { |
261 | return levelCheckRank(op, typeToCheck: v.getType(), operandOrResult, highest_rank); |
262 | } |
263 | |
264 | // Perform the Level tensor size check on the tensor type. |
265 | bool levelCheckSize(Operation *op, const Type &typeToCheck, |
266 | const StringRef operandOrResult); |
267 | |
268 | // Perform the Level tensor size check on the tensor value. |
269 | bool levelCheckSize(Operation *op, const Value &v, |
270 | const StringRef operandOrResult) { |
271 | return levelCheckSize(op, typeToCheck: v.getType(), operandOrResult); |
272 | } |
273 | |
274 | // Level check sizes of all operands and results of the operation. |
275 | template <typename T> |
276 | bool levelCheckSizes(T tosaOp) { |
277 | auto op = tosaOp.getOperation(); |
278 | for (auto v : op->getOperands()) { |
279 | if (!levelCheckSize(op, v, "operand")) |
280 | return false; |
281 | } |
282 | |
283 | for (auto v : op->getResults()) { |
284 | if (!levelCheckSize(op, v, "result")) |
285 | return false; |
286 | } |
287 | return true; |
288 | } |
289 | |
290 | // Level check ranks of all operands, attribute and results of the operation. |
291 | template <typename T> |
292 | bool levelCheckRanks(T tosaOp) { |
293 | auto op = tosaOp.getOperation(); |
294 | for (auto v : op->getOperands()) { |
295 | if (!levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)) |
296 | return false; |
297 | } |
298 | |
299 | for (auto v : op->getResults()) { |
300 | if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK)) |
301 | return false; |
302 | } |
303 | return true; |
304 | } |
305 | |
306 | // Level check ranks and sizes. |
307 | bool levelCheckRanksAndSizes(Operation *op); |
308 | |
309 | // Pool Op: level check kernel/stride/pad values |
310 | template <typename T> |
311 | bool levelCheckPool(Operation *op) { |
312 | if (auto poolOp = dyn_cast<T>(op)) { |
313 | for (auto k : poolOp.getKernel()) { |
314 | if (!levelCheckKernel(op, v: k, checkDesc: "kernel <= MAX_KERNEL")) { |
315 | return false; |
316 | } |
317 | } |
318 | for (auto s : poolOp.getStride()) { |
319 | if (!levelCheckStride(op, v: s, checkDesc: "stride <= MAX_STRIDE")) { |
320 | return false; |
321 | } |
322 | } |
323 | for (auto p : poolOp.getPad()) { |
324 | if (!levelCheckKernel(op, v: p, checkDesc: "pad <= MAX_KERNEL")) { |
325 | return false; |
326 | } |
327 | } |
328 | } |
329 | return true; |
330 | } |
331 | |
332 | // Conv Op: level check dilation/stride/pad values |
333 | template <typename T> |
334 | bool levelCheckConv(Operation *op) { |
335 | if (auto convOp = dyn_cast<T>(op)) { |
336 | |
337 | for (auto k : convOp.getDilation()) { |
338 | if (!levelCheckKernel(op, v: k, checkDesc: "dilation <= MAX_KERNEL")) { |
339 | return false; |
340 | } |
341 | } |
342 | for (auto p : convOp.getPad()) { |
343 | if (!levelCheckKernel(op, v: p, checkDesc: "pad <= MAX_KERNEL")) { |
344 | return false; |
345 | } |
346 | } |
347 | for (auto s : convOp.getStride()) { |
348 | if (!levelCheckStride(op, v: s, checkDesc: "stride <= MAX_STRIDE")) { |
349 | return false; |
350 | } |
351 | } |
352 | auto dilation = convOp.getDilation(); |
353 | if (ShapedType weightType = |
354 | dyn_cast<ShapedType>(op->getOperand(1).getType())) { |
355 | auto shape = weightType.getShape(); |
356 | if (isa<tosa::Conv2DOp>(op)) { |
357 | assert(shape.size() == 4); |
358 | assert(dilation.size() == 2); |
359 | if (!levelCheckKernel(op, v: dilation[0] * shape[1], |
360 | checkDesc: "dilation_y * KH <= MAX_KERNEL)") || |
361 | !levelCheckKernel(op, v: dilation[1] * shape[2], |
362 | checkDesc: "dilation_x * KW <= MAX_KERNEL)")) |
363 | return false; |
364 | } else if (isa<tosa::Conv3DOp>(op)) { |
365 | assert(shape.size() == 5); |
366 | assert(dilation.size() == 3); |
367 | if (!levelCheckKernel(op, v: dilation[0] * shape[1], |
368 | checkDesc: "dilation_d * KD <= MAX_KERNEL)") || |
369 | !levelCheckKernel(op, v: dilation[1] * shape[2], |
370 | checkDesc: "dilation_y * KH <= MAX_KERNEL)") || |
371 | !levelCheckKernel(op, v: dilation[2] * shape[3], |
372 | checkDesc: "dilation_x * KW <= MAX_KERNEL)")) |
373 | return false; |
374 | } else if (isa<tosa::DepthwiseConv2DOp>(op)) { |
375 | assert(shape.size() == 4); |
376 | assert(dilation.size() == 2); |
377 | if (!levelCheckKernel(op, v: dilation[0] * shape[0], |
378 | checkDesc: "dilation_y * KH <= MAX_KERNEL)") || |
379 | !levelCheckKernel(op, v: dilation[1] * shape[1], |
380 | checkDesc: "dilation_x * KW <= MAX_KERNEL)")) |
381 | return false; |
382 | } |
383 | } |
384 | } |
385 | return true; |
386 | } |
387 | |
388 | // FFT op: level check H, W in input shape [N,H,W] |
389 | template <typename T> |
390 | bool levelCheckFFT(Operation *op) { |
391 | if (isa<T>(op)) { |
392 | for (auto v : op->getOperands()) { |
393 | if (ShapedType type = dyn_cast<ShapedType>(v.getType())) { |
394 | auto shape = type.getShape(); |
395 | assert(shape.size() == 3); |
396 | if (!levelCheckKernel(op, v: shape[1], checkDesc: "H <= MAX_KERNEL") || |
397 | !levelCheckKernel(op, v: shape[2], checkDesc: "W <= MAX_KERNEL")) { |
398 | return false; |
399 | } |
400 | } |
401 | } |
402 | } |
403 | return true; |
404 | } |
405 | |
406 | // TransposeConv2d op: level check kH/kW, outpad, and stride |
407 | bool levelCheckTransposeConv2d(Operation *op) { |
408 | if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) { |
409 | if (ShapedType filterType = |
410 | dyn_cast<ShapedType>(transpose.getWeight().getType())) { |
411 | auto shape = filterType.getShape(); |
412 | assert(shape.size() == 4); |
413 | // level check kernel sizes for kH and KW |
414 | if (!levelCheckKernel(op, v: shape[1], checkDesc: "KH <= MAX_KERNEL") || |
415 | !levelCheckKernel(op, v: shape[2], checkDesc: "KW <= MAX_KERNEL")) { |
416 | return false; |
417 | } |
418 | } |
419 | for (auto p : transpose.getOutPad()) { |
420 | if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) { |
421 | return false; |
422 | } |
423 | } |
424 | for (auto s : transpose.getStride()) { |
425 | if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) { |
426 | return false; |
427 | } |
428 | } |
429 | } |
430 | return true; |
431 | } |
432 | |
433 | // Resize op: level check max scales |
434 | bool levelCheckResize(Operation *op) { |
435 | if (auto resize = dyn_cast<tosa::ResizeOp>(op)) { |
436 | SmallVector<int64_t> scale; |
437 | if (!tosa::getConstShapeValues(op: resize.getScale().getDefiningOp(), |
438 | result_shape&: scale)) { |
439 | return false; |
440 | } |
441 | const int64_t scaleYN = scale[0]; |
442 | const int64_t scaleYD = scale[1]; |
443 | const int64_t scaleXN = scale[2]; |
444 | const int64_t scaleXD = scale[3]; |
445 | if (!levelCheckScale(op, v: scaleYN / scaleYD, |
446 | checkDesc: "scale_y_n/scale_y_d <= MAX_SCALE") || |
447 | !levelCheckScale(op, v: scaleXN / scaleXD, |
448 | checkDesc: "scale_x_n/scale_x_d <= MAX_SCALE")) { |
449 | return false; |
450 | } |
451 | } |
452 | return true; |
453 | } |
454 | |
455 | // Recursively perform a bottom-up search to determine the maximum nesting |
456 | // depth, starting from a specific operation and continuing up to the function |
457 | // or module scope. Tosa nesting_depth starts at 0 and increments by one each |
458 | // time a new nested `region` is encountered. |
459 | static void getMaxNestedDepth(Operation *op, int32_t &depth) { |
460 | if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op)) |
461 | return; |
462 | |
463 | op = op->getParentOp(); |
464 | if (!op) |
465 | return; |
466 | |
467 | depth++; |
468 | getMaxNestedDepth(op, depth); |
469 | } |
470 | |
471 | bool levelCheckMaxNesting(Operation *op) { |
472 | int32_t maxNestedDepth = 0; |
473 | getMaxNestedDepth(op, depth&: maxNestedDepth); |
474 | |
475 | if (maxNestedDepth >= tosaLevel.MAX_NESTING) { |
476 | op->emitOpError() << "failed level check: "<< maxNestedDepth |
477 | << " >= MAX_NESTING"; |
478 | return false; |
479 | } |
480 | return true; |
481 | } |
482 | |
483 | bool levelCheckListSize(Operation *op) { |
484 | if (auto concat = dyn_cast<tosa::ConcatOp>(op)) { |
485 | return levelCheckListSize(op, concat.getInput1().size(), "input1"); |
486 | } |
487 | if (auto custom = dyn_cast<tosa::CustomOp>(op)) { |
488 | if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") || |
489 | !levelCheckListSize(op, custom.getOutputList().size(), |
490 | "output_list")) { |
491 | return false; |
492 | } |
493 | } |
494 | if (auto condIf = dyn_cast<tosa::IfOp>(op)) { |
495 | if (!levelCheckListSize(op, condIf.getInputList().size(), "inputs") || |
496 | !levelCheckListSize(op, condIf.getOutputList().size(), "outputs")) { |
497 | return false; |
498 | } |
499 | } |
500 | if (auto w = dyn_cast<tosa::WhileOp>(op)) { |
501 | if (!levelCheckListSize(op, w.getInputList().size(), "inputs") || |
502 | !levelCheckListSize(op, w.getOutputList().size(), "outputs")) { |
503 | return false; |
504 | } |
505 | } |
506 | return true; |
507 | } |
508 | |
509 | bool attributeCheckRescale(Operation *op) { |
510 | if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) { |
511 | if (rescale.getRoundingMode() == "DOUBLE_ROUND"&& |
512 | !targetEnv.allows(Extension::doubleround)) { |
513 | op->emitOpError() |
514 | << "failed attribute check: rounding_mode = DOUBLE_ROUND " |
515 | << "requires extension [doubleround]"; |
516 | return false; |
517 | } else if (rescale.getRoundingMode() == "INEXACT_ROUND"&& |
518 | !targetEnv.allows(Extension::inexactround)) { |
519 | op->emitOpError() |
520 | << "failed attribute check: rounding_mode = INEXACT_ROUND " |
521 | << "requires extension [inexactround]"; |
522 | return false; |
523 | } |
524 | } |
525 | return true; |
526 | } |
527 | |
528 | // configure profile and level values from pass options profileName and |
529 | // levelName |
530 | void configLevelAndProfile() { |
531 | tosaLevel = TOSA_LEVEL_NONE; |
532 | if (level == TosaLevelEnum::EightK) { |
533 | tosaLevel = TOSA_LEVEL_EIGHTK; |
534 | } |
535 | |
536 | if (!profile.empty()) { |
537 | for (std::string &prof : profile) { |
538 | auto profSymbol = symbolizeProfile(prof); |
539 | if (profSymbol) { |
540 | targetEnv.addProfile(profSymbol.value()); |
541 | } else { |
542 | llvm::errs() << "unknown TOSA profile name passed in: "<< prof |
543 | << ", supported profiles are `pro_int` and `pro_fp`\n"; |
544 | return signalPassFailure(); |
545 | } |
546 | } |
547 | } |
548 | |
549 | if (!extension.empty()) { |
550 | for (std::string &ext : extension) { |
551 | auto extSymbol = symbolizeExtension(ext); |
552 | if (extSymbol) { |
553 | targetEnv.addExtension(extSymbol.value()); |
554 | } else { |
555 | llvm::errs() << "unknown TOSA extension name passed in: "<< ext |
556 | << ", supported extension are int16, int4, bf16, " |
557 | << "fp8e4m3, fp8e5m2, fft, variable, controlflow, " |
558 | << "doubleround, inexactround and dynamic\n"; |
559 | return signalPassFailure(); |
560 | } |
561 | } |
562 | } |
563 | } |
564 | |
565 | bool CheckVariable(Operation *op); |
566 | bool CheckVariableReadOrWrite(Operation *op); |
567 | bool isValidElementType(Type type, const bool allowUnsigned = false); |
568 | |
569 | SmallVector< |
570 | std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>> |
571 | constCheckers; |
572 | TosaLevel tosaLevel; |
573 | DenseMap<StringAttr, mlir::Type> variablesMap; |
574 | TosaProfileCompliance profileComp; |
575 | tosa::TargetEnv targetEnv; |
576 | }; |
577 | |
578 | template <> |
579 | bool TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) { |
580 | auto op = tosaOp.getOperation(); |
581 | if (!levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK)) |
582 | return false; |
583 | |
584 | // rank(output) = rank(input) - 1 |
585 | if (!levelCheckRank(op, tosaOp.getOutput(), "result", tosaLevel.MAX_RANK - 1)) |
586 | return false; |
587 | |
588 | return true; |
589 | } |
590 | |
591 | template <> |
592 | bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) { |
593 | auto op = tosaOp.getOperation(); |
594 | |
595 | // Only the condition input has rank limitation. |
596 | if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK)) |
597 | return false; |
598 | |
599 | return true; |
600 | } |
601 | |
602 | template <> |
603 | bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) { |
604 | auto op = tosaOp.getOperation(); |
605 | auto variableType = getVariableType(tosaOp); |
606 | if (!levelCheckRank(op, variableType, "variable type", tosaLevel.MAX_RANK)) |
607 | return false; |
608 | |
609 | return true; |
610 | } |
611 | |
612 | template <> |
613 | bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) { |
614 | auto op = tosaOp.getOperation(); |
615 | auto variableType = getVariableType(tosaOp); |
616 | if (!levelCheckSize(op, variableType, "variable type")) |
617 | return false; |
618 | |
619 | return true; |
620 | } |
621 | |
622 | bool TosaValidation::levelCheckRanksAndSizes(Operation *op) { |
623 | #define CHECK_RANKS_AND_SIZES(tosaOp) \ |
624 | if (isa<tosa::tosaOp##Op>(op)) { \ |
625 | if (!levelCheckRanks(cast<tosa::tosaOp##Op>(op))) \ |
626 | return false; \ |
627 | if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op))) \ |
628 | return false; \ |
629 | } |
630 | |
631 | #define CHECK_SIZES(tosaOp) \ |
632 | if (isa<tosa::tosaOp##Op>(op)) { \ |
633 | if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op))) \ |
634 | return false; \ |
635 | } |
636 | |
637 | // Tensor Operators |
638 | CHECK_RANKS_AND_SIZES(ArgMax); |
639 | // Activation Functions |
640 | CHECK_RANKS_AND_SIZES(Clamp); |
641 | CHECK_RANKS_AND_SIZES(Erf); |
642 | CHECK_RANKS_AND_SIZES(Sigmoid); |
643 | CHECK_RANKS_AND_SIZES(Tanh); |
644 | // Elementwise Binary Operators |
645 | CHECK_RANKS_AND_SIZES(Add); |
646 | CHECK_RANKS_AND_SIZES(ArithmeticRightShift); |
647 | CHECK_RANKS_AND_SIZES(BitwiseAnd); |
648 | CHECK_RANKS_AND_SIZES(BitwiseOr); |
649 | CHECK_RANKS_AND_SIZES(BitwiseXor); |
650 | CHECK_RANKS_AND_SIZES(IntDiv); |
651 | CHECK_RANKS_AND_SIZES(LogicalAnd); |
652 | CHECK_RANKS_AND_SIZES(LogicalLeftShift); |
653 | CHECK_RANKS_AND_SIZES(LogicalRightShift); |
654 | CHECK_RANKS_AND_SIZES(LogicalOr); |
655 | CHECK_RANKS_AND_SIZES(LogicalXor); |
656 | CHECK_RANKS_AND_SIZES(Maximum); |
657 | CHECK_RANKS_AND_SIZES(Minimum); |
658 | CHECK_RANKS_AND_SIZES(Mul); |
659 | CHECK_RANKS_AND_SIZES(Pow); |
660 | CHECK_RANKS_AND_SIZES(Sub); |
661 | CHECK_RANKS_AND_SIZES(Table); |
662 | // Elementwise Unary Operators |
663 | CHECK_RANKS_AND_SIZES(Abs); |
664 | CHECK_RANKS_AND_SIZES(BitwiseNot); |
665 | CHECK_RANKS_AND_SIZES(Ceil); |
666 | CHECK_RANKS_AND_SIZES(Clz); |
667 | CHECK_RANKS_AND_SIZES(Cos); |
668 | CHECK_RANKS_AND_SIZES(Exp); |
669 | CHECK_RANKS_AND_SIZES(Floor); |
670 | CHECK_RANKS_AND_SIZES(Log); |
671 | CHECK_RANKS_AND_SIZES(LogicalNot); |
672 | CHECK_RANKS_AND_SIZES(Negate); |
673 | CHECK_RANKS_AND_SIZES(Reciprocal); |
674 | CHECK_RANKS_AND_SIZES(Rsqrt); |
675 | CHECK_RANKS_AND_SIZES(Sin); |
676 | // Elementwise Ternary Operators |
677 | CHECK_RANKS_AND_SIZES(Select); |
678 | // Comparison Operators |
679 | CHECK_RANKS_AND_SIZES(Equal); |
680 | CHECK_RANKS_AND_SIZES(Greater); |
681 | CHECK_RANKS_AND_SIZES(GreaterEqual); |
682 | // Reduction Operators |
683 | CHECK_RANKS_AND_SIZES(ReduceAll); |
684 | CHECK_RANKS_AND_SIZES(ReduceAny); |
685 | CHECK_RANKS_AND_SIZES(ReduceMax); |
686 | CHECK_RANKS_AND_SIZES(ReduceMin); |
687 | CHECK_RANKS_AND_SIZES(ReduceProduct); |
688 | CHECK_RANKS_AND_SIZES(ReduceSum); |
689 | // Data Layout Operators |
690 | CHECK_RANKS_AND_SIZES(Concat); |
691 | CHECK_RANKS_AND_SIZES(Pad); |
692 | CHECK_RANKS_AND_SIZES(Reshape); |
693 | CHECK_RANKS_AND_SIZES(Reverse); |
694 | CHECK_RANKS_AND_SIZES(Slice); |
695 | CHECK_RANKS_AND_SIZES(Tile); |
696 | CHECK_RANKS_AND_SIZES(Transpose); |
697 | // Type Conversion |
698 | CHECK_RANKS_AND_SIZES(Cast); |
699 | CHECK_RANKS_AND_SIZES(Rescale); |
700 | // Control Flow Operators |
701 | CHECK_RANKS_AND_SIZES(If); |
702 | // Variable Operators |
703 | CHECK_RANKS_AND_SIZES(Variable); |
704 | CHECK_RANKS_AND_SIZES(VariableWrite); |
705 | CHECK_RANKS_AND_SIZES(VariableRead); |
706 | // Data Nodes |
707 | CHECK_RANKS_AND_SIZES(Const); |
708 | CHECK_RANKS_AND_SIZES(Identity); |
709 | |
710 | // For the following operators, check whether the size of each tensor |
711 | // operand is valid in a given Level. |
712 | |
713 | // Tensor Operators |
714 | CHECK_SIZES(AvgPool2d); |
715 | CHECK_SIZES(Conv2D); |
716 | CHECK_SIZES(Conv3D); |
717 | CHECK_SIZES(DepthwiseConv2D); |
718 | CHECK_SIZES(TransposeConv2D); |
719 | CHECK_SIZES(FFT2d); |
720 | CHECK_SIZES(MatMul); |
721 | CHECK_SIZES(MaxPool2d); |
722 | CHECK_SIZES(RFFT2d); |
723 | // Scatter/Gather Operators |
724 | CHECK_SIZES(Gather); |
725 | CHECK_SIZES(Scatter); |
726 | // Image Operators |
727 | CHECK_SIZES(Resize); |
728 | // Custom Operators |
729 | CHECK_SIZES(Custom); |
730 | // Control Flow Operators |
731 | CHECK_SIZES(While); |
732 | // Shape Operators |
733 | CHECK_SIZES(ConstShape); |
734 | |
735 | #undef CHECK_RANKS_AND_SIZES |
736 | #undef CHECK_SIZES |
737 | return true; |
738 | } |
739 | |
740 | // Perform the Level tensor size check on the tensor type. |
741 | bool TosaValidation::levelCheckSize(Operation *op, const Type &typeToCheck, |
742 | const StringRef operandOrResult) { |
743 | if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) { |
744 | if (!type.hasRank()) { |
745 | op->emitOpError() << "failed level check: unranked tensor"; |
746 | return false; |
747 | } |
748 | auto shape = type.getShape(); |
749 | for (auto dim : shape) { |
750 | if (mlir::ShapedType::isDynamic(dim)) { |
751 | op->emitOpError() << "failed level check: "<< operandOrResult |
752 | << " shape dimension cannot be dynamic"; |
753 | return false; |
754 | } |
755 | } |
756 | |
757 | int64_t element_bits = type.getElementTypeBitWidth(); |
758 | int64_t element_bytes = std::max(INT64_C(1), element_bits / 8); |
759 | int64_t size = element_bytes * type.getNumElements(); |
760 | |
761 | // According to 1.11. Tensor Definitions of Tosa spec, the value of |
762 | // tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is |
763 | // defined in 1.7. Levels. |
764 | // For each tensor, the number of tensor elements multiplied by the |
765 | // element size in bytes must be representable as a tensor_size_t. |
766 | const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1; |
767 | if (size > max_size) { |
768 | op->emitOpError() |
769 | << "failed level check: "<< operandOrResult |
770 | << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)"; |
771 | return false; |
772 | } |
773 | } |
774 | return true; |
775 | } |
776 | |
777 | LogicalResult TosaValidation::applyLevelCheck(Operation *op) { |
778 | if (tosaLevel == TOSA_LEVEL_NONE) { |
779 | // no need to do level checks |
780 | return success(); |
781 | } |
782 | |
783 | // check rank and sizes early so later checks can assume shaped operands |
784 | if (!levelCheckRanksAndSizes(op)) |
785 | return failure(); |
786 | |
787 | // additional level checks from spec 0.70 |
788 | if (!levelCheckPool<tosa::AvgPool2dOp>(op) || |
789 | !levelCheckConv<tosa::Conv2DOp>(op) || |
790 | !levelCheckConv<tosa::Conv3DOp>(op) || |
791 | !levelCheckConv<tosa::DepthwiseConv2DOp>(op) || |
792 | !levelCheckFFT<tosa::FFT2dOp>(op) || |
793 | !levelCheckPool<tosa::MaxPool2dOp>(op) || |
794 | !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) || |
795 | !levelCheckResize(op)) { |
796 | return failure(); |
797 | } |
798 | |
799 | // level check MAX_TENSOR_LIST_SIZE |
800 | if (!levelCheckListSize(op)) { |
801 | return failure(); |
802 | } |
803 | |
804 | if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) { |
805 | if (!levelCheckMaxNesting(op)) { |
806 | return failure(); |
807 | } |
808 | } |
809 | |
810 | return success(); |
811 | } |
812 | |
813 | LogicalResult TosaValidation::applyAttributeCheck(Operation *op) { |
814 | if (!attributeCheckRescale(op)) |
815 | return failure(); |
816 | return success(); |
817 | } |
818 | |
819 | inline bool CompatibleTypes(const mlir::Type &type, |
820 | const mlir::Type &declaredType) { |
821 | // for now, simply use type equality comparison |
822 | return type == declaredType; |
823 | } |
824 | |
825 | bool TosaValidation::CheckVariable(Operation *op) { |
826 | if (auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) { |
827 | mlir::StringAttr nameAttr = variableOp.getNameAttr(); |
828 | |
829 | if (variablesMap.count(nameAttr)) { |
830 | op->emitOpError() << "name has already been declared"; |
831 | return false; |
832 | } |
833 | |
834 | auto elementType = variableOp.getType(); |
835 | DenseIntElementsAttr varShapeAttr = variableOp.getVarShape(); |
836 | SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>()); |
837 | RankedTensorType variableType = |
838 | RankedTensorType::get(ArrayRef<int64_t>(shape), elementType); |
839 | |
840 | variablesMap[nameAttr] = variableType; |
841 | } |
842 | |
843 | return true; |
844 | } |
845 | |
846 | bool TosaValidation::CheckVariableReadOrWrite(Operation *op) { |
847 | if (isa<mlir::tosa::VariableReadOp>(op) || |
848 | isa<mlir::tosa::VariableWriteOp>(op)) { |
849 | mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name")); |
850 | if (!variablesMap.count(nameAttr)) { |
851 | op->emitOpError() << "name has not been declared"; |
852 | return false; |
853 | } |
854 | |
855 | auto varType = variablesMap[nameAttr]; |
856 | |
857 | for (auto v : op->getOperands()) { |
858 | auto type = v.getType(); |
859 | if (!CompatibleTypes(type, varType)) { |
860 | op->emitOpError() << "operand type does not equal variable type"; |
861 | return false; |
862 | } |
863 | } |
864 | |
865 | for (auto v : op->getResults()) { |
866 | auto type = v.getType(); |
867 | if (!CompatibleTypes(type, varType)) { |
868 | op->emitOpError() << "result type does not equal variable type"; |
869 | return false; |
870 | } |
871 | } |
872 | } |
873 | |
874 | return true; |
875 | } |
876 | |
877 | LogicalResult TosaValidation::applyVariableCheck(Operation *op) { |
878 | if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) { |
879 | return failure(); |
880 | } |
881 | return success(); |
882 | } |
883 | |
884 | bool checkErrorIfResize(Operation *op) { |
885 | auto resize = dyn_cast<tosa::ResizeOp>(op); |
886 | if (!resize) |
887 | return true; |
888 | |
889 | const Value input = resize.getInput(); |
890 | const Value output = resize.getOutput(); |
891 | const RankedTensorType inputType = |
892 | llvm::dyn_cast<RankedTensorType>(input.getType()); |
893 | const RankedTensorType outputType = |
894 | llvm::dyn_cast<RankedTensorType>(output.getType()); |
895 | |
896 | if (!inputType || !outputType) { |
897 | op->emitOpError(message: "expect ranked input/output tensor"); |
898 | return false; |
899 | } |
900 | |
901 | // Ensure the image size is supported by GPU APIs and that for integer |
902 | // implementations, position * stride does not overflow int32_t. |
903 | if (inputType.hasStaticShape() && outputType.hasStaticShape()) { |
904 | const SmallVector<int64_t, 4> sizes = { |
905 | outputType.getDimSize(1), outputType.getDimSize(2), |
906 | inputType.getDimSize(1), inputType.getDimSize(2)}; |
907 | const int64_t *maxDim = llvm::max_element(sizes); |
908 | if (maxDim != sizes.end() && *maxDim >= 16384) { |
909 | op->emitOpError(message: "expect input/output height/width dims to be < 16384, ") |
910 | << "got [OH, OW, IH, IW] = "<< sizes; |
911 | return false; |
912 | } |
913 | } |
914 | |
915 | SmallVector<int64_t> scale; |
916 | if (!tosa::getConstShapeValues(op: resize.getScale().getDefiningOp(), result_shape&: scale)) { |
917 | return false; |
918 | } |
919 | |
920 | const int64_t scaleYN = scale[0]; |
921 | const int64_t scaleYD = scale[1]; |
922 | const int64_t scaleXN = scale[2]; |
923 | const int64_t scaleXD = scale[3]; |
924 | |
925 | // Ensure scale values don't overflow int32 accumulator |
926 | if (scaleYN > (1 << 11) || scaleXN > (1 << 11)) { |
927 | op->emitOpError(message: "expect all scale numerator values to be <= (1 << 11), " |
928 | "got scale_y_n=") |
929 | << scaleYN << ", scale_x_n="<< scaleXN; |
930 | return false; |
931 | } |
932 | |
933 | if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) { |
934 | op->emitOpError(message: "expect a downscale ratio larger than 1/16, got y=") |
935 | << scaleYN << "/"<< scaleYD << ", x="<< scaleXN << "/"<< scaleXD; |
936 | return false; |
937 | } |
938 | |
939 | SmallVector<int64_t> offset; |
940 | SmallVector<int64_t> border; |
941 | if (!tosa::getConstShapeValues(op: resize.getOffset().getDefiningOp(), result_shape&: offset) || |
942 | !tosa::getConstShapeValues(op: resize.getBorder().getDefiningOp(), result_shape&: border)) { |
943 | return false; |
944 | } |
945 | |
946 | const int64_t offsetY = offset[0]; |
947 | const int64_t offsetX = offset[1]; |
948 | // Set a consistent lower limit of 1/16 downscale to simplify |
949 | // implementations |
950 | if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) { |
951 | op->emitOpError( |
952 | message: "expect offsetY / scaleYNumerator to be in range [-1, 16), got ") |
953 | << offsetY << "/"<< scaleYN; |
954 | return false; |
955 | } |
956 | if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) { |
957 | op->emitOpError( |
958 | message: "expect offsetX / scaleXNumerator to be in range [-1, 16), got ") |
959 | << offsetX << "/"<< scaleXN; |
960 | return false; |
961 | } |
962 | |
963 | const int64_t borderY = border[0]; |
964 | const int64_t borderX = border[1]; |
965 | if (borderY < -16 * scaleYN || borderY >= scaleYN) { |
966 | op->emitOpError( |
967 | message: "expect borderY / scaleYNumerator to be in range [-16, 1), got ") |
968 | << borderY << "/"<< scaleYN; |
969 | return false; |
970 | } |
971 | if (borderX < -16 * scaleXN || borderX >= scaleXN) { |
972 | op->emitOpError( |
973 | message: "expect borderX / scaleXNumerator to be in range [-16, 1), got ") |
974 | << borderX << "/"<< scaleXN; |
975 | return false; |
976 | } |
977 | |
978 | // The following section of code is mostly duplicated with ResizeOp::verify(). |
979 | // |
980 | // In TOSA specification, we do not support broadcast behavior. |
981 | // However, there is a rewrite pattern to materialize broadcast ResizeOp. |
982 | // It makes invalid TOSA ResizeOp into valid one. To avoid breaking |
983 | // existing code, we keep the rewrite pattern untouched. So, we need |
984 | // loose the checking in ResizeOp::verify() to support broadcast ResizeOp. |
985 | // |
986 | // Here is a strict checking to conform TOSA specification. |
987 | // FIXME: Remove the duplicated checkings when broadcast ResizeOp is removed. |
988 | auto idivCheck = [](const int64_t lhs, |
989 | const int64_t rhs) -> std::optional<int64_t> { |
990 | if (lhs % rhs != 0) |
991 | return std::nullopt; |
992 | return lhs / rhs; |
993 | }; |
994 | |
995 | const int64_t oh = outputType.getDimSize(1); |
996 | const int64_t ow = outputType.getDimSize(2); |
997 | const int64_t ih = inputType.getDimSize(1); |
998 | const int64_t iw = inputType.getDimSize(2); |
999 | |
1000 | if (ih != ShapedType::kDynamic) { |
1001 | const std::optional<int64_t> calculatedOutHeightMinusOne = |
1002 | idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD); |
1003 | if (!calculatedOutHeightMinusOne.has_value()) { |
1004 | op->emitOpError(message: "expected (input_height - 1) * scale_y_n - offset_y + " |
1005 | "border_y ") |
1006 | << "to be wholly divisible by scale_y_d, got (("<< ih << " - 1) * " |
1007 | << scaleYN << " - "<< offsetY << " + "<< borderY << ") / " |
1008 | << scaleYD; |
1009 | return false; |
1010 | } |
1011 | const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1; |
1012 | if (oh != ShapedType::kDynamic && calculatedOutHeight != oh) { |
1013 | op->emitOpError(message: "calculated output height did not match expected: ") |
1014 | << "calculated="<< calculatedOutHeight << ", expected="<< oh; |
1015 | return false; |
1016 | } |
1017 | } |
1018 | |
1019 | if (iw != ShapedType::kDynamic) { |
1020 | const std::optional<int64_t> calculatedOutWidthMinusOne = |
1021 | idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD); |
1022 | if (!calculatedOutWidthMinusOne.has_value()) { |
1023 | op->emitOpError(message: "expected (input_width - 1) * scale_x_n - offset_x + " |
1024 | "border_x ") |
1025 | << "to be wholly divisible by scale_x_d, got (("<< iw << " - 1) * " |
1026 | << scaleXN << " - "<< offsetX << " + "<< borderX << ") / " |
1027 | << scaleXD; |
1028 | return false; |
1029 | } |
1030 | const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1; |
1031 | if (ow != ShapedType::kDynamic && calculatedOutWidth != ow) { |
1032 | op->emitOpError(message: "calculated output width did not match expected: ") |
1033 | << "calculated="<< calculatedOutWidth << ", expected="<< ow; |
1034 | return false; |
1035 | } |
1036 | } |
1037 | |
1038 | return true; |
1039 | } |
1040 | |
1041 | bool checkErrorIfMul(Operation *op) { |
1042 | auto mul = dyn_cast<tosa::MulOp>(op); |
1043 | if (!mul) |
1044 | return true; |
1045 | |
1046 | // REQUIRE(0 <= shift && shift <= 63); |
1047 | // REQUIRE(is_same<in_t,int32_t>() || shift == 0); |
1048 | ElementsAttr shift_elem; |
1049 | if (!matchPattern(mul.getShift(), m_Constant(&shift_elem))) { |
1050 | return true; |
1051 | } |
1052 | int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt(); |
1053 | auto inputElemType = getElementTypeOrSelf(mul.getInput1()); |
1054 | if (inputElemType.isInteger(32)) { |
1055 | // 0 <= shift <= 63 for int32_t type |
1056 | if (shift < 0 || shift > 63) { |
1057 | op->emitOpError() << "requires 0 <= shift && shift <= 63, but got: " |
1058 | << shift; |
1059 | return false; |
1060 | } |
1061 | } else { |
1062 | // shift must be 0 for all other types |
1063 | if (shift != 0) { |
1064 | op->emitOpError() << "requires shift = 0 for all input data types that " |
1065 | "are not int32_t, but got: " |
1066 | << shift; |
1067 | return false; |
1068 | } |
1069 | } |
1070 | |
1071 | return true; |
1072 | } |
1073 | |
1074 | bool checkErrorIfTable(Operation *op) { |
1075 | auto table = dyn_cast<tosa::TableOp>(op); |
1076 | if (!table) |
1077 | return true; |
1078 | |
1079 | // REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513 |
1080 | const auto inputElemType = getElementTypeOrSelf(table.getInput1().getType()); |
1081 | const int tableSize = inputElemType.isInteger(8) ? 256 : 513; |
1082 | |
1083 | const ShapeAdaptor tableShape(table.getTable().getType()); |
1084 | if (tableShape.hasStaticShape()) { |
1085 | const auto numElements = tableShape.getNumElements(); |
1086 | if (numElements != tableSize) { |
1087 | op->emitOpError() << "requires table size of "<< tableSize << ", got " |
1088 | << numElements; |
1089 | return false; |
1090 | } |
1091 | } |
1092 | |
1093 | return true; |
1094 | } |
1095 | |
1096 | bool checkErrorIfRescale(Operation *op) { |
1097 | auto rescale = dyn_cast<tosa::RescaleOp>(op); |
1098 | if (!rescale) |
1099 | return true; |
1100 | |
1101 | auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType()); |
1102 | auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType()); |
1103 | if (!inputType || !outputType || !inputType.getElementType().isInteger() || |
1104 | !outputType.getElementType().isInteger()) |
1105 | return true; |
1106 | |
1107 | auto inElemType = inputType.getElementType(); |
1108 | auto outElemType = outputType.getElementType(); |
1109 | auto inWidth = inElemType.getIntOrFloatBitWidth(); |
1110 | auto outWidth = outElemType.getIntOrFloatBitWidth(); |
1111 | |
1112 | bool inputUnsigned = rescale.getInputUnsigned(); |
1113 | bool outputUnsigned = rescale.getOutputUnsigned(); |
1114 | |
1115 | bool scale32 = rescale.getScale32(); |
1116 | auto roundingMode = rescale.getRoundingMode(); |
1117 | |
1118 | // ERROR_IF(scale32 && is_same<in_t,i48_t>()) |
1119 | if (scale32 && inWidth == 48) { |
1120 | op->emitOpError() << "scale32 is not allowed with 48-bit input."; |
1121 | return false; |
1122 | } |
1123 | |
1124 | // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND)) |
1125 | if (!scale32 && roundingMode == "DOUBLE_ROUND") { |
1126 | op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true."; |
1127 | return false; |
1128 | } |
1129 | |
1130 | // ERROR_IF(input_unsigned && output_unsigned) |
1131 | if (inputUnsigned && outputUnsigned) { |
1132 | op->emitOpError() << "input and output cannot be both unsigned."; |
1133 | return false; |
1134 | } |
1135 | |
1136 | // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned) |
1137 | if (outWidth == 32 && inputUnsigned) { |
1138 | op->emitOpError() << "i32 output type is not allowed with unsigned input."; |
1139 | return false; |
1140 | } |
1141 | |
1142 | // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned) |
1143 | if (inWidth == 32 && outputUnsigned) { |
1144 | op->emitOpError() << "i32 input type is not allowed with unsigned output."; |
1145 | return false; |
1146 | } |
1147 | |
1148 | // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned) |
1149 | if (inWidth == 48 && outputUnsigned) { |
1150 | op->emitOpError() << "i48 input type is not allowed with unsigned output."; |
1151 | return false; |
1152 | } |
1153 | |
1154 | // ERROR_IF(is_same<in_t, i48_t> && input_unsigned) |
1155 | if (inWidth == 48 && inputUnsigned) { |
1156 | op->emitOpError() << "i48 input type cannot be unsigned."; |
1157 | return false; |
1158 | } |
1159 | |
1160 | // ERROR_IF(is_same<in_t, i32_t> && input_unsigned) |
1161 | if (inWidth == 32 && inputUnsigned) { |
1162 | op->emitOpError() << "i32 input type cannot be unsigned."; |
1163 | return false; |
1164 | } |
1165 | |
1166 | // ERROR_IF(is_same<out_t, i32_t> && output_unsigned) |
1167 | if (outWidth == 32 && outputUnsigned) { |
1168 | op->emitOpError() << "i32 output type cannot be unsigned."; |
1169 | return false; |
1170 | } |
1171 | |
1172 | return true; |
1173 | } |
1174 | |
1175 | bool checkErrorIfPad(Operation *op) { |
1176 | auto pad = dyn_cast<tosa::PadOp>(op); |
1177 | if (!pad) |
1178 | return true; |
1179 | |
1180 | DenseIntElementsAttr paddingAttr; |
1181 | if (!matchPattern(pad.getPadding(), m_Constant(&paddingAttr))) |
1182 | // Pad verifier will catch this |
1183 | return true; |
1184 | |
1185 | for (const APInt &val : paddingAttr.getValues<APInt>()) { |
1186 | if (val.getSExtValue() < 0) { |
1187 | op->emitOpError() << "padding value must all be non-negative, got " |
1188 | << val.getSExtValue(); |
1189 | return false; |
1190 | } |
1191 | } |
1192 | |
1193 | return true; |
1194 | } |
1195 | |
1196 | // Returns true if the operation takes no input operands, excluding attributes. |
1197 | static bool isNullaryOperation(Operation *op) { |
1198 | if (isa<tosa::ConstOp>(op) || isa<tosa::ConstShapeOp>(op) || |
1199 | isa<tosa::YieldOp>(op) || isa<tosa::VariableOp>(op)) |
1200 | return true; |
1201 | return false; |
1202 | } |
1203 | |
1204 | bool checkErrorIfCondIf(Operation *op) { |
1205 | auto ifOp = dyn_cast<tosa::IfOp>(op); |
1206 | if (!ifOp) |
1207 | return true; |
1208 | |
1209 | // Whether the types and shapes of operands between the input/output list and |
1210 | // internal regions are validated by the operation verifier. However, with |
1211 | // support for the simplified form - where redundant operand notations are |
1212 | // omitted - is not conformant to the specification. According to the |
1213 | // specification, all operands passed into an operation must be explicitly |
1214 | // declared at each operation's structure. This code section verify that the |
1215 | // operation's form complies with this requirement. |
1216 | |
1217 | // Returns true if the region uses no external input operands. |
1218 | auto isNullaryRegion = [](Region ®ion) -> bool { |
1219 | bool noLiveInValue = true; |
1220 | region.walk([&noLiveInValue](Operation *op) { |
1221 | if (!isNullaryOperation(op)) { |
1222 | noLiveInValue = false; |
1223 | return WalkResult::interrupt(); |
1224 | } |
1225 | return WalkResult::advance(); |
1226 | }); |
1227 | return noLiveInValue; |
1228 | }; |
1229 | |
1230 | mlir::Region &thenGraph = ifOp.getThenGraph(); |
1231 | mlir::Region &elseGraph = ifOp.getElseGraph(); |
1232 | bool isThenGraphNullaryRegion = isNullaryRegion(thenGraph); |
1233 | bool isElseGraphNullaryRegion = isNullaryRegion(elseGraph); |
1234 | bool isInputListEmpty = ifOp.getInputList().size() == 0; |
1235 | |
1236 | if ((isInputListEmpty != isThenGraphNullaryRegion) || |
1237 | (isInputListEmpty != isElseGraphNullaryRegion)) { |
1238 | op->emitOpError() |
1239 | << "the current simplified form is not strictly conformant to the " |
1240 | "spec, please use the generic format\n"; |
1241 | return false; |
1242 | } |
1243 | |
1244 | return true; |
1245 | } |
1246 | |
1247 | LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) { |
1248 | if (!checkErrorIfResize(op) || !checkErrorIfMul(op) || |
1249 | !checkErrorIfTable(op) || !checkErrorIfRescale(op) || |
1250 | !checkErrorIfPad(op) || !checkErrorIfCondIf(op)) |
1251 | return failure(); |
1252 | return success(); |
1253 | } |
1254 | |
1255 | bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { |
1256 | if (isa<FloatType>(type)) { |
1257 | return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType, |
1258 | Float8E5M2Type>(type); |
1259 | } else if (auto intTy = dyn_cast<IntegerType>(type)) { |
1260 | if (intTy.isSignless()) { |
1261 | switch (intTy.getWidth()) { |
1262 | case 1: |
1263 | case 4: |
1264 | case 8: |
1265 | case 16: |
1266 | case 32: |
1267 | case 48: |
1268 | return true; |
1269 | } |
1270 | } else if (allowUnsigned && intTy.isUnsigned()) { |
1271 | switch (intTy.getWidth()) { |
1272 | case 8: |
1273 | case 16: |
1274 | case 32: |
1275 | return true; |
1276 | } |
1277 | } |
1278 | } else if (mlir::isa<tosa::shapeType>(type)) { |
1279 | return true; |
1280 | } |
1281 | return false; |
1282 | } |
1283 | |
1284 | void TosaValidation::runOnOperation() { |
1285 | configLevelAndProfile(); |
1286 | |
1287 | TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>(); |
1288 | if (!tosaDialect) |
1289 | return; |
1290 | |
1291 | getOperation().walk([&](Operation *op) { |
1292 | if (op->getDialect() != tosaDialect) |
1293 | return; |
1294 | |
1295 | // validate operator element types: |
1296 | // - rescale operator is allowed to have ui8/ui16/ui32 |
1297 | // operands/results |
1298 | // - perform valid element type check at the beginning to |
1299 | // protect rest of code against quantized element types |
1300 | const bool opIsRescale = isa<tosa::RescaleOp>(op); |
1301 | for (Value operand : op->getOperands()) { |
1302 | auto elementTy = getElementTypeOrSelf(val: operand); |
1303 | if (!isValidElementType(type: elementTy, allowUnsigned: opIsRescale)) { |
1304 | op->emitOpError() << "is not profile-aligned: element type " |
1305 | << elementTy << " is not legal"; |
1306 | return signalPassFailure(); |
1307 | } |
1308 | } |
1309 | for (Type resultTy : op->getResultTypes()) { |
1310 | auto elementTy = getElementTypeOrSelf(resultTy); |
1311 | if (!isValidElementType(elementTy, opIsRescale)) { |
1312 | op->emitOpError() << "is not profile-aligned: element type " |
1313 | << elementTy << " is not legal"; |
1314 | return signalPassFailure(); |
1315 | } |
1316 | } |
1317 | |
1318 | if (strictOpSpecAlignment && |
1319 | failed(profileComp.checkProfile(op, targetEnv))) |
1320 | return signalPassFailure(); |
1321 | |
1322 | if (strictOpSpecAlignment && |
1323 | failed(profileComp.checkExtension(op, targetEnv))) |
1324 | return signalPassFailure(); |
1325 | |
1326 | if (!allowInvalidOpDatatypeCombinations && |
1327 | failed(profileComp.checkInvalid(op))) |
1328 | return signalPassFailure(); |
1329 | |
1330 | // Some uses of TOSA rely on the constant operands of particular |
1331 | // operations. |
1332 | if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op))) |
1333 | signalPassFailure(); |
1334 | |
1335 | // do level checks |
1336 | if (failed(Result: applyLevelCheck(op))) |
1337 | signalPassFailure(); |
1338 | |
1339 | // check additional attribute restrictions |
1340 | if (failed(Result: applyAttributeCheck(op))) |
1341 | signalPassFailure(); |
1342 | |
1343 | // do variable type checks |
1344 | if (failed(Result: applyVariableCheck(op))) |
1345 | signalPassFailure(); |
1346 | |
1347 | // do error if checks |
1348 | if (strictOpSpecAlignment && failed(applyErrorIfCheck(op))) |
1349 | signalPassFailure(); |
1350 | }); |
1351 | } |
1352 | } // namespace |
1353 |
Definitions
- checkConstantOperands
- checkConstantOperandMul
- checkConstantOperandTable
- checkConstantOperandPad
- checkConstantOperandRescale
- checkConstantOperandConvOps
- checkConstantOperandMatMul
- checkConstantOperandAvgPool2d
- checkConstantOperandNegate
- TosaLevel
- operator==
- TOSA_LEVEL_EIGHTK
- TOSA_LEVEL_NONE
- TosaValidation
- TosaValidation
- TosaValidation
- applyConstantOperandCheck
- populateConstantOperandChecks
- levelCheckKernel
- levelCheckStride
- levelCheckScale
- levelCheckListSize
- levelCheckRank
- levelCheckRank
- levelCheckSize
- levelCheckSizes
- levelCheckRanks
- levelCheckPool
- levelCheckConv
- levelCheckFFT
- levelCheckTransposeConv2d
- levelCheckResize
- getMaxNestedDepth
- levelCheckMaxNesting
- levelCheckListSize
- attributeCheckRescale
- configLevelAndProfile
- levelCheckRanks
- levelCheckRanks
- levelCheckRanks
- levelCheckSizes
- levelCheckRanksAndSizes
- levelCheckSize
- applyLevelCheck
- applyAttributeCheck
- CompatibleTypes
- CheckVariable
- CheckVariableReadOrWrite
- applyVariableCheck
- checkErrorIfResize
- checkErrorIfMul
- checkErrorIfTable
- checkErrorIfRescale
- checkErrorIfPad
- isNullaryOperation
- checkErrorIfCondIf
- applyErrorIfCheck
- isValidElementType
Learn to use CMake with our Intro Training
Find out more