1//===- TosaValidation.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Validate if TOSA dialect input matchs with the specification for given
10// requirements.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Tosa/Transforms/Passes.h"
15#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
16
17#include <string>
18
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/Tosa/IR/TosaOps.h"
21#include "mlir/IR/Builders.h"
22#include "mlir/IR/BuiltinOps.h"
23#include "mlir/IR/Matchers.h"
24#include "mlir/IR/TypeUtilities.h"
25#include "mlir/Pass/Pass.h"
26#include "mlir/Transforms/DialectConversion.h"
27
28namespace mlir {
29namespace tosa {
30#define GEN_PASS_DEF_TOSAVALIDATION
31#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
32} // namespace tosa
33} // namespace mlir
34
35using namespace mlir;
36using namespace mlir::tosa;
37
38namespace {
39
40static LogicalResult checkConstantOperandPad(Operation *op) {
41 if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
42 DenseElementsAttr paddings;
43 if (!matchPattern(padOp.getPadding(), m_Constant(&paddings)))
44 return op->emitOpError(message: "padding of pad is not constant");
45
46 DenseElementsAttr padConst;
47 // Assume this op is zero-padding if padConst is not presented.
48 if (padOp.getPadConst() &&
49 !matchPattern(padOp.getPadConst(), m_Constant(&padConst)))
50 return op->emitOpError(message: "pad_const of pad is not constant");
51 }
52 return success();
53}
54
55static LogicalResult checkConstantOperandTranspose(Operation *op) {
56 if (auto transposeOp = dyn_cast<tosa::TransposeOp>(op)) {
57 DenseElementsAttr perms;
58 if (!matchPattern(transposeOp.getPerms(), m_Constant(&perms)))
59 return op->emitOpError(message: "perms of transpose is not constant");
60 }
61 return success();
62}
63
64static LogicalResult checkConstantOperandFullyConnected(Operation *op) {
65 if (auto fcOp = dyn_cast<tosa::FullyConnectedOp>(op)) {
66 DenseElementsAttr weight;
67 if (!matchPattern(fcOp.getWeight(), m_Constant(&weight)))
68 return op->emitOpError(message: "weight of fully_connected is not constant");
69
70 DenseElementsAttr bias;
71 if (!matchPattern(fcOp.getBias(), m_Constant(&bias)))
72 return op->emitOpError(message: "bias of fully_connected is not constant");
73 }
74 return success();
75}
76
77struct TosaLevel {
78 int32_t MAX_RANK = 0;
79 int32_t MAX_KERNEL = 0;
80 int32_t MAX_STRIDE = 0;
81 int32_t MAX_SCALE = 0;
82
83 // @todo: MAX_LOG2_SIZE value and checks
84
85 bool operator==(const TosaLevel &rhs) {
86 return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
87 MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE;
88 }
89};
90
91static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {.MAX_RANK: 6, .MAX_KERNEL: 8192, .MAX_STRIDE: 8192, .MAX_SCALE: 256};
92static constexpr TosaLevel TOSA_LEVEL_NONE = {.MAX_RANK: 0, .MAX_KERNEL: 0, .MAX_STRIDE: 0, .MAX_SCALE: 0};
93
94//===----------------------------------------------------------------------===//
95// TOSA Validation Pass.
96//===----------------------------------------------------------------------===//
97
98struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
99public:
100 explicit TosaValidation() { populateConstantOperandChecks(); }
101 explicit TosaValidation(const TosaValidationOptions &options)
102 : TosaValidation() {
103 this->profile = options.profile;
104 this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment;
105 this->level = options.level;
106 }
107 void runOnOperation() final;
108
109 LogicalResult applyConstantOperandCheck(Operation *op) {
110 for (auto &checker : constCheckers) {
111 if (failed(checker(op)))
112 return failure();
113 }
114 return success();
115 }
116
117 LogicalResult applyLevelCheck(Operation *op);
118
119 // check variable read/write data types against variable declarations
120 LogicalResult applyVariableCheck(Operation *op);
121
122private:
123 void populateConstantOperandChecks() {
124 constCheckers.emplace_back(checkConstantOperandPad);
125 constCheckers.emplace_back(checkConstantOperandTranspose);
126 constCheckers.emplace_back(checkConstantOperandFullyConnected);
127 }
128
129 bool levelCheckKernel(Operation *op, int32_t v,
130 const std::string &checkDesc) {
131 if (v > tosaLevel.MAX_KERNEL) {
132 op->emitOpError() << "failed level check: " << checkDesc;
133 return false;
134 }
135 return true;
136 }
137
138 bool levelCheckStride(Operation *op, int32_t v,
139 const std::string &checkDesc) {
140 if (v > tosaLevel.MAX_STRIDE) {
141 op->emitOpError() << "failed level check: " << checkDesc;
142 return false;
143 }
144 return true;
145 }
146
147 bool levelCheckScale(Operation *op, int32_t v, const std::string &checkDesc) {
148 if (v > tosaLevel.MAX_SCALE) {
149 op->emitOpError() << "failed level check: " << checkDesc;
150 return false;
151 }
152 return true;
153 }
154
155 bool levelCheckRank(Operation *op, const Value &v,
156 const std::string &checkDesc) {
157 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
158 if (!type.hasRank()) {
159 op->emitOpError() << "failed level check: unranked tensor";
160 return false;
161 }
162 if (type.getRank() > tosaLevel.MAX_RANK) {
163 op->emitOpError() << "failed level check: " << checkDesc;
164 return false;
165 }
166 }
167 return true;
168 }
169
170 template <typename T>
171 bool levelCheckRanksFor(Operation *op) {
172 if (dyn_cast<T>(op)) {
173 // level check ranks of all operands and results
174 for (auto v : op->getOperands()) {
175 if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK"))
176 return false;
177 }
178 for (auto v : op->getResults()) {
179 if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK"))
180 return false;
181 }
182 }
183 return true;
184 }
185
186 bool levelCheckRanks(Operation *op) {
187#define CHECK_RANKS_FOR(tosaOp) \
188 if (!levelCheckRanksFor<tosaOp##Op>(op)) \
189 return false;
190
191 // tensor operators:
192 CHECK_RANKS_FOR(ArgMax);
193 // all activation functions:
194 CHECK_RANKS_FOR(Clamp);
195 CHECK_RANKS_FOR(Sigmoid);
196 CHECK_RANKS_FOR(Tanh);
197 // all elementwise binary operators:
198 CHECK_RANKS_FOR(Add);
199 CHECK_RANKS_FOR(ArithmeticRightShift);
200 CHECK_RANKS_FOR(BitwiseAnd);
201 CHECK_RANKS_FOR(BitwiseOr);
202 CHECK_RANKS_FOR(BitwiseXor);
203 CHECK_RANKS_FOR(Div);
204 CHECK_RANKS_FOR(LogicalAnd);
205 CHECK_RANKS_FOR(LogicalLeftShift);
206 CHECK_RANKS_FOR(LogicalRightShift);
207 CHECK_RANKS_FOR(LogicalOr);
208 CHECK_RANKS_FOR(LogicalXor);
209 CHECK_RANKS_FOR(Maximum);
210 CHECK_RANKS_FOR(Minimum);
211 CHECK_RANKS_FOR(Mul);
212 CHECK_RANKS_FOR(Pow);
213 CHECK_RANKS_FOR(Sub);
214 CHECK_RANKS_FOR(Table);
215 // all elementwise unary operators:
216 CHECK_RANKS_FOR(Abs);
217 CHECK_RANKS_FOR(BitwiseNot);
218 CHECK_RANKS_FOR(Ceil);
219 CHECK_RANKS_FOR(Clz);
220 CHECK_RANKS_FOR(Exp);
221 CHECK_RANKS_FOR(Floor);
222 CHECK_RANKS_FOR(Log);
223 CHECK_RANKS_FOR(LogicalNot);
224 CHECK_RANKS_FOR(Negate);
225 CHECK_RANKS_FOR(Reciprocal);
226 CHECK_RANKS_FOR(Rsqrt);
227 // all elementwise ternary operators:
228 CHECK_RANKS_FOR(Select);
229 // all comparison operators:
230 CHECK_RANKS_FOR(Equal);
231 CHECK_RANKS_FOR(Greater);
232 CHECK_RANKS_FOR(GreaterEqual);
233 // all reduction operators:
234 CHECK_RANKS_FOR(ReduceAll);
235 CHECK_RANKS_FOR(ReduceAny);
236 CHECK_RANKS_FOR(ReduceMax);
237 CHECK_RANKS_FOR(ReduceMin);
238 CHECK_RANKS_FOR(ReduceProd);
239 CHECK_RANKS_FOR(ReduceSum);
240 // all data layout operators:
241 CHECK_RANKS_FOR(Concat);
242 CHECK_RANKS_FOR(Pad);
243 CHECK_RANKS_FOR(Reshape);
244 CHECK_RANKS_FOR(Reverse);
245 CHECK_RANKS_FOR(Slice);
246 CHECK_RANKS_FOR(Tile);
247 CHECK_RANKS_FOR(Transpose);
248 // all type conversion operators:
249 CHECK_RANKS_FOR(Cast);
250 CHECK_RANKS_FOR(Rescale);
251 // all data nodes operators:
252 CHECK_RANKS_FOR(Const);
253 CHECK_RANKS_FOR(Identity);
254
255#undef CHECK_RANKS_FOR
256 return true;
257 }
258
259 // Pool Op: level check kernel/stride/pad values
260 template <typename T>
261 bool levelCheckPool(Operation *op) {
262 if (auto poolOp = dyn_cast<T>(op)) {
263 for (auto k : poolOp.getKernel()) {
264 if (!levelCheckKernel(op, v: k, checkDesc: "kernel <= MAX_KERNEL")) {
265 return false;
266 }
267 }
268 for (auto s : poolOp.getStride()) {
269 if (!levelCheckStride(op, v: s, checkDesc: "stride <= MAX_STRIDE")) {
270 return false;
271 }
272 }
273 for (auto p : poolOp.getPad()) {
274 if (!levelCheckKernel(op, v: p, checkDesc: "pad <= MAX_KERNEL")) {
275 return false;
276 }
277 }
278 }
279 return true;
280 }
281
282 // Conv Op: level check dilation/stride/pad values
283 template <typename T>
284 bool levelCheckConv(Operation *op) {
285 if (auto convOp = dyn_cast<T>(op)) {
286
287 for (auto k : convOp.getDilation()) {
288 if (!levelCheckKernel(op, v: k, checkDesc: "dilation <= MAX_KERNEL")) {
289 return false;
290 }
291 }
292 for (auto p : convOp.getPad()) {
293 if (!levelCheckKernel(op, v: p, checkDesc: "pad <= MAX_KERNEL")) {
294 return false;
295 }
296 }
297 for (auto s : convOp.getStride()) {
298 if (!levelCheckStride(op, v: s, checkDesc: "stride <= MAX_STRIDE")) {
299 return false;
300 }
301 }
302 auto dilation = convOp.getDilation();
303 if (ShapedType weightType =
304 dyn_cast<ShapedType>(op->getOperand(1).getType())) {
305 auto shape = weightType.getShape();
306 if (isa<tosa::Conv2DOp>(op)) {
307 assert(shape.size() == 4);
308 assert(dilation.size() == 2);
309 if (!levelCheckKernel(op, v: dilation[0] * shape[1],
310 checkDesc: "dilation_y * KH <= MAX_KERNEL)") ||
311 !levelCheckKernel(op, v: dilation[1] * shape[2],
312 checkDesc: "dilation_x * KW <= MAX_KERNEL)"))
313 return false;
314 } else if (isa<tosa::Conv3DOp>(op)) {
315 assert(shape.size() == 5);
316 assert(dilation.size() == 3);
317 if (!levelCheckKernel(op, v: dilation[0] * shape[1],
318 checkDesc: "dilation_d * KD <= MAX_KERNEL)") ||
319 !levelCheckKernel(op, v: dilation[1] * shape[2],
320 checkDesc: "dilation_y * KH <= MAX_KERNEL)") ||
321 !levelCheckKernel(op, v: dilation[2] * shape[3],
322 checkDesc: "dilation_x * KW <= MAX_KERNEL)"))
323 return false;
324 } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
325 assert(shape.size() == 4);
326 assert(dilation.size() == 2);
327 if (!levelCheckKernel(op, v: dilation[0] * shape[0],
328 checkDesc: "dilation_y * KH <= MAX_KERNEL)") ||
329 !levelCheckKernel(op, v: dilation[1] * shape[1],
330 checkDesc: "dilation_x * KW <= MAX_KERNEL)"))
331 return false;
332 }
333 }
334 }
335 return true;
336 }
337
338 // FFT op: level check H, W in input shape [N,H,W]
339 template <typename T>
340 bool levelCheckFFT(Operation *op) {
341 if (isa<T>(op)) {
342 for (auto v : op->getOperands()) {
343 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
344 auto shape = type.getShape();
345 assert(shape.size() == 3);
346 if (!levelCheckKernel(op, v: shape[1], checkDesc: "H <= MAX_KERNEL") ||
347 !levelCheckKernel(op, v: shape[2], checkDesc: "W <= MAX_KERNEL")) {
348 return false;
349 }
350 }
351 }
352 }
353 return true;
354 }
355
356 // TransposeConv2d op: level check kH/kW, outpad, and stride
357 bool levelCheckTransposeConv2d(Operation *op) {
358 if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
359 if (ShapedType filterType =
360 dyn_cast<ShapedType>(transpose.getFilter().getType())) {
361 auto shape = filterType.getShape();
362 assert(shape.size() == 4);
363 // level check kernel sizes for kH and KW
364 if (!levelCheckKernel(op, v: shape[1], checkDesc: "KH <= MAX_KERNEL") ||
365 !levelCheckKernel(op, v: shape[2], checkDesc: "KW <= MAX_KERNEL")) {
366 return false;
367 }
368 }
369 for (auto p : transpose.getOutPad()) {
370 if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
371 return false;
372 }
373 }
374 for (auto s : transpose.getStride()) {
375 if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
376 return false;
377 }
378 }
379 }
380 return true;
381 }
382
383 // Resize op: level check max scales
384 bool levelCheckResize(Operation *op) {
385 if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
386 auto scale = resize.getScale();
387 int16_t scaleYN = scale[0];
388 int16_t scaleYD = scale[1];
389 int16_t scaleXN = scale[2];
390 int16_t scaleXD = scale[3];
391 if (!levelCheckScale(op, scaleYN / scaleYD,
392 "scale_y_n/scale_y_d <= MAX_SCALE") ||
393 !levelCheckScale(op, scaleXN / scaleXD,
394 "scale_x_n/scale_x_d <= MAX_SCALE")) {
395 return false;
396 }
397 }
398 return true;
399 }
400
401 // configure profile and level values from pass options profileName and
402 // levelName
403 void configLevelAndProfile() {
404 tosaLevel = TOSA_LEVEL_NONE;
405 if (level == TosaLevelEnum::EightK) {
406 tosaLevel = TOSA_LEVEL_EIGHTK;
407 }
408 }
409
410 bool CheckVariable(Operation *op);
411 bool CheckVariableReadOrWrite(Operation *op);
412
413 bool isValidElementType(Type type);
414
415 SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
416 TosaLevel tosaLevel;
417 DenseMap<StringAttr, mlir::Type> variablesMap;
418};
419
420LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
421 if (tosaLevel == TOSA_LEVEL_NONE) {
422 // no need to do level checks
423 return success();
424 }
425
426 if (!levelCheckRanks(op)) {
427 return failure();
428 }
429
430 // additional level checks from spec 0.70
431 if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
432 !levelCheckConv<tosa::Conv2DOp>(op) ||
433 !levelCheckConv<tosa::Conv3DOp>(op) ||
434 !levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
435 !levelCheckFFT<tosa::FFT2dOp>(op) ||
436 !levelCheckPool<tosa::MaxPool2dOp>(op) ||
437 !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
438 !levelCheckResize(op)) {
439 return failure();
440 }
441
442 return success();
443}
444
445inline bool CompatibleTypes(const mlir::Type &type,
446 const mlir::Type &declaredType) {
447 // for now, simply use type equality comparison
448 return type == declaredType;
449}
450
451bool TosaValidation::CheckVariable(Operation *op) {
452 if (isa<mlir::tosa::VariableOp>(op)) {
453 auto nameAttr = cast<mlir::StringAttr>(op->getAttr(name: "name"));
454
455 if (variablesMap.count(nameAttr)) {
456 op->emitOpError() << "name has already been declared";
457 return false;
458 }
459
460 auto typeAttr = cast<mlir::TypeAttr>(op->getAttr(name: "type"));
461 mlir::Type type = typeAttr.getValue();
462
463 variablesMap[nameAttr] = type;
464 }
465
466 return true;
467}
468
469bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
470 if (isa<mlir::tosa::VariableReadOp>(op) ||
471 isa<mlir::tosa::VariableWriteOp>(op)) {
472 auto nameAttr = cast<mlir::StringAttr>(op->getAttr(name: "name"));
473
474 if (!variablesMap.count(nameAttr)) {
475 op->emitOpError() << "name has not been declared";
476 return false;
477 }
478
479 auto varType = variablesMap[nameAttr];
480
481 for (auto v : op->getOperands()) {
482 auto type = v.getType();
483 if (!CompatibleTypes(type, varType)) {
484 op->emitOpError() << "operand type does not equal variable type";
485 return false;
486 }
487 }
488
489 for (auto v : op->getResults()) {
490 auto type = v.getType();
491 if (!CompatibleTypes(type, varType)) {
492 op->emitOpError() << "result type does not equal variable type";
493 return false;
494 }
495 }
496 }
497
498 return true;
499}
500
501LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
502 if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
503 return failure();
504 }
505 return success();
506}
507
508bool TosaValidation::isValidElementType(Type type) {
509 if ((profile == TosaProfileEnum::BaseInference) && isa<FloatType>(type)) {
510 return false;
511 }
512 if (type.isF64()) {
513 return false;
514 }
515 if (auto intTy = dyn_cast<IntegerType>(type)) {
516 if (intTy.isUnsigned()) {
517 switch (intTy.getWidth()) {
518 case 8:
519 case 16:
520 return true;
521 default:
522 return false;
523 }
524 } else {
525 // Signless - treated as signed.
526 switch (intTy.getWidth()) {
527 case 1:
528 case 4:
529 case 8:
530 case 16:
531 case 32:
532 case 48:
533 case 64:
534 return true;
535 default:
536 return false;
537 }
538 }
539 return false;
540 }
541 return true;
542}
543
544void TosaValidation::runOnOperation() {
545 configLevelAndProfile();
546 getOperation().walk([&](Operation *op) {
547 for (Value operand : op->getOperands()) {
548 auto elementTy = getElementTypeOrSelf(val: operand);
549 if (!isValidElementType(type: elementTy)) {
550 op->emitOpError() << "is not profile-aligned: element type "
551 << elementTy << " is not legal";
552 return signalPassFailure();
553 }
554 }
555 for (Type resultTy : op->getResultTypes()) {
556 auto elementTy = getElementTypeOrSelf(resultTy);
557 if (!isValidElementType(elementTy)) {
558 op->emitOpError() << "is not profile-aligned: element type "
559 << elementTy << " is not legal";
560 return signalPassFailure();
561 }
562 }
563
564 // Some uses of TOSA rely on the constant operands of particular
565 // operations.
566 if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
567 signalPassFailure();
568
569 // do level checks
570 if (failed(result: applyLevelCheck(op)))
571 signalPassFailure();
572
573 // do variable type checks
574 if (failed(result: applyVariableCheck(op)))
575 signalPassFailure();
576 });
577}
578} // namespace
579

source code of mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp