1//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// This file implements the TOSA Specification:
11// https://www.mlplatform.org/tosa/tosa_spec.html
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/Tosa/IR/TosaOps.h"
16#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
17#include "mlir/Dialect/Quant/IR/Quant.h"
18#include "mlir/Dialect/Tensor/IR/Tensor.h"
19#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
20#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
21#include "mlir/Dialect/Utils/IndexingUtils.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/DialectImplementation.h"
24#include "mlir/IR/Matchers.h"
25#include "mlir/IR/TypeUtilities.h"
26#include "mlir/Interfaces/InferTypeOpInterface.h"
27#include "mlir/Transforms/InliningUtils.h"
28#include "llvm/ADT/APFloat.h"
29#include "llvm/ADT/TypeSwitch.h"
30
31#include <numeric>
32
33using namespace mlir;
34using namespace mlir::tosa;
35
36#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
37#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
38
39//===----------------------------------------------------------------------===//
40// Tosa dialect interface includes.
41//===----------------------------------------------------------------------===//
42
43#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
44#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
45#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
46#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
47
48namespace {
49#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
50
51//===----------------------------------------------------------------------===//
52// Dialect Function Inliner Interface.
53//===----------------------------------------------------------------------===//
54struct TosaInlinerInterface : public DialectInlinerInterface {
55 using DialectInlinerInterface::DialectInlinerInterface;
56
57 //===--------------------------------------------------------------------===//
58 // Analysis Hooks.
59 //===--------------------------------------------------------------------===//
60
61 /// All operations can be inlined by default.
62 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
63 IRMapping &map) const final {
64 return true;
65 }
66
67 /// All regions with If and While parent operators can be inlined.
68 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
69 IRMapping &map) const final {
70 return (isa<tosa::IfOp>(Val: dest->getParentOp()) ||
71 isa<tosa::WhileOp>(Val: dest->getParentOp()));
72 }
73};
74
75/// This class implements the bytecode interface for the Tosa dialect.
76struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
77 TosaDialectBytecodeInterface(Dialect *dialect)
78 : BytecodeDialectInterface(dialect) {}
79
80 //===--------------------------------------------------------------------===//
81 // Attributes
82
83 Attribute readAttribute(DialectBytecodeReader &reader) const override {
84 return ::readAttribute(context: getContext(), reader);
85 }
86
87 LogicalResult writeAttribute(Attribute attr,
88 DialectBytecodeWriter &writer) const override {
89 return ::writeAttribute(attribute: attr, writer);
90 }
91
92 //===--------------------------------------------------------------------===//
93 // Types
94
95 Type readType(DialectBytecodeReader &reader) const override {
96 return ::readType(context: getContext(), reader);
97 }
98
99 LogicalResult writeType(Type type,
100 DialectBytecodeWriter &writer) const override {
101 return ::writeType(type, writer);
102 }
103
104 void writeVersion(DialectBytecodeWriter &writer) const final {
105 // TODO: Populate.
106 }
107
108 std::unique_ptr<DialectVersion>
109 readVersion(DialectBytecodeReader &reader) const final {
110 // TODO: Populate
111 reader.emitError(msg: "Dialect does not support versioning");
112 return nullptr;
113 }
114
115 LogicalResult upgradeFromVersion(Operation *topLevelOp,
116 const DialectVersion &version) const final {
117 return success();
118 }
119};
120
121} // namespace
122
123//===----------------------------------------------------------------------===//
124// TOSA control flow support.
125//===----------------------------------------------------------------------===//
126
127/// Returns the while loop body.
128SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
129 return {&getBodyGraph()};
130}
131
132//===----------------------------------------------------------------------===//
133// TOSA variable operator support.
134//===----------------------------------------------------------------------===//
135
136static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
137 return to_vector(Range: llvm::map_range(C&: shape, F: [](int64_t dim) {
138 return dim == -1 ? ShapedType::kDynamic : dim;
139 }));
140}
141
142// returns type of variable op
143RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
144 Type elementType = variableOp.getType();
145 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
146 auto shape = convertToMlirShape(shape: to_vector(Range: varShapeAttr.getValues<int64_t>()));
147 return RankedTensorType::get(shape, elementType);
148}
149
150//===----------------------------------------------------------------------===//
151// Tosa dialect initialization.
152//===----------------------------------------------------------------------===//
153
154void TosaDialect::initialize() {
155 addTypes<
156#define GET_TYPEDEF_LIST
157#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
158 >();
159 addOperations<
160#define GET_OP_LIST
161#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
162 >();
163 addAttributes<
164#define GET_ATTRDEF_LIST
165#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
166 >();
167 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
168 declarePromisedInterfaces<
169 mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
170 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
171 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
172 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
173 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
174 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
175 GreaterEqualOp, MatMulOp>();
176}
177
178Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
179 Type type, Location loc) {
180 // Tosa dialect constants only support ElementsAttr unlike standard dialect
181 // constant which supports all attributes.
182 if (llvm::isa<shapeType>(Val: type) && llvm::isa<DenseIntElementsAttr>(Val: value)) {
183 return builder.create<tosa::ConstShapeOp>(
184 location: loc, args&: type, args: llvm::cast<DenseIntElementsAttr>(Val&: value));
185 }
186 if (llvm::isa<ElementsAttr>(Val: value))
187 return builder.create<tosa::ConstOp>(location: loc, args&: type,
188 args: llvm::cast<ElementsAttr>(Val&: value));
189 return nullptr;
190}
191
192//===----------------------------------------------------------------------===//
193// Parsers and printers
194//===----------------------------------------------------------------------===//
195
196namespace {
197
198ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
199 DenseElementsAttr &varShapeAttr,
200 TypeAttr &typeAttr) {
201 if (auto shapedType = dyn_cast<ShapedType>(Val&: parsedType)) {
202 if (!shapedType.hasRank())
203 return parser.emitError(loc: parser.getCurrentLocation())
204 << "expected ranked type";
205
206 auto elementType = shapedType.getElementType();
207 typeAttr = TypeAttr::get(type: elementType);
208 ArrayRef<int64_t> shape = shapedType.getShape();
209 Builder builder(parser.getContext());
210 varShapeAttr = builder.getIndexTensorAttr(values: convertFromMlirShape(shape));
211 return success();
212 }
213 return parser.emitError(loc: parser.getCurrentLocation())
214 << "expected shaped type";
215}
216
217} // namespace
218
219// parses the optional initial value or type for a tosa variable
220// with initial value:
221// tosa.variable @name = dense<0.0> : tensor<1x8xf32>
222//
223// without initial value:
224// tosa.variable @name : tensor<1x8xf32>
225ParseResult mlir::tosa::parseVariableOpTypeOrInitialValue(
226 OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
227 Attribute &initialValueAttr) {
228 if (succeeded(Result: parser.parseOptionalEqual())) {
229 if (failed(Result: parser.parseAttribute(result&: initialValueAttr))) {
230 return parser.emitError(loc: parser.getCurrentLocation())
231 << "expected attribute";
232 }
233 if (auto typedAttr = dyn_cast<TypedAttr>(Val&: initialValueAttr)) {
234 return getShapeAndElementType(parser, parsedType: typedAttr.getType(), varShapeAttr,
235 typeAttr);
236 }
237 return parser.emitError(loc: parser.getCurrentLocation())
238 << "expected Typed attr";
239 }
240
241 initialValueAttr = nullptr;
242 Type parsedType;
243 if (failed(Result: parser.parseColonType(result&: parsedType))) {
244 return parser.emitError(loc: parser.getCurrentLocation())
245 << "expected type after colon";
246 }
247 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
248}
249
250void mlir::tosa::printVariableOpTypeOrInitialValue(
251 OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
252 TypeAttr typeAttr, Attribute initialValueAttr) {
253 bool needsSpace = false;
254 if (!dyn_cast_or_null<TypedAttr>(Val&: initialValueAttr)) {
255 auto shape =
256 convertToMlirShape(shape: to_vector(Range: varShapeAttr.getValues<int64_t>()));
257 Type elementType = typeAttr.getValue();
258 RankedTensorType tensorType =
259 RankedTensorType::get(shape: ArrayRef<int64_t>(shape), elementType);
260 auto tensorTypeAttr = TypeAttr::get(type: tensorType);
261 p << ": ";
262 p.printAttribute(attr: tensorTypeAttr);
263 needsSpace = true; // subsequent attr value needs a space separator
264 }
265 if (initialValueAttr) {
266 if (needsSpace)
267 p << ' ';
268 p << "= ";
269 p.printAttribute(attr: initialValueAttr);
270 }
271}
272
273//===----------------------------------------------------------------------===//
274// Tosa utilities.
275//===----------------------------------------------------------------------===//
276
277std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
278 if (lhs % rhs != 0)
279 return std::nullopt;
280 return lhs / rhs;
281}
282
283Type getStorageElementTypeOrSelf(Type type) {
284 auto srcType = getElementTypeOrSelf(type);
285 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(Val&: srcType))
286 srcType = quantType.getStorageType();
287 return srcType;
288}
289
290Type getStorageElementTypeOrSelf(Value value) {
291 return getStorageElementTypeOrSelf(type: value.getType());
292}
293
294static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
295 Value valZp, StringRef name) {
296 Type eType = getStorageElementTypeOrSelf(type: val.getType());
297 Type eZpType = getStorageElementTypeOrSelf(type: valZp.getType());
298
299 bool bothInts =
300 mlir::isa<IntegerType>(Val: eType) && mlir::isa<IntegerType>(Val: eZpType);
301 bool sameBitWidth =
302 (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
303
304 if (!bothInts || !sameBitWidth) {
305 return op->emitOpError()
306 << "expected " << name << " and " << name
307 << "_zp to both be integer of the same bitwidth, but got " << eType
308 << " vs. " << eZpType;
309 }
310 return success();
311}
312
313// Create a pad-const const tensor with value of `val` of required data-type
314Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc,
315 Value src, int32_t val) {
316 const auto srcType = getElementTypeOrSelf(val: src);
317 const auto srcElemType = getStorageElementTypeOrSelf(value: src);
318 const auto padConstType = mlir::RankedTensorType::get(shape: {1}, elementType: srcType);
319 const auto padConstEType = mlir::RankedTensorType::get(shape: {1}, elementType: srcElemType);
320 const auto padConstAttr{
321 llvm::isa<FloatType>(Val: srcElemType)
322 ? DenseElementsAttr::get(type: padConstEType,
323 values: builder.getFloatAttr(type: srcElemType, value: val))
324 : DenseElementsAttr::get(type: padConstEType,
325 values: builder.getIntegerAttr(type: srcElemType, value: val))};
326 return builder.create<tosa::ConstOp>(location: loc, args: padConstType, args: padConstAttr);
327}
328
329//===----------------------------------------------------------------------===//
330// TOSA Operator Verifiers.
331//===----------------------------------------------------------------------===//
332
333template <typename T>
334static LogicalResult verifyConvOp(T op) {
335 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
336 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
337
338 auto inputEType = inputType.getElementType();
339 auto weightEType = weightType.getElementType();
340 auto biasEType =
341 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
342 auto resultEType =
343 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
344 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
345 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
346
347 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
348 inputEType = quantType.getStorageType();
349
350 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
351 weightEType = quantType.getStorageType();
352
353 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
354 biasEType = quantType.getStorageType();
355
356 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
357 resultEType = quantType.getStorageType();
358
359 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
360 // for now, only enforce bias element type == result element type for
361 // float types.
362 op.emitOpError(
363 "expect both bias and result to have same element type, got ")
364 << biasEType << " and " << resultEType;
365 return failure();
366 }
367
368 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
369 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
370 if (inputEType != weightEType) {
371 op.emitOpError(
372 "expect both input and weight to have same element type, got ")
373 << inputEType << " and " << weightEType;
374 return failure();
375 }
376 }
377
378 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
379 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
380
381 // Either both must be float or both non-float.
382 if (inputIsFloat != weightIsFloat) {
383 op.emitOpError(
384 "expect both input and weight to be float or not together, got ")
385 << inputEType << " and " << weightEType;
386 return failure();
387 }
388
389 auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
390 if (inputEType != inputZpEType) {
391 return op.emitOpError("expect both input and its zero point are the same "
392 "element type, got ")
393 << inputEType << " and " << inputZpEType;
394 }
395
396 auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
397 if (weightEType != weightZpEType) {
398 return op.emitOpError("expect both weight and its zero point are the same "
399 "element type, got ")
400 << weightEType << " and " << weightZpEType;
401 }
402
403 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
404 if (succeeded(Result: maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
405 return failure();
406
407 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
408 if (succeeded(Result: maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
409 return failure();
410
411 return success();
412}
413
414LogicalResult tosa::ConstOp::verify() {
415
416 auto attrType = llvm::dyn_cast<TensorType>(Val: getValuesAttr().getType());
417 auto outputType = llvm::dyn_cast<TensorType>(Val: getOutput().getType());
418
419 if (!attrType || !outputType) {
420 emitOpError(message: "expected tensors for attr/result type");
421 return failure();
422 }
423
424 if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
425 Val: outputType.getElementType())) {
426 if (result.getStorageType() == attrType.getElementType())
427 return success();
428 }
429
430 if (attrType.getElementType() != outputType.getElementType()) {
431 emitOpError(message: "expected same attr/result element types");
432 return failure();
433 }
434
435 return success();
436}
437
438template <typename T>
439static LogicalResult verifyConvOpModes(T op) {
440 auto inputEType =
441 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
442
443 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
444 inputEType = quantType.getStorageType();
445
446 auto accType = op.getAccType();
447 if (inputEType.isInteger(8) && !accType.isInteger(32))
448 return op.emitOpError("accumulator type for i8 tensor is not i32");
449
450 if (inputEType.isInteger(16) && !accType.isInteger(48))
451 return op.emitOpError("accumulator type for i16 tensor is not i48");
452
453 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
454 return op.emitOpError("accumulator type for f8 tensor is not f16");
455
456 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
457 return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
458
459 if (inputEType.isBF16() && !accType.isF32())
460 return op.emitOpError("accumulator type for bf16 tensor is not f32");
461
462 if (inputEType.isF32() && !accType.isF32())
463 return op.emitOpError("accumulator type for f32 tensor is not f32");
464
465 auto resultEType =
466 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
467
468 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
469 resultEType = quantType.getStorageType();
470
471 return success();
472}
473
474//===----------------------------------------------------------------------===//
475// ERROR_IF functions.
476// ERROR_IF is a predicate that must set an error if the condition holds.
477//===----------------------------------------------------------------------===//
478
479template <typename T>
480static LogicalResult verifyConvOpErrorIf(T op) {
481 llvm::ArrayRef<int64_t> padding = op.getPad();
482 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
483 return op.emitOpError("expect all padding values to be >= 0, got ")
484 << padding;
485
486 llvm::ArrayRef<int64_t> strides = op.getStride();
487 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
488 return op.emitOpError("expect all stride values to be >= 1, got ")
489 << strides;
490
491 llvm::ArrayRef<int64_t> dilations = op.getDilation();
492 if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
493 return op.emitOpError("expect all dilation values to be >= 1, got ")
494 << dilations;
495
496 const RankedTensorType outputType =
497 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
498 if (!outputType)
499 // Skip following checks if output is not ranked
500 return success();
501
502 const RankedTensorType inputType =
503 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
504 const RankedTensorType weightType =
505 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
506
507 if (inputType && weightType) {
508 const auto verifyOutputSize =
509 [&op](const int64_t inputSize, const int64_t kernelSize,
510 const int64_t outputSize, const int64_t padBefore,
511 const int64_t padAfter, const int64_t stride,
512 const int64_t dilation, const llvm::StringRef dimName,
513 const llvm::StringRef dimAxis,
514 const llvm::StringRef padBeforeName,
515 const llvm::StringRef padAfterName) -> LogicalResult {
516 if (inputSize == ShapedType::kDynamic ||
517 kernelSize == ShapedType::kDynamic)
518 return success();
519
520 // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
521
522 const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
523 lhs: inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
524 rhs: stride);
525 if (!calculatedOutSizeMinusOne.has_value())
526 return op.emitOpError("expected input_")
527 << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
528 << padAfterName << " - (kernel_" << dimName
529 << " - 1) * dilation_" << dimAxis
530 << " to be wholly divisible by stride_" << dimAxis << ", got ("
531 << inputSize << " - 1 + " << padBefore << " + " << padAfter
532 << " - (" << kernelSize << " - 1) * " << dilation << ") / "
533 << stride;
534
535 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
536 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
537 return op.emitOpError("calculated output ")
538 << dimName << " did not match expected: "
539 << "calculated=" << calculatedOutSize
540 << ", expected=" << outputSize;
541
542 return success();
543 };
544
545 // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
546 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
547 if (failed(verifyOutputSize(
548 inputType.getDimSize(idx: 1), weightType.getDimSize(idx: 1),
549 outputType.getDimSize(idx: 1), padding[0], padding[1], strides[0],
550 dilations[0], "height", "y", "top", "bottom")))
551 return failure();
552
553 if (failed(verifyOutputSize(
554 inputType.getDimSize(idx: 2), weightType.getDimSize(idx: 2),
555 outputType.getDimSize(idx: 2), padding[2], padding[3], strides[1],
556 dilations[1], "width", "x", "left", "right")))
557 return failure();
558 }
559
560 // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
561 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
562 if (failed(verifyOutputSize(
563 inputType.getDimSize(idx: 1), weightType.getDimSize(idx: 0),
564 outputType.getDimSize(idx: 1), padding[0], padding[1], strides[0],
565 dilations[0], "height", "y", "top", "bottom")))
566 return failure();
567
568 if (failed(verifyOutputSize(
569 inputType.getDimSize(idx: 2), weightType.getDimSize(idx: 1),
570 outputType.getDimSize(idx: 2), padding[2], padding[3], strides[1],
571 dilations[1], "width", "x", "left", "right")))
572 return failure();
573 }
574
575 // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
576 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
577 if (failed(verifyOutputSize(
578 inputType.getDimSize(idx: 1), weightType.getDimSize(idx: 1),
579 outputType.getDimSize(idx: 1), padding[0], padding[1], strides[0],
580 dilations[0], "depth", "d", "front", "back")))
581 return failure();
582
583 if (failed(verifyOutputSize(
584 inputType.getDimSize(idx: 2), weightType.getDimSize(idx: 2),
585 outputType.getDimSize(idx: 2), padding[2], padding[3], strides[1],
586 dilations[1], "height", "y", "top", "bottom")))
587 return failure();
588
589 if (failed(verifyOutputSize(
590 inputType.getDimSize(idx: 3), weightType.getDimSize(idx: 3),
591 outputType.getDimSize(idx: 3), padding[4], padding[5], strides[2],
592 dilations[2], "width", "x", "left", "right")))
593 return failure();
594 }
595 }
596
597 const RankedTensorType biasType =
598 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
599 if (!biasType)
600 // Skip following checks if bias is not ranked
601 return success();
602
603 const int64_t biasChannels = biasType.getDimSize(idx: 0);
604 const int64_t outputChannels =
605 outputType.getDimSize(idx: outputType.getRank() - 1);
606 if (biasChannels == ShapedType::kDynamic ||
607 outputChannels == ShapedType::kDynamic)
608 // Skip following checks if biasChannels or outputChannels is dynamic dim
609 return success();
610
611 if (biasChannels != outputChannels && biasChannels != 1)
612 return op.emitOpError(
613 "bias channels expected to be equal to output channels (")
614 << outputChannels << ") or 1, got " << biasChannels;
615
616 return success();
617}
618
619// Verify whether same type and shape of the given two types.
620static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
621 StringRef name1, Type type2,
622 StringRef name2) {
623 auto shapeType1 = dyn_cast<ShapedType>(Val&: type1);
624 auto shapeType2 = dyn_cast<ShapedType>(Val&: type2);
625 if (!shapeType1 || !shapeType2)
626 return failure();
627
628 auto elemType1 = shapeType1.getElementType();
629 auto elemType2 = shapeType2.getElementType();
630 if (elemType1 != elemType2)
631 return op->emitOpError()
632 << "require same element type for " << name1 << " (" << elemType1
633 << ") and " << name2 << " (" << elemType2 << ")";
634
635 if (failed(Result: verifyCompatibleShape(type1, type2)))
636 return op->emitOpError()
637 << "require same shapes for " << name1 << " (" << type1 << ") and "
638 << name2 << " (" << type2 << ")";
639
640 return success();
641}
642
643// Verify whether same length, type, and shape of the given two tensor lists.
644static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
645 StringRef name1,
646 ValueRange list2,
647 StringRef name2) {
648 if (list1.size() != list2.size())
649 return op->emitOpError()
650 << "require same number of values in " << name1 << " ("
651 << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
652
653 for (auto [type1, type2] :
654 llvm::zip_equal(t: list1.getTypes(), u: list2.getTypes())) {
655 if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
656 return failure();
657 }
658
659 return success();
660}
661
662static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
663 ShapeAdaptor shapeAdaptor(type);
664 if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
665 return success();
666
667 return shapeAdaptor.getNumElements() == 1 ? success() : failure();
668}
669
670// Returns the first declaration point prior to this operation or failure if
671// not found.
672static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
673 StringRef symName) {
674 ModuleOp module = op->getParentOfType<ModuleOp>();
675 tosa::VariableOp varOp = nullptr;
676
677 // TODO: Adopt SymbolTable trait to Varible ops.
678 // Currently, the variable's definition point is searched via walk(),
679 // starting from the top-level ModuleOp and stopping at the point of use. Once
680 // TOSA control flow and variable extensions reach the complete state, may
681 // leverage MLIR's Symbol Table functionality to look up symbol and enhance
682 // the search to a TOSA specific graph traversal over the IR structure.
683 module.walk(callback: [&](Operation *tempOp) {
684 // Reach this op itself.
685 if (tempOp == op) {
686 return WalkResult::interrupt();
687 }
688
689 if (auto tosaOp = dyn_cast<tosa::VariableOp>(Val: tempOp)) {
690 if (symName == tosaOp.getName()) {
691 varOp = tosaOp;
692 return WalkResult::interrupt();
693 }
694 }
695
696 return WalkResult::advance();
697 });
698
699 if (varOp)
700 return varOp;
701
702 return failure();
703}
704
705template <typename T>
706static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
707 StringRef symName = op.getName();
708 FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
709 if (failed(Result: varOp))
710 return op->emitOpError("'")
711 << symName << "' has not been declared by 'tosa.variable'";
712
713 // Verify type and shape
714 auto variableType = getVariableType(variableOp: varOp.value());
715 if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
716 "the input tensor")
717 .failed())
718 return failure();
719
720 return success();
721}
722
723// verify that inType and outType have same element types
724template <typename T>
725static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
726 auto inputType = llvm::dyn_cast<TensorType>(Val&: inType);
727 auto outputType = llvm::dyn_cast<TensorType>(Val&: outType);
728 if (!inputType) {
729 op.emitOpError("expect shaped tensor for input, got ") << inType;
730 return failure();
731 }
732 if (!outputType) {
733 op.emitOpError("expect shaped tensor for output, got ") << outType;
734 return failure();
735 }
736 auto inputElementType = inputType.getElementType();
737 auto outputElementType = outputType.getElementType();
738 auto inputQuantType =
739 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(Val&: inputElementType);
740 auto outputQuantType =
741 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(Val&: outputElementType);
742 if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
743 (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
744 inputElementType != outputElementType) {
745 // only check if both element types are int/index/float/UniformQuantized
746 // eg, not sure how to check quant::QuantizedType
747 // this happens in test_conv2d_q_grouped_convolution in
748 // tfl-to-tosa-pipeline.mlir
749 op.emitOpError("expect input and output to have same element type, got ")
750 << inputElementType << " and " << outputElementType;
751 return failure();
752 }
753 return success();
754}
755
756LogicalResult tosa::ArgMaxOp::verify() {
757 const ShapedType resultType = llvm::cast<ShapedType>(Val: getType());
758
759 // Ensure output is of 32-bit integer
760 if (const auto resultETy = resultType.getElementType();
761 !resultETy.isIntOrIndex())
762 return emitOpError(message: "result tensor is not of integer type");
763
764 const auto inputType = llvm::cast<ShapedType>(Val: getInput().getType());
765 if (!inputType.hasRank())
766 return success();
767
768 // Ensure axis is within the tensor rank
769 const int64_t axis = getAxisAttr().getInt();
770 if (((axis < 0) || axis >= inputType.getRank()))
771 return emitOpError(message: "specified axis is outside the rank of the tensor");
772
773 if (!resultType.hasRank())
774 return success();
775
776 const ArrayRef<int64_t> inputShape = inputType.getShape();
777 const ArrayRef<int64_t> outputShape = resultType.getShape();
778 llvm::SmallVector<int64_t> expectedOutputShape(inputShape);
779 expectedOutputShape.erase(CI: expectedOutputShape.begin() + axis);
780 if (failed(Result: verifyCompatibleShape(shape1: expectedOutputShape, shape2: outputShape)))
781 return emitOpError(message: "expected output shape '")
782 << expectedOutputShape << "', got '" << outputShape << "'";
783
784 return success();
785}
786
787template <typename T>
788static LogicalResult verifyPoolingOp(T op) {
789 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
790 if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
791 return op.emitOpError("expect all kernel values to be >= 1, got ")
792 << kernel;
793
794 const llvm::ArrayRef<int64_t> strides = op.getStride();
795 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
796 return op.emitOpError("expect all stride values to be >= 1, got ")
797 << strides;
798
799 const llvm::ArrayRef<int64_t> padding = op.getPad();
800 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
801 return op.emitOpError("expect all padding values to be >= 0, got ")
802 << padding;
803
804 // Padding must be less than kernel size to avoid a divide-by-zero
805 const int64_t kernelX = kernel[1];
806 const int64_t padLeft = padding[2];
807 const int64_t padRight = padding[3];
808 if (padRight >= kernelX || padLeft >= kernelX)
809 return op.emitOpError("expected left/right padding to be less than the "
810 "width of the kernel, got pad_left=")
811 << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;
812
813 const int64_t kernelY = kernel[0];
814 const int64_t padTop = padding[0];
815 const int64_t padBottom = padding[1];
816 if (padTop >= kernelY || padBottom >= kernelY)
817 return op.emitOpError("expected top/bottom padding to be less than the "
818 "height of the kernel, got pad_top=")
819 << padTop << ", pad_bottom=" << padBottom
820 << ", kernel_y=" << kernelY;
821
822 const auto inputType =
823 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
824 const auto outputType =
825 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
826 if (!inputType || !outputType)
827 return success();
828
829 const auto verifyOutputSize =
830 [&op](const int64_t inputSize, const int64_t outputSize,
831 const int64_t kernelSize, const int64_t strideSize,
832 const int64_t padBefore, const int64_t padAfter,
833 const llvm::StringRef dimName, const llvm::StringRef dimAxis,
834 const llvm::StringRef padBeforeName,
835 const llvm::StringRef padAfterName) -> LogicalResult {
836 if (ShapedType::isDynamic(dValue: inputSize))
837 return success();
838
839 const std::optional<int64_t> calculatedOutSizeMinusOne =
840 idivCheck(lhs: inputSize + padBefore + padAfter - kernelSize, rhs: strideSize);
841 if (!calculatedOutSizeMinusOne.has_value())
842 return op.emitOpError("expected input_")
843 << dimName << " + pad_" << padBeforeName << " + pad_"
844 << padAfterName << " - kernel_" << dimAxis
845 << " to be wholly divisible by stride_" << dimAxis << ", got ("
846 << inputSize << " + " << padBefore << " + " << padAfter << " - "
847 << kernelSize << ") / " << strideSize;
848
849 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
850 if (ShapedType::isStatic(dValue: outputSize) && calculatedOutSize != outputSize)
851 return op.emitOpError("calculated output ")
852 << dimName << " did not match expected: "
853 << "calculated=" << calculatedOutSize
854 << ", expected=" << outputSize;
855
856 return success();
857 };
858
859 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
860 kernel[0], strides[0], padding[0], padding[1],
861 "height", "y", "top", "bottom")))
862 return failure();
863
864 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
865 kernel[1], strides[1], padding[2], padding[3],
866 "width", "x", "left", "right")))
867 return failure();
868
869 return success();
870}
871
872LogicalResult tosa::AvgPool2dOp::verify() {
873 if (failed(Result: verifyPoolingOp(op: *this)))
874 return failure();
875
876 const Type inputETy = getStorageElementTypeOrSelf(type: getInput().getType());
877 const Type resultETy = getStorageElementTypeOrSelf(type: getOutput().getType());
878 const Type inputZpETy = getStorageElementTypeOrSelf(type: getInputZp().getType());
879 const Type outputZpETy = getStorageElementTypeOrSelf(type: getOutputZp().getType());
880
881 auto accType = getAccType();
882 if (llvm::isa<IntegerType>(Val: inputETy) && !accType.isInteger(width: 32))
883 return emitOpError(message: "accumulator type for integer tensor is not i32");
884
885 if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
886 return emitOpError(message: "accumulator type for f16 tensor is not f16/f32");
887
888 if (inputETy.isBF16() && !accType.isF32())
889 return emitOpError(message: "accumulator type for bf16 tensor is not f32");
890
891 if (inputETy.isF32() && !accType.isF32())
892 return emitOpError(message: "accumulator type for f32 tensor is not f32");
893
894 if (inputETy != inputZpETy)
895 return emitOpError(message: "expect both input and its zero point are the same "
896 "element type, got ")
897 << inputETy << " and " << inputZpETy;
898
899 if (resultETy != outputZpETy)
900 return emitOpError(message: "expect both output and its zero point are the same "
901 "element type, got ")
902 << resultETy << " and " << outputZpETy;
903
904 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
905 if (succeeded(Result: maybeIZp) && verifyInputZeroPoint(zp: *maybeIZp).failed())
906 return failure();
907
908 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
909 if (succeeded(Result: maybeOZp) && verifyOutputZeroPoint(zp: *maybeOZp).failed())
910 return failure();
911
912 return success();
913}
914
915LogicalResult tosa::ClampOp::verify() {
916 mlir::Type inputETy =
917 llvm::cast<ShapedType>(Val: getInput().getType()).getElementType();
918 if (auto quantType =
919 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(Val&: inputETy)) {
920 inputETy = quantType.getStorageType();
921 }
922 mlir::Type outputETy =
923 llvm::cast<ShapedType>(Val: getOutput().getType()).getElementType();
924 if (auto quantType =
925 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(Val&: outputETy)) {
926 outputETy = quantType.getStorageType();
927 }
928 if (inputETy != outputETy)
929 return emitOpError(message: "input/output element types are incompatible.");
930
931 auto maxValAttr = getMaxValAttr();
932 auto minValAttr = getMinValAttr();
933
934 unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
935
936 if (inputETy.isInteger(width: dataTypeBitWidth)) {
937 // if input datatype is integer, check that the min_val/max_val attributes
938 // are integer attributes, and that their type is the same as the input's
939 // datatype
940 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(Val&: maxValAttr);
941 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(Val&: minValAttr);
942 if (!intMaxValAttr || !intMinValAttr ||
943 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
944 (intMaxValAttr.getType() != inputETy))
945 return emitOpError(message: "min/max attributes types are incompatible with "
946 "input/output element types.");
947
948 const bool isUnsigned = cast<IntegerType>(Val&: inputETy).isUnsigned();
949 const APInt minVal = intMinValAttr.getValue();
950 const APInt maxVal = intMaxValAttr.getValue();
951 if (isUnsigned ? maxVal.ult(RHS: minVal) : maxVal.slt(RHS: minVal))
952 return emitOpError(message: "expected min_val <= max_val, got min_val=")
953 << minValAttr << ", max_val=" << maxValAttr;
954 } else {
955 // otherwise, input datatype is float, check that the min_val/max_val
956 // attributes share the same type and that their type is the same as the
957 // input's datatype
958 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(Val&: maxValAttr);
959 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(Val&: minValAttr);
960 if (!floatMaxValAttr || !floatMinValAttr ||
961 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
962 (floatMaxValAttr.getType() != inputETy))
963 return emitOpError(message: "min/max attributes types are incompatible with "
964 "input/output element types.");
965
966 const APFloat minVal = floatMinValAttr.getValue();
967 const APFloat maxVal = floatMaxValAttr.getValue();
968 if (minVal.isNaN() || maxVal.isNaN())
969 return emitOpError(message: "min/max attributes should not be 'NaN', got min_val=")
970 << minValAttr << ", max_val=" << maxValAttr;
971
972 if (maxVal < minVal)
973 return emitOpError(message: "expected min_val <= max_val, got min_val=")
974 << minValAttr << ", max_val=" << maxValAttr;
975 }
976
977 return success();
978}
979
980//===----------------------------------------------------------------------===//
981// TOSA Operator Quantization Builders.
982//===----------------------------------------------------------------------===//
983
984/// This builder is called on all convolution operators except TransposeConv,
985/// which has specialized output shape semantics. The builder also defines the
986/// bitwidth of the output given the bit width of the input & weight content.
987static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
988 Type outputType, Value input, Value weight,
989 Value bias, DenseI64ArrayAttr pad,
990 DenseI64ArrayAttr stride,
991 DenseI64ArrayAttr dilation,
992 TypeAttr accType) {
993 auto zps = createZPsAsConst(builder, input, weight);
994 result.addOperands(newOperands: {input, weight, bias, zps.first, zps.second});
995 result.addAttribute(name: "pad", attr: pad);
996 result.addAttribute(name: "stride", attr: stride);
997 result.addAttribute(name: "dilation", attr: dilation);
998 result.addAttribute(name: "acc_type", attr: accType);
999 Type finalOutputType = outputType;
1000 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1001 if (quantAttr) {
1002 finalOutputType =
1003 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1004 }
1005 result.addTypes(newTypes: finalOutputType);
1006}
1007
1008/// Handles tosa.transpose_conv2d which has outpad and output shape
1009/// attributes.
1010static void
1011buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
1012 Type outputType, Value input, Value weight,
1013 Value bias, DenseI64ArrayAttr outpad,
1014 DenseI64ArrayAttr stride, TypeAttr accType) {
1015 auto zps = createZPsAsConst(builder, input, weight);
1016 result.addOperands(newOperands: {input, weight, bias, zps.first, zps.second});
1017 result.addAttribute(name: "out_pad", attr: outpad);
1018 result.addAttribute(name: "stride", attr: stride);
1019 result.addAttribute(name: "acc_type", attr: accType);
1020 Type finalOutputType = outputType;
1021 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1022 if (quantAttr) {
1023 finalOutputType =
1024 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1025 }
1026 result.addTypes(newTypes: finalOutputType);
1027}
1028
1029/// The tosa.matmul op is also intended to be generated where a fully_connected
1030/// op must be constructed where the weight is not a constant. In this case,
1031/// the fully_connected op must be expressed using matmul.
1032/// TODO: Add link to the leglization document explaining this.
1033static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
1034 OperationState &result, Type outputType,
1035 Value a, Value b) {
1036 auto zps = createZPsAsConst(builder, input: a, weight: b);
1037 result.addOperands(newOperands: {a, b, zps.first, zps.second});
1038
1039 Type finalOutputType{outputType};
1040 if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
1041 auto eType = getStorageElementTypeOrSelf(type: a.getType());
1042 auto inputBits = eType.getIntOrFloatBitWidth();
1043
1044 auto outputShapedType = llvm::dyn_cast<ShapedType>(Val&: outputType);
1045 assert(outputShapedType && "Output must be a shaped type");
1046
1047 IntegerType accElementType;
1048 if (inputBits == 16)
1049 accElementType = builder.getIntegerType(width: 48);
1050 else
1051 accElementType = builder.getI32Type();
1052
1053 finalOutputType = outputShapedType.clone(elementType: accElementType);
1054 }
1055 result.addTypes(newTypes: finalOutputType);
1056}
1057
1058/// Both the tosa.avg_pool2d and unary ops use the same
1059/// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
1060/// has additional parameters not part of the unary ops.
1061static void
1062buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
1063 Type outputType, Value input,
1064 DenseArrayAttr kernel, DenseArrayAttr stride,
1065 DenseArrayAttr pad, TypeAttr accType) {
1066 const Location loc{result.location};
1067 int64_t inputZp{0};
1068 int64_t outputZp{0};
1069
1070 if (auto quantAttr =
1071 buildUnaryOpQuantizationAttr(builder, input, outputRawType: outputType)) {
1072 inputZp = quantAttr.getInputZp();
1073 outputZp = quantAttr.getOutputZp();
1074 }
1075 const std::optional<Value> inputZpOp =
1076 createZeroPointTensor(builder, loc, srcElemType: input.getType(), zp: inputZp);
1077 if (!inputZpOp) {
1078 (void)emitError(
1079 loc,
1080 message: "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1081 }
1082 const std::optional<Value> outputZpOp =
1083 createZeroPointTensor(builder, loc, srcElemType: outputType, zp: outputZp);
1084 if (!outputZpOp) {
1085 (void)emitError(loc, message: "Failed to create output zero point tensor for "
1086 "quantized AVG_POOL2D op");
1087 }
1088
1089 if (inputZpOp && outputZpOp) {
1090 result.addOperands(newOperands: {input, inputZpOp.value(), outputZpOp.value()});
1091 } else {
1092 // failed to create one or more zero points above: just add input as
1093 // operands this will trigger error in building the op because of missing
1094 // zero points
1095 result.addOperands(newOperands: {input});
1096 }
1097 result.addAttribute(name: "kernel", attr: kernel);
1098 result.addAttribute(name: "stride", attr: stride);
1099 result.addAttribute(name: "pad", attr: pad);
1100 result.addAttribute(name: "acc_type", attr: accType);
1101 result.types.push_back(Elt: outputType);
1102}
1103
1104/// This builder is called on single-parameter negate operator
1105/// to construct input and output zero points based on their
1106/// types.
1107static void buildNegateOpWithQuantInfo(OpBuilder &builder,
1108 OperationState &result, Type outputType,
1109 Value input) {
1110 const Location loc{result.location};
1111 int64_t input1Zp{0};
1112 int64_t outputZp{0};
1113 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputRawType: outputType);
1114 if (quantAttr) {
1115 input1Zp = quantAttr.getInputZp();
1116 outputZp = quantAttr.getOutputZp();
1117 }
1118 const std::optional<Value> input1ZpOp =
1119 createZeroPointTensor(builder, loc, srcElemType: input.getType(), zp: input1Zp);
1120 if (!input1ZpOp) {
1121 (void)emitError(
1122 loc, message: "Failed to create input1 zero point for quantized NEGATE op");
1123 }
1124
1125 const std::optional<Value> outputZpOp =
1126 createZeroPointTensor(builder, loc, srcElemType: input.getType(), zp: outputZp);
1127 if (!outputZpOp) {
1128 (void)emitError(
1129 loc, message: "Failed to create output zero point for quantized NEGATE op");
1130 }
1131
1132 if (input1ZpOp && outputZpOp) {
1133 result.addOperands(newOperands: {input, input1ZpOp.value(), outputZpOp.value()});
1134 } else {
1135 // failed to create one or more zero points above: just add input as
1136 // operands. This will trigger error in building the op because of
1137 // missing zero points
1138 result.addOperands(newOperands: {input});
1139 }
1140
1141 result.types.push_back(Elt: outputType);
1142}
1143
1144/// This builder is called on TOSA pad operator that needs to create its own
1145/// OptionalAttr quantization_attr parameter to scale the padding values
1146/// correctly. No pad_const is interpreted as zero-padding.
1147static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
1148 Type outputType, Value input,
1149 Value paddings) {
1150 const Location loc{result.location};
1151 int32_t zp{0};
1152 const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
1153 if (quantAttr) {
1154 zp = static_cast<int32_t>(quantAttr.getInputZp());
1155 }
1156 const auto padConstOp{createPadConstTensor(builder, loc, src: input, val: zp)};
1157 result.addOperands(newOperands: {input, paddings, padConstOp});
1158 result.types.push_back(Elt: outputType);
1159}
1160
1161static void buildVariableOp(OpBuilder &builder, OperationState &result,
1162 StringRef name, Type variableType,
1163 Attribute initialValue) {
1164 const Location loc{result.location};
1165 auto nameAttr = builder.getStringAttr(bytes: name);
1166
1167 auto shapedType = dyn_cast<ShapedType>(Val&: variableType);
1168 if (!shapedType) {
1169 (void)emitError(loc, message: "variable type must be a shaped type");
1170 return;
1171 }
1172 if (!shapedType.hasRank()) {
1173 (void)emitError(loc, message: "variable type must be a ranked type");
1174 return;
1175 }
1176
1177 auto elementType = shapedType.getElementType();
1178 auto elementTypeAttr = TypeAttr::get(type: elementType);
1179 ArrayRef<int64_t> shape = shapedType.getShape();
1180 auto varShapeAttr = builder.getIndexTensorAttr(values: convertFromMlirShape(shape));
1181
1182 result.addAttribute(name: "name", attr: nameAttr);
1183 result.addAttribute(name: "var_shape", attr: varShapeAttr);
1184 result.addAttribute(name: "type", attr: elementTypeAttr);
1185 result.addAttribute(name: "initial_value", attr: initialValue);
1186}
1187
1188//===----------------------------------------------------------------------===//
1189// TOSA Operator Return Type Inference.
1190//===----------------------------------------------------------------------===//
1191
1192static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1193 SmallVector<int64_t> &outShape) {
1194 int64_t outRank = 0;
1195 for (int i = 0, e = operands.size(); i != e; ++i) {
1196 auto shape = operands.getShape(index: i);
1197 if (!shape.hasRank()) {
1198 // TODO(jennik): Update function to have better case handling for
1199 // invalid operands and for ranked tensors.
1200 return failure();
1201 }
1202 outRank = std::max<int64_t>(a: outRank, b: shape.getRank());
1203 }
1204
1205 outShape.resize(N: outRank, NV: 1);
1206
1207 for (int i = 0, e = operands.size(); i != e; ++i) {
1208 auto shape = operands.getShape(index: i);
1209 auto rankDiff = outShape.size() - shape.getRank();
1210
1211 for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1212 auto dim1 = outShape[i + rankDiff];
1213 auto dim2 = shape.getDimSize(index: i);
1214 auto resolvedDim = dim1;
1215
1216 if (dim1 == 1) {
1217 resolvedDim = dim2;
1218 } else if (dim2 == 1) {
1219 resolvedDim = dim1;
1220 } else if (dim1 != dim2) {
1221 return failure();
1222 }
1223 outShape[i + rankDiff] = resolvedDim;
1224 }
1225 }
1226
1227 return success();
1228}
1229
1230LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1231 MLIRContext *context, ::std::optional<Location> location,
1232 ArgMaxOp::Adaptor adaptor,
1233 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1234 ShapeAdaptor inputShape(adaptor.getInput().getType());
1235 IntegerAttr axis = adaptor.getProperties().axis;
1236 int32_t axisVal = axis.getValue().getSExtValue();
1237
1238 if (!inputShape.hasRank()) {
1239 inferredReturnShapes.push_back(Elt: ShapedTypeComponents());
1240 return success();
1241 }
1242
1243 SmallVector<int64_t> outShape;
1244 outShape.reserve(N: inputShape.getRank() - 1);
1245 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1246 if (i == axisVal)
1247 continue;
1248 outShape.push_back(Elt: inputShape.getDimSize(index: i));
1249 }
1250
1251 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outShape));
1252 return success();
1253}
1254
1255LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1256 MLIRContext *context, ::std::optional<Location> location,
1257 RFFT2dOp::Adaptor adaptor,
1258 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1259 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1260
1261 if (!inputShape.hasRank())
1262 return failure();
1263
1264 llvm::SmallVector<int64_t> outputShape;
1265 outputShape.resize(N: 3, NV: ShapedType::kDynamic);
1266 outputShape[0] = inputShape.getDimSize(index: 0);
1267 outputShape[1] = inputShape.getDimSize(index: 1);
1268 int64_t inWidth = inputShape.getDimSize(index: 2);
1269
1270 // Note that we can support this calculation symbolically
1271 // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
1272 if (inWidth != ShapedType::kDynamic)
1273 outputShape[2] = inWidth / 2 + 1;
1274
1275 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
1276 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
1277
1278 return success();
1279}
1280
1281static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
1282 const llvm::StringRef dimName) {
1283 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1284 if (!isPowerOfTwo)
1285 return op->emitOpError(message: "expected ")
1286 << dimName << " to be a power of two, got " << dimSize;
1287
1288 return success();
1289}
1290
1291LogicalResult tosa::RFFT2dOp::verify() {
1292 const auto outputTypes = getResultTypes();
1293 if (failed(Result: verifyCompatibleShapes(types: outputTypes)))
1294 return emitOpError(message: "expected output shapes to match, got ") << outputTypes;
1295
1296 const auto inputType =
1297 llvm::dyn_cast<RankedTensorType>(Val: getInputReal().getType());
1298 if (!inputType)
1299 return success();
1300
1301 const int64_t height = inputType.getDimSize(idx: 1);
1302 if (ShapedType::isStatic(dValue: height) &&
1303 failed(Result: verifyDimIsPowerOfTwo(op: *this, dimSize: height, dimName: "height")))
1304 return failure();
1305
1306 const int64_t width = inputType.getDimSize(idx: 2);
1307 if (ShapedType::isStatic(dValue: width) &&
1308 failed(Result: verifyDimIsPowerOfTwo(op: *this, dimSize: width, dimName: "width")))
1309 return failure();
1310
1311 const auto outputType = llvm::dyn_cast<RankedTensorType>(Val: outputTypes[0]);
1312 if (!outputType)
1313 return success();
1314
1315 // Batch and height input/output dimensions should match
1316 if (failed(Result: verifyCompatibleShape(shape1: inputType.getShape().drop_back(),
1317 shape2: outputType.getShape().drop_back())))
1318 return emitOpError(message: "expected batch and height dimensions of input/output "
1319 "to match, got input=")
1320 << inputType << " output=" << outputType;
1321
1322 // Output width dimension expected to be input_width / 2 + 1
1323 const int64_t outputWidth = outputType.getDimSize(idx: 2);
1324 if (ShapedType::isStatic(dValue: width) && ShapedType::isStatic(dValue: outputWidth) &&
1325 (outputWidth != (width / 2) + 1))
1326 return emitOpError(
1327 message: "expected output width to be equal to input_width / 2 + 1, got ")
1328 << outputWidth;
1329
1330 return success();
1331}
1332
1333LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1334 MLIRContext *context, ::std::optional<Location> location,
1335 FFT2dOp::Adaptor adaptor,
1336 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1337 inferredReturnShapes.push_back(
1338 Elt: ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1339 inferredReturnShapes.push_back(
1340 Elt: ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1341 return success();
1342}
1343
1344LogicalResult tosa::FFT2dOp::verify() {
1345 const auto inputRealType =
1346 llvm::dyn_cast<RankedTensorType>(Val: getInputReal().getType());
1347 const auto inputImagType =
1348 llvm::dyn_cast<RankedTensorType>(Val: getInputImag().getType());
1349 if (!inputRealType || !inputImagType)
1350 return success();
1351
1352 const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1353 return ShapedType::isDynamic(dValue: a) ? a : b;
1354 };
1355
1356 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(idx: 1),
1357 inputImagType.getDimSize(idx: 1));
1358 if (ShapedType::isStatic(dValue: height) &&
1359 failed(Result: verifyDimIsPowerOfTwo(op: *this, dimSize: height, dimName: "height")))
1360 return failure();
1361
1362 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(idx: 2),
1363 inputImagType.getDimSize(idx: 2));
1364 if (ShapedType::isStatic(dValue: width) &&
1365 failed(Result: verifyDimIsPowerOfTwo(op: *this, dimSize: width, dimName: "width")))
1366 return failure();
1367
1368 return success();
1369}
1370
1371LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1372 MLIRContext *context, ::std::optional<Location> location,
1373 ConcatOp::Adaptor adaptor,
1374 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1375 // Infer all dimension sizes by reducing based on inputs.
1376 const Properties &prop = adaptor.getProperties();
1377 int32_t axis = prop.axis.getValue().getSExtValue();
1378 llvm::SmallVector<int64_t> outputShape;
1379 bool hasRankedInput = false;
1380 for (auto operand : adaptor.getOperands()) {
1381 ShapeAdaptor operandShape(operand.getType());
1382 if (!operandShape.hasRank())
1383 continue;
1384
1385 // Copy the Operand's rank.
1386 if (!hasRankedInput)
1387 outputShape.resize(N: operandShape.getRank(), NV: ShapedType::kDynamic);
1388
1389 // Copy shapes until the dim is non-dynamic.
1390 for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1391 if (i == axis || operandShape.isDynamicDim(index: i))
1392 continue;
1393 if (outputShape[i] == ShapedType::kDynamic)
1394 outputShape[i] = operandShape.getDimSize(index: i);
1395 if (outputShape[i] != operandShape.getDimSize(index: i))
1396 return emitOptionalError(loc: location,
1397 args: "Cannot concat tensors with different sizes"
1398 " on the non-axis dimension ",
1399 args&: i);
1400 }
1401
1402 hasRankedInput = true;
1403 }
1404
1405 if (adaptor.getInput1().empty())
1406 return failure();
1407
1408 Type inputType =
1409 llvm::cast<TensorType>(Val: adaptor.getInput1().getType()[0]).getElementType();
1410 if (!hasRankedInput) {
1411 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(inputType));
1412 return success();
1413 }
1414
1415 // Determine the dimension size along the concatenation axis.
1416 int64_t concatDimSize = 0;
1417 for (auto operand : adaptor.getOperands()) {
1418 ShapeAdaptor operandShape(operand.getType());
1419
1420 // We need to know the length of the concatenation axis of all inputs to
1421 // determine the dimension size of the output shape.
1422 if (!operandShape.hasRank() || operandShape.isDynamicDim(index: axis)) {
1423 concatDimSize = ShapedType::kDynamic;
1424 break;
1425 }
1426
1427 concatDimSize += operandShape.getDimSize(index: axis);
1428 }
1429
1430 outputShape[axis] = concatDimSize;
1431
1432 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape, inputType));
1433 return success();
1434}
1435
1436LogicalResult tosa::ConcatOp::verify() {
1437 // check that each input has same element type as output
1438 auto outType = getOutput().getType();
1439 const Operation::operand_range inputList = getInput1();
1440
1441 // Check there is at least one input
1442 if (inputList.empty())
1443 return emitOpError(message: "expect at least one input");
1444
1445 if (!llvm::all_of(Range: inputList, P: [&](auto input) {
1446 return succeeded(verifySameElementTypes(
1447 *this, /* inType = */ input.getType(), outType));
1448 })) {
1449 return failure();
1450 }
1451
1452 const int32_t axis = getAxis();
1453 ShapeAdaptor firstRankedInputShape = nullptr;
1454 for (const auto &input : inputList) {
1455 const Type inputType = input.getType();
1456 ShapeAdaptor currShape(inputType);
1457 if (currShape.hasRank()) {
1458 firstRankedInputShape = currShape;
1459 // Check axis is in expected range
1460 if (axis < 0 || axis >= firstRankedInputShape.getRank())
1461 return emitOpError(message: "expect axis to be within range 0 < axis < "
1462 "rank(input1[firstRankedTensorIdx]), got ")
1463 << axis;
1464 break;
1465 }
1466 }
1467
1468 const auto allOperandsHasRank = [](const Value input) {
1469 return ShapeAdaptor(input.getType()).hasRank();
1470 };
1471 if (llvm::all_of(Range: inputList, P: allOperandsHasRank)) {
1472 const int64_t firstInputRank = firstRankedInputShape.getRank();
1473
1474 for (const auto &[index, input] : llvm::enumerate(First: inputList.drop_front())) {
1475 const ShapeAdaptor inputShape(input.getType());
1476 const int64_t inputRank = inputShape.getRank();
1477 const size_t operandNum = index + 1;
1478
1479 // Check that each operand has the same rank
1480 if (inputRank != firstInputRank)
1481 return emitOpError(
1482 message: "expect all operands to have the same rank, but got ")
1483 << firstInputRank << " vs " << inputRank << " on operands 0 and "
1484 << operandNum;
1485
1486 // Check non-axis dims match
1487 for (int i = 0; i < inputRank; i++) {
1488 const int64_t inputDim = inputShape.getDimSize(index: i);
1489 const int64_t firstInputDim = firstRankedInputShape.getDimSize(index: i);
1490 if (i == axis || firstRankedInputShape.isDynamicDim(index: i) ||
1491 inputShape.isDynamicDim(index: i))
1492 continue;
1493 if (inputDim != firstInputDim)
1494 return emitOpError(message: "expect all operand shapes to have the same sizes "
1495 "on non-axis dimensions, but got ")
1496 << inputDim << " vs " << firstInputDim << " at index " << i
1497 << " on operands 0 and " << operandNum;
1498 }
1499 }
1500
1501 // ERROR_IF(axis_sum != shape[axis]);
1502 int64_t axisSum = 0;
1503 for (const auto &input : inputList) {
1504 const ShapeAdaptor inputShape(input.getType());
1505 if (inputShape.isDynamicDim(index: axis)) {
1506 // make axisSum negative to indicate invalid value
1507 axisSum = -1;
1508 break;
1509 }
1510 axisSum += inputShape.getDimSize(index: axis);
1511 }
1512 const ShapeAdaptor outputShape(outType);
1513 if (axisSum >= 0 && outputShape.hasRank() &&
1514 !outputShape.isDynamicDim(index: axis) &&
1515 axisSum != outputShape.getDimSize(index: axis))
1516 return emitOpError(message: "requires sum of axis dimensions of input1 "
1517 "equal to output axis dimension, got ")
1518 << axisSum << " and " << outputShape.getDimSize(index: axis);
1519 }
1520
1521 return success();
1522}
1523
1524LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1525 MLIRContext *context, ::std::optional<Location> location,
1526 ValueShapeRange operands, DictionaryAttr attributes,
1527 OpaqueProperties properties, RegionRange regions,
1528 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1529 auto elementType = IntegerType::get(context, /*width=*/1);
1530
1531 llvm::SmallVector<int64_t> outShape;
1532 if (resolveBroadcastShape(operands, outShape).failed()) {
1533 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(elementType));
1534 return success();
1535 }
1536
1537 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outShape, elementType));
1538 return success();
1539}
1540
1541bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1542 if (l.size() != r.size() || l.size() != 1)
1543 return false;
1544 return succeeded(Result: verifyCompatibleShape(type1: l[0], type2: r[0]));
1545}
1546
1547LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1548 MLIRContext *context, ::std::optional<Location> location,
1549 MatMulOp::Adaptor adaptor,
1550 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1551 ShapeAdaptor lhsShape(adaptor.getA().getType());
1552 ShapeAdaptor rhsShape(adaptor.getB().getType());
1553
1554 // All shapes are dynamic.
1555 SmallVector<int64_t> outShape;
1556 outShape.resize(N: 3, NV: ShapedType::kDynamic);
1557
1558 if (lhsShape.hasRank()) {
1559 outShape[0] = lhsShape.getDimSize(index: 0);
1560 outShape[1] = lhsShape.getDimSize(index: 1);
1561 }
1562
1563 if (rhsShape.hasRank()) {
1564 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(index: 0)
1565 : outShape[0];
1566 outShape[2] = rhsShape.getDimSize(index: 2);
1567 }
1568
1569 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outShape));
1570 return success();
1571}
1572
1573LogicalResult MatMulOp::verify() {
1574 auto aType = llvm::dyn_cast<ShapedType>(Val: getA().getType());
1575 auto bType = llvm::dyn_cast<ShapedType>(Val: getB().getType());
1576
1577 // Must be shaped tensor types
1578 if (!aType)
1579 return emitOpError(message: "expect a shaped tensor for input a, got ")
1580 << getA().getType();
1581
1582 if (!bType)
1583 return emitOpError(message: "expect a shaped tensor for input b, got ")
1584 << getB().getType();
1585
1586 auto aElementType = aType.getElementType();
1587 auto bElementType = bType.getElementType();
1588
1589 auto aQuantizedEType =
1590 llvm::dyn_cast<quant::UniformQuantizedType>(Val&: aElementType);
1591 auto bQuantizedEType =
1592 llvm::dyn_cast<quant::UniformQuantizedType>(Val&: bElementType);
1593
1594 if (aQuantizedEType || bQuantizedEType) {
1595 if (!aQuantizedEType || !bQuantizedEType) {
1596 return emitOpError(message: "expect operands to be both quantized or both not "
1597 "quantized, got ")
1598 << aElementType << " and " << bElementType;
1599 }
1600 // both a and b have quantized element types
1601 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1602 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1603 if (aQuantWidth != bQuantWidth) {
1604 return emitOpError(message: "expect quantized operands to have same widths, got ")
1605 << aQuantWidth << " and " << bQuantWidth;
1606 }
1607 } else {
1608 // non-quantized element types
1609 if (aElementType != bElementType) {
1610 return emitOpError(message: "expect same element type for inputs a and b, got ")
1611 << aElementType << " and " << bElementType;
1612 }
1613 }
1614
1615 // check a_zp and b_zp
1616 auto aEType = getStorageElementTypeOrSelf(type: aType);
1617 auto aZpEType = getStorageElementTypeOrSelf(type: getAZp().getType());
1618 if (aEType != aZpEType) {
1619 return emitOpError(message: "expect input a and a_zp have the same "
1620 "element type, got ")
1621 << aEType << " and " << aZpEType;
1622 }
1623
1624 auto bEType = getStorageElementTypeOrSelf(type: bType);
1625 auto bZpEType = getStorageElementTypeOrSelf(type: getBZp().getType());
1626 if (bEType != bZpEType) {
1627 return emitOpError(message: "expect input b and b_zp have the same "
1628 "element type, got ")
1629 << bEType << " and " << bZpEType;
1630 }
1631
1632 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1633 if (succeeded(Result: maybeAZp) && verifyAZeroPoint(zp: *maybeAZp).failed())
1634 return failure();
1635
1636 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1637 if (succeeded(Result: maybeBZp) && verifyBZeroPoint(zp: *maybeBZp).failed())
1638 return failure();
1639
1640 return success();
1641}
1642
1643LogicalResult tosa::PadOp::inferReturnTypeComponents(
1644 MLIRContext *context, ::std::optional<Location> location,
1645 PadOp::Adaptor adaptor,
1646 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1647 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1648 auto paddingRank =
1649 cast<tosa::shapeType>(Val: adaptor.getPadding().getType()).getRank();
1650 SmallVector<int64_t> outputShape;
1651
1652 // If the input rank is unknown, we can infer the output rank using the
1653 // padding shape's rank divided by 2.
1654 if (!inputShape.hasRank()) {
1655 outputShape.resize(N: paddingRank / 2, NV: ShapedType::kDynamic);
1656 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
1657 return success();
1658 }
1659
1660 SmallVector<int64_t> paddingValues;
1661 // If the paddings value is not a constant, all dimensions must be dynamic.
1662 if (!tosa::getConstShapeValues(op: adaptor.getPadding().getDefiningOp(),
1663 result_shape&: paddingValues)) {
1664 outputShape.resize(N: inputShape.getRank(), NV: ShapedType::kDynamic);
1665 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
1666 return success();
1667 }
1668
1669 outputShape.reserve(N: inputShape.getRank());
1670 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1671 if (inputShape.isDynamicDim(index: i)) {
1672 outputShape.push_back(Elt: ShapedType::kDynamic);
1673 continue;
1674 }
1675 auto padFront = paddingValues[i * 2];
1676 auto padBack = paddingValues[i * 2 + 1];
1677 if (padFront < 0 || padBack < 0) {
1678 // if either padding for dim i is -1, output dim is unknown
1679 outputShape.push_back(Elt: ShapedType::kDynamic);
1680 continue;
1681 }
1682
1683 outputShape.push_back(Elt: inputShape.getDimSize(index: i) + padFront + padBack);
1684 }
1685
1686 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
1687 return success();
1688}
1689
1690LogicalResult tosa::PadOp::verify() {
1691 if (verifySameElementTypes(op: *this, /* inType = */ getInput1().getType(),
1692 /* outType = */ getOutput().getType())
1693 .failed()) {
1694 return failure();
1695 }
1696
1697 if (auto padConst = getPadConst()) {
1698 if (verifySameElementTypes(op: *this, /* inType = */ padConst.getType(),
1699 /* outType = */ getOutput().getType())
1700 .failed()) {
1701 return failure();
1702 }
1703 }
1704
1705 RankedTensorType inputType =
1706 llvm::dyn_cast<RankedTensorType>(Val: getInput1().getType());
1707 RankedTensorType outputType =
1708 llvm::dyn_cast<RankedTensorType>(Val: getOutput().getType());
1709 if (!inputType || !outputType)
1710 return success();
1711
1712 auto inputRank = inputType.getRank();
1713 auto outputRank = outputType.getRank();
1714 if (inputRank != outputRank)
1715 return emitOpError() << "expect same input and output tensor rank, but got "
1716 << "inputRank: " << inputRank
1717 << ", outputRank: " << outputRank;
1718
1719 DenseIntElementsAttr paddingAttr;
1720 if (!matchPattern(value: getPadding(), pattern: m_Constant(bind_value: &paddingAttr))) {
1721 return failure();
1722 }
1723
1724 auto paddingValues = paddingAttr.getValues<APInt>();
1725 if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
1726 return emitOpError() << "padding tensor must have " << inputRank
1727 << " * 2 = " << inputRank * 2 << " elements, but got "
1728 << paddingValues.size();
1729
1730 auto inputShape = inputType.getShape();
1731 auto outputShape = outputType.getShape();
1732
1733 for (int64_t i = 0; i < inputRank; ++i) {
1734 int64_t padStart = paddingValues[i * 2].getSExtValue();
1735 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
1736
1737 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
1738 return emitOpError()
1739 << "invalid padding values at dimension " << i
1740 << ": values must be non-negative or -1 for dynamic padding, got ["
1741 << padStart << ", " << padEnd << "]";
1742 }
1743
1744 // Skip shape verification for dynamic input/output
1745 if (inputShape[i] == ShapedType::kDynamic ||
1746 outputShape[i] == ShapedType::kDynamic)
1747 continue;
1748
1749 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
1750 return emitOpError() << "mismatch in output shape at dimension " << i
1751 << ": expected " << inputShape[i] << " + "
1752 << padStart << " + " << padEnd << " = "
1753 << (inputShape[i] + padStart + padEnd)
1754 << ", but got " << outputShape[i];
1755 }
1756 }
1757
1758 return success();
1759}
1760
1761LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1762 MLIRContext *context, ::std::optional<Location> location,
1763 SliceOp::Adaptor adaptor,
1764 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1765
1766 Type inputType = getElementTypeOrSelf(type: adaptor.getInput1().getType());
1767 SmallVector<int64_t> start;
1768 SmallVector<int64_t> size;
1769
1770 if (!tosa::getConstShapeValues(op: adaptor.getStart().getDefiningOp(), result_shape&: start) ||
1771 !tosa::getConstShapeValues(op: adaptor.getSize().getDefiningOp(), result_shape&: size)) {
1772 auto rank = cast<tosa::shapeType>(Val: adaptor.getSize().getType()).getRank();
1773 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1774 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(fallback, inputType));
1775 return success();
1776 }
1777
1778 // if size[i] is -1, all remaining elements in dimension i are included
1779 // in the slice, similar to TF.
1780 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1781 // initialize outputShape to all unknown
1782 SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
1783 if (inputShape.hasRank()) {
1784 for (size_t i = 0; i < size.size(); i++) {
1785 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
1786 (ShapedType::isDynamic(dValue: inputShape.getDimSize(index: i)) ||
1787 start[i] < inputShape.getDimSize(index: i))) {
1788 // size[i] is not 0 and not < -1, and start[i] is in valid range
1789 if (ShapedType::isDynamic(dValue: inputShape.getDimSize(index: i))) {
1790 // input shape has unknown dim[i] - only valid if size[i] > 0
1791 if (size[i] > 0) {
1792 outputShape[i] = size[i];
1793 }
1794 } else {
1795 // input shape has known dim[i]
1796 if (size[i] == -1) {
1797 outputShape[i] = inputShape.getDimSize(index: i) - start[i];
1798 } else if (start[i] + size[i] <= inputShape.getDimSize(index: i)) {
1799 // start[i] + size[i] is within bound of input shape's dim[i]
1800 outputShape[i] = size[i];
1801 }
1802 }
1803 }
1804 }
1805 } else {
1806 outputShape = convertToMlirShape(shape: size);
1807 }
1808 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
1809 return success();
1810}
1811
1812LogicalResult tosa::SliceOp::verify() {
1813 if (verifySameElementTypes(op: *this, /* inType = */ getInput1().getType(),
1814 /* outType = */ getOutput().getType())
1815 .failed())
1816 return failure();
1817
1818 const ShapeAdaptor inputShape(getInput1().getType());
1819 if (inputShape.hasRank()) {
1820 const auto inputRank = inputShape.getRank();
1821 const ShapeAdaptor outputShape(getOutput().getType());
1822 if (outputShape.hasRank() && inputRank != outputShape.getRank())
1823 return emitOpError(
1824 message: "expect input1 and output to have the same ranks, got ")
1825 << inputRank << " and " << outputShape.getRank();
1826
1827 const auto startShapeRank =
1828 llvm::cast<tosa::shapeType>(Val: getStart().getType()).getRank();
1829 if (inputRank != startShapeRank)
1830 return emitOpError(message: "length of start is not equal to rank of input shape");
1831
1832 const auto sizeShapeRank =
1833 llvm::cast<tosa::shapeType>(Val: getSize().getType()).getRank();
1834 if (inputRank != sizeShapeRank)
1835 return emitOpError(message: "length of size is not equal to rank of input shape");
1836 }
1837
1838 return success();
1839}
1840
1841LogicalResult tosa::MulOp::inferReturnTypeComponents(
1842 MLIRContext *context, ::std::optional<Location> location,
1843 ValueShapeRange operands, DictionaryAttr attributes,
1844 OpaqueProperties properties, RegionRange regions,
1845 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1846 // mul op's output shape only depend on input1 and input2, not on shift
1847 ValueShapeRange twoInputs = operands.drop_back();
1848 llvm::SmallVector<int64_t> outShape;
1849 if (resolveBroadcastShape(operands: twoInputs, outShape).failed()) {
1850 inferredReturnShapes.push_back(Elt: ShapedTypeComponents());
1851 } else {
1852 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outShape));
1853 }
1854 return success();
1855}
1856
1857LogicalResult tosa::MulOp::verify() {
1858 const Value output = getOutput();
1859 auto resElemType = getElementTypeOrSelf(val: output);
1860
1861 // Verify if the element type among operands and result match tosa
1862 // specification.
1863 if (auto resIntType = dyn_cast<IntegerType>(Val&: resElemType)) {
1864 IntegerType lhsIntType =
1865 dyn_cast<IntegerType>(Val: getElementTypeOrSelf(val: getInput1()));
1866 IntegerType rhsIntType =
1867 dyn_cast<IntegerType>(Val: getElementTypeOrSelf(val: getInput2()));
1868 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
1869 return emitOpError(message: "requires the same element type for all operands");
1870
1871 // Though the spec requires the element type of result to be i32, a more
1872 // relaxed way is provided at dialect level for easier cooperating with
1873 // other dialects.
1874 if (lhsIntType.getWidth() > resIntType.getWidth())
1875 return emitOpError(message: "invalid data type size for operands or result");
1876
1877 } else {
1878 // For other supported type, the spec requires requires the same element
1879 // type for all operands (excludes `shift` operand) and results.
1880 for (int i = 0; i < 2; ++i) {
1881 if (getElementTypeOrSelf(val: getOperand(i)) != resElemType)
1882 return emitOpError(
1883 message: "requires the same element type for all operands and results");
1884 }
1885
1886 // verify shift has value 0 for non-integer types
1887 ElementsAttr shift_elem;
1888 if (matchPattern(value: getShift(), pattern: m_Constant(bind_value: &shift_elem))) {
1889 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1890 if (shift != 0) {
1891 return emitOpError() << "require shift to be 0 for float type";
1892 }
1893 }
1894 }
1895
1896 // Verify the op has same ranks for all main operands (excludes extra operands
1897 // such as shift of mul op, so this is the only difference with the built-in
1898 // `SameOperandsAndResultRank` trait) and results types, if known.
1899 TypeRange operandTypes = getOperandTypes();
1900 ShapedType aType = cast<ShapedType>(Val: operandTypes[0]);
1901 ShapedType bType = cast<ShapedType>(Val: operandTypes[1]);
1902
1903 const bool aHasRank = aType.hasRank();
1904 const bool bHasRank = bType.hasRank();
1905 if (aHasRank && bHasRank) {
1906 const int64_t aRank = aType.getRank();
1907 const int64_t bRank = bType.getRank();
1908 if (aRank != bRank)
1909 return emitOpError(message: "a and b operands don't have matching ranks, got ")
1910 << aRank << " and " << bRank;
1911
1912 // check for broadcast compatible shapes
1913 SmallVector<int64_t> resultShape;
1914 if (!mlir::OpTrait::util::getBroadcastedShape(
1915 shape1: aType.getShape(), shape2: bType.getShape(), resultShape))
1916 return emitOpError(message: "a and b operands don't have broadcast-compatible "
1917 "shapes, got ")
1918 << aType << " and " << bType;
1919 }
1920
1921 ShapedType resultType = cast<ShapedType>(Val: output.getType());
1922 if (!resultType.hasRank())
1923 return success();
1924
1925 const int64_t resultRank = resultType.getRank();
1926 if (aHasRank && resultRank != aType.getRank())
1927 return emitOpError(message: "result type has different rank than a, got ")
1928 << resultRank << " vs " << aType.getRank();
1929 if (bHasRank && resultRank != bType.getRank())
1930 return emitOpError(message: "result type has different rank than b, got ")
1931 << resultRank << " vs " << bType.getRank();
1932
1933 return success();
1934}
1935
1936LogicalResult tosa::TableOp::inferReturnTypeComponents(
1937 MLIRContext *context, ::std::optional<Location> location,
1938 TableOp::Adaptor adaptor,
1939 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1940 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1941
1942 if (!inputShape.hasRank()) {
1943 inferredReturnShapes.push_back(Elt: ShapedTypeComponents());
1944 return success();
1945 }
1946
1947 inferredReturnShapes.resize(N: 1);
1948 inputShape.getDims(res&: inferredReturnShapes[0]);
1949 return success();
1950}
1951
1952LogicalResult tosa::TableOp::verify() {
1953 TensorType inputType = getInput1().getType();
1954 TensorType outputType = getOutput().getType();
1955
1956 if (inputType.hasRank() && outputType.hasRank() &&
1957 inputType.getRank() != outputType.getRank())
1958 return emitOpError()
1959 << "expected input tensor rank to equal result tensor rank";
1960
1961 auto inputDims = inputType.getShape();
1962 auto outputDims = outputType.getShape();
1963 for (auto it : llvm::enumerate(First: llvm::zip(t&: inputDims, u&: outputDims))) {
1964 int64_t dim = it.index();
1965 auto [inputDim, outputDim] = it.value();
1966 if (ShapedType::isStatic(dValue: outputDim) && outputDim != inputDim) {
1967 return emitOpError() << "dim(result, " << dim << ") = " << outputDim
1968 << " doesn't match dim(input, " << dim
1969 << ") = " << inputDim;
1970 }
1971 }
1972 return success();
1973}
1974
1975LogicalResult
1976tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
1977 // Multiples must be constants.
1978 DenseIntElementsAttr multiplesAttr;
1979 if (!matchPattern(value: getMultiples(), pattern: m_Constant(bind_value: &multiplesAttr)))
1980 return failure();
1981 multiples = llvm::to_vector(
1982 Range: llvm::map_range(C: multiplesAttr.getValues<APInt>(),
1983 F: [](const APInt &val) { return val.getSExtValue(); }));
1984 return success();
1985}
1986
1987LogicalResult tosa::TileOp::inferReturnTypeComponents(
1988 MLIRContext *context, ::std::optional<Location> location,
1989 TileOp::Adaptor adaptor,
1990 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1991 Type inputType = getElementTypeOrSelf(type: adaptor.getInput1().getType());
1992 SmallVector<int64_t> multiples;
1993 if (!tosa::getConstShapeValues(op: adaptor.getMultiples().getDefiningOp(),
1994 result_shape&: multiples)) {
1995 auto rank =
1996 cast<tosa::shapeType>(Val: adaptor.getMultiples().getType()).getRank();
1997 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1998 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(fallback, inputType));
1999 return success();
2000 } else {
2001 multiples = convertToMlirShape(shape: multiples);
2002 }
2003
2004 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2005 SmallVector<int64_t> outputShape;
2006 if (!inputShape.hasRank()) {
2007 outputShape.resize(N: multiples.size(), NV: ShapedType::kDynamic);
2008 inferredReturnShapes.push_back(
2009 Elt: ShapedTypeComponents(outputShape, inputType));
2010 return success();
2011 } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
2012 return failure();
2013
2014 // Any non dynamic dimension can be multiplied to a known size.
2015 outputShape.reserve(N: multiples.size());
2016 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2017 if (multiples[i] == ShapedType::kDynamic) {
2018 outputShape.push_back(Elt: ShapedType::kDynamic);
2019 } else {
2020 int64_t dim = inputShape.getDimSize(index: i);
2021 if (dim != ShapedType::kDynamic)
2022 dim *= multiples[i];
2023 outputShape.push_back(Elt: dim);
2024 }
2025 }
2026
2027 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape, inputType));
2028 return success();
2029}
2030
2031LogicalResult tosa::TileOp::verify() {
2032 if (verifySameElementTypes(op: *this, /* intype = */ inType: getInput1().getType(),
2033 /* outType = */ getOutput().getType())
2034 .failed()) {
2035 return failure();
2036 }
2037 ShapedType inputType = llvm::cast<ShapedType>(Val: getInput1().getType());
2038 ShapedType outputType = llvm::cast<ShapedType>(Val: getType());
2039
2040 shapeType multiplesType =
2041 llvm::cast<tosa::shapeType>(Val: getMultiples().getType());
2042
2043 auto multiplesRank = multiplesType.getRank();
2044
2045 if (inputType.hasRank()) {
2046 if (inputType.getRank() != multiplesRank)
2047 return emitOpError(message: "expect 'multiples' to have rank ")
2048 << inputType.getRank() << " but got " << multiplesRank << ".";
2049 if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
2050 return emitOpError(message: "expect same input and output tensor rank.");
2051 } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2052 return emitOpError(message: "expect 'multiples' array to have length ")
2053 << outputType.getRank() << " but got " << multiplesRank << ".";
2054
2055 SmallVector<int64_t> multiples;
2056 if (getConstantMultiples(multiples).succeeded() &&
2057 llvm::any_of(Range&: multiples, P: [](int64_t v) { return v <= 0 && v != -1; }))
2058 return emitOpError(
2059 message: "expect element of 'multiples' to be positive integer or -1.");
2060
2061 return success();
2062}
2063
2064bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2065 if (l.size() != r.size() || l.size() != 1)
2066 return false;
2067 return getElementTypeOrSelf(type: l[0]) == getElementTypeOrSelf(type: r[0]);
2068}
2069
2070LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2071 MLIRContext *context, ::std::optional<Location> location,
2072 ReshapeOp::Adaptor adaptor,
2073 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2074 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2075 Type inputType = getElementTypeOrSelf(type: adaptor.getInput1().getType());
2076 llvm::SmallVector<int64_t> newShapeValue;
2077 if (!tosa::getConstShapeValues(op: adaptor.getShape().getDefiningOp(),
2078 result_shape&: newShapeValue)) {
2079 auto rank = cast<tosa::shapeType>(Val: adaptor.getShape().getType()).getRank();
2080 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2081 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(fallback, inputType));
2082 return success();
2083 } else {
2084 newShapeValue = convertToMlirShape(shape: newShapeValue);
2085 }
2086
2087 // We cannot infer from the total number of elements so we must take the
2088 // shape attribute as exact.
2089 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2090 inferredReturnShapes.push_back(
2091 Elt: ShapedTypeComponents(newShapeValue, inputType));
2092 return success();
2093 }
2094
2095 // Determine the number of elements covered by the slice of all static
2096 // dimensions. This allows us to infer the length of the remaining dynamic
2097 // dimension.
2098 int64_t numElements = inputShape.getNumElements();
2099 int64_t staticMul = 1;
2100 for (auto val : newShapeValue) {
2101 if (ShapedType::isStatic(dValue: val)) {
2102 staticMul *= val;
2103 }
2104 }
2105
2106 // Determine the length of the dynamic dimension.
2107 for (auto &val : newShapeValue) {
2108 if (ShapedType::isDynamic(dValue: val))
2109 val = numElements / staticMul;
2110 }
2111
2112 inferredReturnShapes.push_back(
2113 Elt: ShapedTypeComponents(newShapeValue, inputType));
2114 return success();
2115}
2116
2117llvm::LogicalResult tosa::ReshapeOp::verify() {
2118 if (verifySameElementTypes(op: *this, /* inType = */ getInput1().getType(),
2119 /* outType = */ getOutput().getType())
2120 .failed()) {
2121 return failure();
2122 }
2123 TensorType inputType = getInput1().getType();
2124
2125 SmallVector<int64_t> shapeValues;
2126 if (!tosa::getConstShapeValues(op: getShape().getDefiningOp(), result_shape&: shapeValues)) {
2127 // skip following checks if shape is not constant
2128 return mlir::success();
2129 }
2130
2131 int missingDims = llvm::count(Range&: shapeValues, Element: -1);
2132 if (missingDims > 1)
2133 return emitOpError() << "expected at most one target dimension to be -1";
2134
2135 const auto outputType = dyn_cast<RankedTensorType>(Val: getType());
2136 if (!outputType)
2137 return success();
2138
2139 if ((int64_t)shapeValues.size() != outputType.getRank())
2140 return emitOpError() << "new shape does not match result rank";
2141
2142 for (auto [newShapeDim, outputShapeDim] :
2143 zip(t&: shapeValues, u: outputType.getShape())) {
2144 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2145 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2146 return emitOpError() << "new shape is inconsistent with result shape";
2147
2148 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2149 return emitOpError() << "new shape has invalid tensor dimension size "
2150 << newShapeDim;
2151 }
2152
2153 if (inputType.hasStaticShape()) {
2154 int64_t inputElementsNum = inputType.getNumElements();
2155 if (outputType.hasStaticShape()) {
2156 int64_t outputElementsNum = outputType.getNumElements();
2157 if (inputElementsNum != outputElementsNum) {
2158 return emitOpError() << "cannot reshape " << inputElementsNum
2159 << " elements into " << outputElementsNum;
2160 }
2161 }
2162
2163 int64_t newShapeElementsNum = std::accumulate(
2164 first: shapeValues.begin(), last: shapeValues.end(), init: 1LL,
2165 binary_op: [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
2166 bool isStaticNewShape =
2167 llvm::all_of(Range&: shapeValues, P: [](int64_t s) { return s > 0; });
2168 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2169 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2170 return emitOpError() << "cannot reshape " << inputElementsNum
2171 << " elements into " << newShapeElementsNum;
2172 }
2173 }
2174
2175 return mlir::success();
2176}
2177
2178// return failure if val is not a constant
2179// set zp to -1 if val is non-zero float or val is not integer nor float
2180// otherwise set zp to val's constant value
2181static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
2182 ElementsAttr zpAttr;
2183 if (!matchPattern(value: val, pattern: m_Constant(bind_value: &zpAttr))) {
2184 return failure();
2185 }
2186
2187 Type zpElemType = zpAttr.getElementType();
2188
2189 if (llvm::isa<FloatType>(Val: zpElemType)) {
2190 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2191 return 0;
2192 }
2193 // return non-zero value to trigger error check
2194 return -1;
2195 }
2196
2197 if (llvm::isa<IntegerType>(Val: zpElemType)) {
2198 if (signExtend)
2199 return zpAttr.getValues<APInt>()[0].getSExtValue();
2200 else
2201 return zpAttr.getValues<APInt>()[0].getZExtValue();
2202 }
2203
2204 // return non-zero value to trigger error check
2205 return -1;
2206}
2207
2208template <typename T>
2209static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
2210 const std::string &operand) {
2211 Type zpElemType = getElementTypeOrSelf(val);
2212
2213 if (!zpElemType.isInteger(width: 8) && zp != 0) {
2214 // convert operand to lower case for error message
2215 std::string lower = operand;
2216 std::transform(first: lower.begin(), last: lower.end(), result: lower.begin(), unary_op: ::tolower);
2217 return op.emitOpError()
2218 << lower << " zero point must be zero for non-int8 integer types";
2219 }
2220
2221 return success();
2222}
2223
2224static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
2225 const int64_t &zp,
2226 const std::string &operand) {
2227 bool isInputZp = (operand == "Input");
2228
2229 bool tensorUnsigned =
2230 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2231 StringRef tensorName = isInputZp ? "input" : "output";
2232
2233 Type zpElemType = getElementTypeOrSelf(val: zpVal);
2234
2235 if (zp != 0) {
2236 if (!zpElemType.isInteger(width: 8) &&
2237 !(zpElemType.isInteger(width: 16) && tensorUnsigned)) {
2238 return op.emitOpError()
2239 << "expect " << tensorName << "_zp of 0, got " << zp;
2240 }
2241 if (zpElemType.isInteger(width: 16) && tensorUnsigned && zp != 32768) {
2242 return op.emitOpError() << "expect " << tensorName
2243 << "_zp of 0 or 32768 for unsigned int16 "
2244 << tensorName << ", got " << zp;
2245 }
2246 }
2247
2248 return success();
2249}
2250
2251#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2252 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2253 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2254 } \
2255 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2256 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2257 }
2258
2259ZERO_POINT_HELPER(Conv2DOp, Input, true)
2260ZERO_POINT_HELPER(Conv2DOp, Weight, true)
2261ZERO_POINT_HELPER(Conv3DOp, Input, true)
2262ZERO_POINT_HELPER(Conv3DOp, Weight, true)
2263ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
2264ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
2265ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
2266ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
2267ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
2268ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
2269ZERO_POINT_HELPER(MatMulOp, A, true)
2270ZERO_POINT_HELPER(MatMulOp, B, true)
2271ZERO_POINT_HELPER(NegateOp, Input1, true)
2272ZERO_POINT_HELPER(NegateOp, Output, true)
2273ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
2274ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
2275#undef ZERO_POINT_HELPER
2276
2277LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2278 MLIRContext *context, ::std::optional<Location> location,
2279 TransposeOp::Adaptor adaptor,
2280 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2281 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2282
2283 // If input rank and permutation length is unknown, the output rank is
2284 // unknown.
2285 if (!inputShape.hasRank()) {
2286 inferredReturnShapes.push_back(Elt: ShapedTypeComponents());
2287 return success();
2288 }
2289
2290 const auto inputRank = inputShape.getRank();
2291
2292 // This would imply the number of permutations does not match the rank of
2293 // the input which is illegal.
2294 if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
2295 return failure();
2296 }
2297
2298 SmallVector<int64_t> outputShape;
2299 // Rank-0 means no permutations matter.
2300 if (inputRank == 0) {
2301 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
2302 return success();
2303 }
2304
2305 // Check whether the input dimensions are all the same.
2306 bool allTheSame = true;
2307 for (int i = 1, s = inputRank; i < s; i++) {
2308 if (inputShape.getDimSize(index: 0) != inputShape.getDimSize(index: i)) {
2309 allTheSame = false;
2310 break;
2311 }
2312 }
2313
2314 // If all of the input dimensions are the same we don't care about the
2315 // permutation.
2316 if (allTheSame) {
2317 outputShape.resize(N: inputRank, NV: inputShape.getDimSize(index: 0));
2318 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
2319 return success();
2320 }
2321
2322 outputShape.resize(N: inputRank, NV: ShapedType::kDynamic);
2323
2324 // Constant permutation values must be within the input rank.
2325 if (llvm::any_of(Range: adaptor.getPerms(),
2326 P: [inputRank](const auto i) { return i >= inputRank; }))
2327 return failure();
2328
2329 outputShape.reserve(N: inputRank);
2330 for (int i = 0, s = inputRank; i < s; i++) {
2331 outputShape[i] = inputShape.getDimSize(index: adaptor.getPerms()[i]);
2332 }
2333
2334 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
2335 return success();
2336}
2337
2338LogicalResult tosa::TransposeOp::verify() {
2339 if (verifySameElementTypes(op: *this, /* inType = */ getInput1().getType(),
2340 /* outType = */ getOutput().getType())
2341 .failed()) {
2342 return failure();
2343 }
2344
2345 const ShapeAdaptor inputShape(getInput1().getType());
2346 const ShapeAdaptor outputShape(getOutput().getType());
2347
2348 const llvm::ArrayRef<int32_t> constantPerms = getPerms();
2349
2350 if (inputShape.hasRank() &&
2351 constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
2352 return emitOpError() << "expected perms attribute to have size "
2353 << inputShape.getRank()
2354 << " (input rank) but got size "
2355 << constantPerms.size();
2356
2357 if (inputShape.hasRank() && outputShape.hasRank() &&
2358 inputShape.getRank() != outputShape.getRank())
2359 return emitOpError()
2360 << "expected input tensor rank to equal result tensor rank";
2361
2362 if (outputShape.hasRank() &&
2363 constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
2364 return emitOpError() << "expected perms attribute to have size "
2365 << outputShape.getRank()
2366 << " (output rank) but got size "
2367 << constantPerms.size();
2368
2369 if (!llvm::all_of(Range: constantPerms,
2370 P: [&constantPerms](int32_t s) {
2371 return s >= 0 &&
2372 static_cast<size_t>(s) < constantPerms.size();
2373 }) ||
2374 !isPermutationVector(interchange: llvm::to_vector(Range: llvm::map_range(
2375 C: constantPerms, F: [](int32_t v) -> int64_t { return v; }))))
2376 return emitOpError() << "expected valid permutation indices";
2377
2378 // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
2379 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2380 inputShape.getNumElements() != outputShape.getNumElements())
2381 return emitOpError() << "expected input1 and output to have same numbers "
2382 "of elements, got "
2383 << inputShape.getNumElements() << " and "
2384 << outputShape.getNumElements();
2385
2386 // Verify that the types of the input and output tensors are properly
2387 // permuted.
2388 if (inputShape.hasRank() && outputShape.hasRank()) {
2389 for (auto i = 0; i < outputShape.getRank(); i++) {
2390 if (inputShape.isDynamicDim(index: constantPerms[i]) ||
2391 outputShape.isDynamicDim(index: i))
2392 continue;
2393
2394 if (inputShape.getDimSize(index: constantPerms[i]) != outputShape.getDimSize(index: i))
2395 return emitOpError()
2396 << "expected output tensor dim " << i << " to match "
2397 << "input dim " << constantPerms[i] << " with value of "
2398 << inputShape.getDimSize(index: constantPerms[i]);
2399 }
2400 }
2401
2402 return success();
2403}
2404
2405LogicalResult TransposeOp::reifyResultShapes(
2406 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2407
2408 const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2409
2410 Value input = getInput1();
2411 auto inputType = cast<TensorType>(Val: input.getType());
2412
2413 SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2414 for (auto dim : transposePerms) {
2415 int32_t dimInInput = transposePerms[dim];
2416 if (inputType.isDynamicDim(idx: dimInInput))
2417 returnedDims[dim] =
2418 builder.create<tensor::DimOp>(location: getLoc(), args&: input, args&: dimInInput)
2419 .getResult();
2420 else
2421 returnedDims[dim] =
2422 builder.getIndexAttr(value: inputType.getDimSize(idx: dimInInput));
2423 }
2424
2425 reifiedReturnShapes.emplace_back(Args: std::move(returnedDims));
2426 return success();
2427}
2428
2429LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2430 MLIRContext *context, ::std::optional<Location> location,
2431 GatherOp::Adaptor adaptor,
2432 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2433 llvm::SmallVector<int64_t> outputShape;
2434 outputShape.resize(N: 3, NV: ShapedType::kDynamic);
2435
2436 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2437 if (valuesShape.hasRank()) {
2438 outputShape[0] = valuesShape.getDimSize(index: 0);
2439 outputShape[2] = valuesShape.getDimSize(index: 2);
2440 }
2441
2442 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2443 if (indicesShape.hasRank()) {
2444 if (outputShape[0] == ShapedType::kDynamic)
2445 outputShape[0] = indicesShape.getDimSize(index: 0);
2446 if (outputShape[1] == ShapedType::kDynamic)
2447 outputShape[1] = indicesShape.getDimSize(index: 1);
2448 }
2449
2450 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
2451 return success();
2452}
2453
2454LogicalResult tosa::GatherOp::verify() {
2455 if (verifySameElementTypes(op: *this, /* inType = */ getValues().getType(),
2456 /* outType = */ getOutput().getType())
2457 .failed()) {
2458 return failure();
2459 }
2460
2461 const ShapeAdaptor valuesShape(getValues().getType());
2462 const ShapeAdaptor indicesShape(getIndices().getType());
2463 const ShapeAdaptor outputShape(getOutput().getType());
2464
2465 int64_t N = ShapedType::kDynamic;
2466 int64_t W = ShapedType::kDynamic;
2467 int64_t C = ShapedType::kDynamic;
2468
2469 if (valuesShape.hasRank()) {
2470 N = valuesShape.getDimSize(index: 0);
2471 C = valuesShape.getDimSize(index: 2);
2472 }
2473 if (indicesShape.hasRank()) {
2474 const int64_t indicesN = indicesShape.getDimSize(index: 0);
2475 W = indicesShape.getDimSize(index: 1);
2476 if (N == ShapedType::kDynamic)
2477 N = indicesN;
2478 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2479 return emitOpError() << "requires indices dimension 0 to have size " << N
2480 << ", got " << indicesN;
2481 }
2482 if (outputShape.hasRank()) {
2483 const int64_t outputN = outputShape.getDimSize(index: 0);
2484 const int64_t outputW = outputShape.getDimSize(index: 1);
2485 const int64_t outputC = outputShape.getDimSize(index: 2);
2486 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2487 N != outputN)
2488 return emitOpError() << "requires output dimension 0 to have size " << N
2489 << ", got " << outputN;
2490
2491 if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2492 W != outputW)
2493 return emitOpError() << "requires output dimension 1 to have size " << W
2494 << ", got " << outputW;
2495 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2496 C != outputC)
2497 return emitOpError() << "requires output dimension 2 to have size " << C
2498 << ", got " << outputC;
2499 }
2500 return success();
2501}
2502
2503LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2504 MLIRContext *context, ::std::optional<Location> location,
2505 ResizeOp::Adaptor adaptor,
2506 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2507 llvm::SmallVector<int64_t, 4> outputShape;
2508 outputShape.resize(N: 4, NV: ShapedType::kDynamic);
2509
2510 ShapeAdaptor inputShape(adaptor.getInput().getType());
2511 if (!inputShape.hasRank())
2512 return failure();
2513
2514 outputShape[0] = inputShape.getDimSize(index: 0);
2515 outputShape[3] = inputShape.getDimSize(index: 3);
2516 int64_t inputHeight = inputShape.getDimSize(index: 1);
2517 int64_t inputWidth = inputShape.getDimSize(index: 2);
2518
2519 if ((inputHeight == ShapedType::kDynamic) ||
2520 (inputWidth == ShapedType::kDynamic))
2521 return failure();
2522
2523 SmallVector<int64_t> scaleInt, offsetInt, borderInt;
2524 if (!tosa::getConstShapeValues(op: adaptor.getScale().getDefiningOp(),
2525 result_shape&: scaleInt) ||
2526 !tosa::getConstShapeValues(op: adaptor.getOffset().getDefiningOp(),
2527 result_shape&: offsetInt) ||
2528 !tosa::getConstShapeValues(op: adaptor.getBorder().getDefiningOp(),
2529 result_shape&: borderInt)) {
2530 return failure();
2531 }
2532
2533 // Compute the output shape based on attributes: scale, offset, and border.
2534 const int64_t outputHeight =
2535 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2536 scaleInt[1]) +
2537 1;
2538
2539 const int64_t outputWidth =
2540 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2541 scaleInt[3]) +
2542 1;
2543
2544 if (outputHeight < 0 || outputWidth < 0) {
2545 return emitOptionalError(
2546 loc: location,
2547 args: "calculated output height and width must be non-negative, "
2548 "got height = ",
2549 args: outputHeight, args: ", width = ", args: outputWidth);
2550 }
2551
2552 outputShape[1] = outputHeight;
2553 outputShape[2] = outputWidth;
2554 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
2555 return success();
2556}
2557
2558LogicalResult tosa::ResizeOp::verify() {
2559 const Value input = getInput();
2560 const Value output = getOutput();
2561 const RankedTensorType inputType =
2562 llvm::dyn_cast<RankedTensorType>(Val: input.getType());
2563 const RankedTensorType outputType =
2564 llvm::dyn_cast<RankedTensorType>(Val: output.getType());
2565
2566 SmallVector<int64_t> scaleValues;
2567 SmallVector<int64_t> offsetValues;
2568 SmallVector<int64_t> borderValues;
2569 if (!tosa::getConstShapeValues(op: getScale().getDefiningOp(), result_shape&: scaleValues) ||
2570 !tosa::getConstShapeValues(op: getOffset().getDefiningOp(), result_shape&: offsetValues) ||
2571 !tosa::getConstShapeValues(op: getBorder().getDefiningOp(), result_shape&: borderValues)) {
2572 // Skip following checks if shape is not constant
2573 return success();
2574 }
2575
2576 if (llvm::any_of(Range&: scaleValues, P: [](int64_t s) { return s <= 0; }))
2577 return emitOpError(message: "expect all scale values to be > 0, got ")
2578 << scaleValues;
2579
2580 const int64_t scaleYN = scaleValues[0];
2581 const int64_t scaleYD = scaleValues[1];
2582 const int64_t scaleXN = scaleValues[2];
2583 const int64_t scaleXD = scaleValues[3];
2584
2585 const int64_t offsetY = offsetValues[0];
2586 const int64_t offsetX = offsetValues[1];
2587
2588 const int64_t borderY = borderValues[0];
2589 const int64_t borderX = borderValues[1];
2590
2591 if (!inputType)
2592 return success();
2593 if (!outputType)
2594 return success();
2595
2596 const int64_t oh = outputType.getDimSize(idx: 1);
2597 const int64_t ow = outputType.getDimSize(idx: 2);
2598 const int64_t ih = inputType.getDimSize(idx: 1);
2599 const int64_t iw = inputType.getDimSize(idx: 2);
2600
2601 // Don't check with input height that could be broadcast (ih != 1)
2602 // since Linalg, a consumer of TOSA, expects broadcasting support
2603 // in resize to be available. Taking the cautious approach for now,
2604 // we can consider removing support for broadcasting later.
2605 if (ih != ShapedType::kDynamic && ih != 1) {
2606 const std::optional<int64_t> calculatedOutHeightMinusOne =
2607 idivCheck(lhs: (ih - 1) * scaleYN - offsetY + borderY, rhs: scaleYD);
2608 if (!calculatedOutHeightMinusOne.has_value())
2609 return emitOpError(message: "expected (input_height - 1) * scale_y_n - offset_y + "
2610 "border_y ")
2611 << "to be wholly divisible by scale_y_d, got ((" << ih
2612 << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
2613 << ") / " << scaleYD;
2614 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2615 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2616 return emitOpError(message: "calculated output height did not match expected: ")
2617 << "calculated=" << calculatedOutHeight << ", expected=" << oh;
2618 }
2619
2620 // Don't check with input width that could be broadcast (iw != 1)
2621 // since Linalg, a consumer of TOSA, expects broadcasting support
2622 // in resize to be available. Taking the cautious approach for now,
2623 // we can consider removing support for broadcasting later.
2624 if (iw != ShapedType::kDynamic && iw != 1) {
2625 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2626 const std::optional<int64_t> calculatedOutWidthMinusOne =
2627 idivCheck(lhs: scaledInWidth, rhs: scaleXD);
2628 if (!calculatedOutWidthMinusOne.has_value())
2629 return emitOpError(message: "expected (input_width - 1) * scale_x_n - offset_x + "
2630 "border_x ")
2631 << "to be wholly divisible by scale_x_d, got ((" << iw
2632 << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
2633 << ") / " << scaleXD;
2634 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2635 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2636 return emitOpError(message: "calculated output width did not match expected: ")
2637 << "calculated=" << calculatedOutWidth << ", expected=" << ow;
2638 }
2639
2640 return success();
2641}
2642
2643LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2644 MLIRContext *context, ::std::optional<Location> location,
2645 ScatterOp::Adaptor adaptor,
2646 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2647 llvm::SmallVector<int64_t> outputShape;
2648 outputShape.resize(N: 3, NV: ShapedType::kDynamic);
2649
2650 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2651 if (valuesInShape.hasRank()) {
2652 outputShape[0] = valuesInShape.getDimSize(index: 0);
2653 outputShape[1] = valuesInShape.getDimSize(index: 1);
2654 outputShape[2] = valuesInShape.getDimSize(index: 2);
2655 }
2656
2657 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2658 if (indicesShape.hasRank()) {
2659 if (outputShape[0] == ShapedType::kDynamic)
2660 outputShape[0] = indicesShape.getDimSize(index: 0);
2661 }
2662
2663 ShapeAdaptor inputShape(adaptor.getInput().getType());
2664 if (inputShape.hasRank()) {
2665 if (outputShape[0] == ShapedType::kDynamic)
2666 outputShape[0] = inputShape.getDimSize(index: 0);
2667 if (outputShape[2] == ShapedType::kDynamic)
2668 outputShape[2] = inputShape.getDimSize(index: 2);
2669 }
2670
2671 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
2672 return success();
2673}
2674
2675LogicalResult tosa::ScatterOp::verify() {
2676 if (verifySameElementTypes(op: *this, /* inType = */ getValuesIn().getType(),
2677 /* outType = */ getValuesOut().getType())
2678 .failed() ||
2679 verifySameElementTypes(op: *this, /* inType = */ getInput().getType(),
2680 /* outType = */ getValuesOut().getType())
2681 .failed()) {
2682 return failure();
2683 }
2684
2685 const ShapeAdaptor valuesInShape(getValuesIn().getType());
2686 const ShapeAdaptor indicesShape(getIndices().getType());
2687 const ShapeAdaptor inputShape(getInput().getType());
2688 const ShapeAdaptor outputShape(getValuesOut().getType());
2689
2690 int64_t N = ShapedType::kDynamic;
2691 int64_t K = ShapedType::kDynamic;
2692 int64_t W = ShapedType::kDynamic;
2693 int64_t C = ShapedType::kDynamic;
2694 if (valuesInShape.hasRank()) {
2695 N = valuesInShape.getDimSize(index: 0);
2696 K = valuesInShape.getDimSize(index: 1);
2697 C = valuesInShape.getDimSize(index: 2);
2698 }
2699 if (indicesShape.hasRank()) {
2700 const int64_t indicesN = indicesShape.getDimSize(index: 0);
2701 W = indicesShape.getDimSize(index: 1);
2702 if (N == ShapedType::kDynamic)
2703 N = indicesN;
2704 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2705 return emitOpError() << "requires indices dimension 0 to have size " << N
2706 << ", got " << indicesN;
2707 }
2708 if (inputShape.hasRank()) {
2709 const int64_t inputN = inputShape.getDimSize(index: 0);
2710 const int64_t inputW = inputShape.getDimSize(index: 1);
2711 const int64_t inputC = inputShape.getDimSize(index: 2);
2712 if (N == ShapedType::kDynamic)
2713 N = inputN;
2714 else if (inputN != ShapedType::kDynamic && N != inputN)
2715 return emitOpError() << "requires input dimension 0 to have size " << N
2716 << ", got " << inputN;
2717 if (W == ShapedType::kDynamic)
2718 W = inputW;
2719 else if (inputW != ShapedType::kDynamic && W != inputW)
2720 return emitOpError() << "requires input dimension 1 to have size " << W
2721 << ", got " << inputW;
2722
2723 if (C == ShapedType::kDynamic)
2724 C = inputC;
2725 else if (inputC != ShapedType::kDynamic && C != inputC)
2726 return emitOpError() << "requires input dimension 2 to have size " << C
2727 << ", got " << inputC;
2728 }
2729 if (outputShape.hasRank()) {
2730 const int64_t outputN = outputShape.getDimSize(index: 0);
2731 const int64_t outputK = outputShape.getDimSize(index: 1);
2732 const int64_t outputC = outputShape.getDimSize(index: 2);
2733 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2734 N != outputN)
2735 return emitOpError() << "requires values_out dimension 0 to have size "
2736 << N << ", got " << outputN;
2737 if (K == ShapedType::kDynamic)
2738 K = outputK;
2739 else if (outputK != ShapedType::kDynamic && K != outputK)
2740 return emitOpError() << "requires values_out dimension 1 to have size "
2741 << K << ", got " << outputK;
2742 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2743 C != outputC)
2744 return emitOpError() << "requires values_out dimension 2 to have size "
2745 << C << ", got " << outputC;
2746 }
2747 if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
2748 return emitOpError() << "requires dimensions K >= W, got K=" << K
2749 << " and W=" << W;
2750
2751 return success();
2752}
2753
2754static LogicalResult ReduceInferReturnTypes(
2755 ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
2756 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2757 int64_t axisVal = axis.getValue().getSExtValue();
2758 if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
2759 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(inputType));
2760 return success();
2761 }
2762
2763 SmallVector<int64_t> outputShape;
2764 operandShape.getDims(res&: outputShape);
2765 outputShape[axisVal] = 1;
2766 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape, inputType));
2767 return success();
2768}
2769
2770#define COMPATIBLE_RETURN_TYPES(OP) \
2771 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
2772 if (l.size() != r.size() || l.size() != 1) \
2773 return false; \
2774 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
2775 return false; \
2776 return succeeded(verifyCompatibleShape(l[0], r[0])); \
2777 }
2778
2779#define REDUCE_SHAPE_INFER(OP) \
2780 LogicalResult OP::inferReturnTypeComponents( \
2781 MLIRContext *context, ::std::optional<Location> location, \
2782 OP::Adaptor adaptor, \
2783 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2784 Type inputType = \
2785 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
2786 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
2787 const Properties &prop = adaptor.getProperties(); \
2788 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
2789 inferredReturnShapes); \
2790 } \
2791 COMPATIBLE_RETURN_TYPES(OP)
2792
2793REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
2794REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
2795REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
2796REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
2797REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
2798REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
2799#undef REDUCE_SHAPE_INFER
2800COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
2801#undef COMPATIBLE_RETURN_TYPES
2802
2803template <typename T>
2804static LogicalResult verifyReduceOp(T op) {
2805 // All TOSA reduce Ops have input, output and axis.
2806 TensorType inputType = op.getInput().getType();
2807 TensorType outputType = op.getOutput().getType();
2808 int32_t reduceAxis = op.getAxis();
2809
2810 if (reduceAxis < 0) {
2811 op.emitOpError("reduce axis must not be negative");
2812 return failure();
2813 }
2814 if (inputType.hasRank()) {
2815 int64_t inputRank = inputType.getRank();
2816 // We allow for a special case where the input/output shape has rank 0 and
2817 // axis is also 0.
2818 if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
2819 op.emitOpError("expect input tensor rank (")
2820 << inputRank << ") to be larger than reduce axis (" << reduceAxis
2821 << ")";
2822 return failure();
2823 }
2824 }
2825 if (outputType.hasRank()) {
2826 int64_t outputRank = outputType.getRank();
2827 if (inputType.hasRank() && outputRank != inputType.getRank()) {
2828 op.emitOpError(
2829 "expect output tensor rank to be equal to input tensor rank");
2830 return failure();
2831 }
2832 if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
2833 op.emitOpError("expect output tensor rank (")
2834 << outputRank << ") to be larger than reduce axis (" << reduceAxis
2835 << ")";
2836 return failure();
2837 }
2838 // We can only verify the reduced dimension size to be 1 if this is not
2839 // the special case of output rank == 0.
2840 if (outputRank != 0) {
2841 auto outputShape = outputType.getShape();
2842 if (!outputType.isDynamicDim(idx: reduceAxis) &&
2843 outputShape[reduceAxis] != 1) {
2844 op.emitOpError("expect reduced dimension size to be 1, got ")
2845 << outputShape[reduceAxis];
2846 return failure();
2847 }
2848 }
2849 }
2850 return success();
2851}
2852
2853LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(op: *this); }
2854LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(op: *this); }
2855LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(op: *this); }
2856LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(op: *this); }
2857LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(op: *this); }
2858LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(op: *this); }
2859
2860static LogicalResult NAryInferReturnTypes(
2861 const ValueShapeRange &operands,
2862 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2863 llvm::SmallVector<int64_t> outShape;
2864 if (resolveBroadcastShape(operands, outShape).failed()) {
2865 inferredReturnShapes.push_back(Elt: ShapedTypeComponents());
2866 } else {
2867 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outShape));
2868 }
2869 return success();
2870}
2871
2872#define NARY_SHAPE_INFER(OP) \
2873 LogicalResult OP::inferReturnTypeComponents( \
2874 MLIRContext *context, ::std::optional<Location> location, \
2875 ValueShapeRange operands, DictionaryAttr attributes, \
2876 OpaqueProperties properties, RegionRange regions, \
2877 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2878 return NAryInferReturnTypes(operands, inferredReturnShapes); \
2879 }
2880
2881NARY_SHAPE_INFER(tosa::AbsOp)
2882NARY_SHAPE_INFER(tosa::AddOp)
2883NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
2884NARY_SHAPE_INFER(tosa::BitwiseAndOp)
2885NARY_SHAPE_INFER(tosa::BitwiseOrOp)
2886NARY_SHAPE_INFER(tosa::BitwiseXorOp)
2887NARY_SHAPE_INFER(tosa::BitwiseNotOp)
2888NARY_SHAPE_INFER(tosa::CastOp)
2889NARY_SHAPE_INFER(tosa::CeilOp)
2890NARY_SHAPE_INFER(tosa::ClampOp)
2891NARY_SHAPE_INFER(tosa::ClzOp)
2892NARY_SHAPE_INFER(tosa::CosOp)
2893NARY_SHAPE_INFER(tosa::ExpOp)
2894NARY_SHAPE_INFER(tosa::FloorOp)
2895NARY_SHAPE_INFER(tosa::GreaterEqualOp)
2896NARY_SHAPE_INFER(tosa::GreaterOp)
2897NARY_SHAPE_INFER(tosa::IdentityOp)
2898NARY_SHAPE_INFER(tosa::IntDivOp)
2899NARY_SHAPE_INFER(tosa::LogOp)
2900NARY_SHAPE_INFER(tosa::LogicalAndOp)
2901NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
2902NARY_SHAPE_INFER(tosa::LogicalNotOp)
2903NARY_SHAPE_INFER(tosa::LogicalOrOp)
2904NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
2905NARY_SHAPE_INFER(tosa::LogicalXorOp)
2906NARY_SHAPE_INFER(tosa::MaximumOp)
2907NARY_SHAPE_INFER(tosa::MinimumOp)
2908NARY_SHAPE_INFER(tosa::PowOp)
2909NARY_SHAPE_INFER(tosa::ReciprocalOp)
2910NARY_SHAPE_INFER(tosa::ReverseOp)
2911NARY_SHAPE_INFER(tosa::RsqrtOp)
2912NARY_SHAPE_INFER(tosa::SinOp)
2913NARY_SHAPE_INFER(tosa::SelectOp)
2914NARY_SHAPE_INFER(tosa::SubOp)
2915NARY_SHAPE_INFER(tosa::TanhOp)
2916NARY_SHAPE_INFER(tosa::ErfOp)
2917NARY_SHAPE_INFER(tosa::SigmoidOp)
2918#undef PRED_SHAPE_INFER
2919
2920LogicalResult tosa::NegateOp::inferReturnTypeComponents(
2921 MLIRContext *context, ::std::optional<Location> location,
2922 NegateOp::Adaptor adaptor,
2923 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2924 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2925 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(inputShape));
2926 return success();
2927}
2928
2929LogicalResult tosa::NegateOp::verify() {
2930 // Verify same element type
2931 const Type input1Type = getInput1().getType();
2932 const Type outputType = getOutput().getType();
2933 if (verifySameElementTypes(op: *this, inType: input1Type, outType: outputType).failed())
2934 return failure();
2935
2936 // Verify same shape
2937 const SmallVector<Type, 2> types = {input1Type, outputType};
2938 if (failed(Result: verifyCompatibleShapes(types)))
2939 return emitOpError() << "requires the same shape for input1 and output";
2940
2941 const Type input1EType = getStorageElementTypeOrSelf(type: getInput1().getType());
2942 const Type input1ZpEType =
2943 getStorageElementTypeOrSelf(type: getInput1Zp().getType());
2944 if (input1EType != input1ZpEType) {
2945 return emitOpError(message: "expect both input1 and its zero point are the same "
2946 "element type, got ")
2947 << input1EType << " and " << input1ZpEType;
2948 }
2949 const Type outputEType = getStorageElementTypeOrSelf(type: getOutput().getType());
2950 const Type outputZpEType =
2951 getStorageElementTypeOrSelf(type: getOutputZp().getType());
2952 if (outputEType != outputZpEType) {
2953 return emitOpError(message: "expect both output and its zero point are the same "
2954 "element type, got ")
2955 << outputEType << " and " << outputZpEType;
2956 }
2957
2958 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2959 if (succeeded(Result: maybeIZp) && verifyInput1ZeroPoint(zp: *maybeIZp).failed())
2960 return failure();
2961
2962 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2963 if (succeeded(Result: maybeOZp) && verifyOutputZeroPoint(zp: *maybeOZp).failed())
2964 return failure();
2965
2966 return success();
2967}
2968
2969static LogicalResult poolingInferReturnTypes(
2970 ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
2971 ArrayRef<int64_t> pad,
2972 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2973 llvm::SmallVector<int64_t> outputShape;
2974 outputShape.resize(N: 4, NV: ShapedType::kDynamic);
2975
2976 // We only know the rank if the input type is unranked.
2977 if (!inputShape) {
2978 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
2979 return success();
2980 }
2981
2982 // Batch and number of channels are identical for pooling layer.
2983 outputShape[0] = inputShape.getDimSize(index: 0);
2984 outputShape[3] = inputShape.getDimSize(index: 3);
2985
2986 int64_t height = inputShape.getDimSize(index: 1);
2987 int64_t width = inputShape.getDimSize(index: 2);
2988
2989 if (ShapedType::isStatic(dValue: height)) {
2990 int64_t padded = height + pad[0] + pad[1] - kernel[0];
2991 outputShape[1] = padded / stride[0] + 1;
2992 }
2993
2994 if (ShapedType::isStatic(dValue: width)) {
2995 int64_t padded = width + pad[2] + pad[3] - kernel[1];
2996 outputShape[2] = padded / stride[1] + 1;
2997 }
2998
2999 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
3000 return success();
3001}
3002
3003LogicalResult Conv2DOp::inferReturnTypeComponents(
3004 MLIRContext *context, ::std::optional<Location> location,
3005 Conv2DOp::Adaptor adaptor,
3006 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3007 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3008
3009 int64_t inputWidth = ShapedType::kDynamic;
3010 int64_t inputHeight = ShapedType::kDynamic;
3011 int64_t weightWidth = ShapedType::kDynamic;
3012 int64_t weightHeight = ShapedType::kDynamic;
3013
3014 // Input shape describes input width/height and batch.
3015
3016 ShapeAdaptor inputShape(adaptor.getInput().getType());
3017 if (inputShape.hasRank()) {
3018 outputShape[0] = inputShape.getDimSize(index: 0);
3019 inputHeight = inputShape.getDimSize(index: 1);
3020 inputWidth = inputShape.getDimSize(index: 2);
3021 }
3022
3023 // Weight shapes describes the filter width/height and the output channels.
3024 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3025 if (weightShape.hasRank()) {
3026 outputShape[3] = weightShape.getDimSize(index: 0);
3027 weightHeight = weightShape.getDimSize(index: 1);
3028 weightWidth = weightShape.getDimSize(index: 2);
3029 }
3030
3031 // Bias shape can describe the output channels.
3032 ShapeAdaptor biasShape(adaptor.getBias().getType());
3033 if (biasShape.hasRank()) {
3034 outputShape[3] = ShapedType::isDynamic(dValue: outputShape[3])
3035 ? biasShape.getDimSize(index: 0)
3036 : outputShape[3];
3037 }
3038
3039 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3040 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3041 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3042
3043 if (ShapedType::isStatic(dValue: inputHeight) && ShapedType::isStatic(dValue: weightHeight)) {
3044 int64_t inputSize = inputHeight + padding[0] + padding[1];
3045 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3046 int64_t unstridedResult = inputSize - filterSize + 1;
3047 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3048 }
3049
3050 if (ShapedType::isStatic(dValue: inputWidth) && ShapedType::isStatic(dValue: weightWidth)) {
3051 int64_t inputSize = inputWidth + padding[2] + padding[3];
3052 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3053 int64_t unstridedResult = inputSize - filterSize + 1;
3054 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3055 }
3056
3057 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
3058 return success();
3059}
3060
3061LogicalResult Conv2DOp::verify() {
3062 if (verifyConvOp(op: *this).failed() || verifyConvOpModes(op: *this).failed() ||
3063 verifyConvOpErrorIf(op: *this).failed())
3064 return failure();
3065 return success();
3066}
3067
3068LogicalResult Conv3DOp::inferReturnTypeComponents(
3069 MLIRContext *context, ::std::optional<Location> location,
3070 Conv3DOp::Adaptor adaptor,
3071 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3072 llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
3073
3074 int64_t inputWidth = ShapedType::kDynamic;
3075 int64_t inputHeight = ShapedType::kDynamic;
3076 int64_t inputDepth = ShapedType::kDynamic;
3077
3078 int64_t weightWidth = ShapedType::kDynamic;
3079 int64_t weightHeight = ShapedType::kDynamic;
3080 int64_t weightDepth = ShapedType::kDynamic;
3081
3082 // Input shape describes input width/height and batch.
3083 ShapeAdaptor inputShape(adaptor.getInput().getType());
3084 if (inputShape.hasRank()) {
3085 outputShape[0] = inputShape.getDimSize(index: 0);
3086 inputDepth = inputShape.getDimSize(index: 1);
3087 inputHeight = inputShape.getDimSize(index: 2);
3088 inputWidth = inputShape.getDimSize(index: 3);
3089 }
3090
3091 // Weight shapes describes the filter width/height and the output channels.
3092 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3093 if (weightShape.hasRank()) {
3094 outputShape[4] = weightShape.getDimSize(index: 0);
3095 weightDepth = weightShape.getDimSize(index: 1);
3096 weightHeight = weightShape.getDimSize(index: 2);
3097 weightWidth = weightShape.getDimSize(index: 3);
3098 }
3099
3100 // Bias shape can describe the output channels.
3101 ShapeAdaptor biasShape(adaptor.getBias().getType());
3102 if (biasShape.hasRank() && ShapedType::isDynamic(dValue: outputShape[4])) {
3103 outputShape[4] = biasShape.getDimSize(index: 0);
3104 }
3105
3106 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3107 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3108 llvm::ArrayRef<int64_t> pad = adaptor.getPad();
3109
3110 if (ShapedType::isStatic(dValue: inputDepth) && ShapedType::isStatic(dValue: weightDepth)) {
3111 int32_t inputSize = inputDepth + pad[0] + pad[1];
3112 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3113 int32_t unstridedResult = inputSize - filterSize + 1;
3114 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3115 }
3116
3117 if (ShapedType::isStatic(dValue: inputHeight) && ShapedType::isStatic(dValue: weightHeight)) {
3118 int32_t inputSize = inputHeight + pad[2] + pad[3];
3119 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3120 int32_t unstridedResult = inputSize - filterSize + 1;
3121 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3122 }
3123
3124 if (ShapedType::isStatic(dValue: inputWidth) && ShapedType::isStatic(dValue: weightWidth)) {
3125 int32_t inputSize = inputWidth + pad[4] + pad[5];
3126 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3127 int32_t unstridedResult = inputSize - filterSize + 1;
3128 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3129 }
3130
3131 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
3132 return success();
3133}
3134
3135LogicalResult Conv3DOp::verify() {
3136 if (verifyConvOp(op: *this).failed() || verifyConvOpModes(op: *this).failed() ||
3137 verifyConvOpErrorIf(op: *this).failed())
3138 return failure();
3139 return success();
3140}
3141
3142LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3143 MLIRContext *context, ::std::optional<Location> location,
3144 AvgPool2dOp::Adaptor adaptor,
3145 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3146 ShapeAdaptor inputShape(adaptor.getInput().getType());
3147 const Properties &prop = adaptor.getProperties();
3148 return poolingInferReturnTypes(inputShape, kernel: prop.kernel, stride: prop.stride, pad: prop.pad,
3149 inferredReturnShapes);
3150}
3151
3152LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3153 MLIRContext *context, ::std::optional<Location> location,
3154 MaxPool2dOp::Adaptor adaptor,
3155 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3156 ShapeAdaptor inputShape(adaptor.getInput().getType());
3157 const Properties &prop = adaptor.getProperties();
3158 return poolingInferReturnTypes(inputShape, kernel: prop.kernel, stride: prop.stride, pad: prop.pad,
3159 inferredReturnShapes);
3160}
3161
3162LogicalResult MaxPool2dOp::verify() {
3163 if (failed(Result: verifySameElementTypes(op: *this, /* intype = */ inType: getInput().getType(),
3164 /* outType = */ getOutput().getType())))
3165 return failure();
3166
3167 if (failed(Result: verifyPoolingOp(op: *this)))
3168 return failure();
3169
3170 return success();
3171}
3172
3173LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3174 MLIRContext *context, ::std::optional<Location> location,
3175 DepthwiseConv2DOp::Adaptor adaptor,
3176 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3177 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3178
3179 int64_t inputWidth = ShapedType::kDynamic;
3180 int64_t inputHeight = ShapedType::kDynamic;
3181 int64_t inputChannels = ShapedType::kDynamic;
3182
3183 int64_t weightWidth = ShapedType::kDynamic;
3184 int64_t weightHeight = ShapedType::kDynamic;
3185 int64_t depthChannels = ShapedType::kDynamic;
3186
3187 // Input shape describes input width/height and batch.
3188 ShapeAdaptor inputShape(adaptor.getInput().getType());
3189 if (inputShape.hasRank()) {
3190 outputShape[0] = inputShape.getDimSize(index: 0);
3191 inputHeight = inputShape.getDimSize(index: 1);
3192 inputWidth = inputShape.getDimSize(index: 2);
3193 inputChannels = inputShape.getDimSize(index: 3);
3194 }
3195
3196 // Weight shapes describes the filter width/height and the output channels.
3197 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3198 if (weightShape.hasRank()) {
3199 weightHeight = weightShape.getDimSize(index: 0);
3200 weightWidth = weightShape.getDimSize(index: 1);
3201 inputChannels = ShapedType::isDynamic(dValue: inputChannels)
3202 ? weightShape.getDimSize(index: 2)
3203 : inputChannels;
3204 depthChannels = weightShape.getDimSize(index: 3);
3205 }
3206
3207 // If both inputChannels and depthChannels are available we can determine
3208 // the output channels.
3209 if (ShapedType::isStatic(dValue: inputChannels) &&
3210 ShapedType::isStatic(dValue: depthChannels)) {
3211 outputShape[3] = inputChannels * depthChannels;
3212 }
3213
3214 // Bias shape can describe the output channels.
3215 ShapeAdaptor biasShape(adaptor.getBias().getType());
3216 if (biasShape.hasRank()) {
3217 outputShape[3] = ShapedType::isDynamic(dValue: outputShape[3])
3218 ? biasShape.getDimSize(index: 0)
3219 : outputShape[3];
3220 }
3221
3222 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3223 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3224 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3225
3226 if (ShapedType::isStatic(dValue: inputHeight) && ShapedType::isStatic(dValue: weightHeight)) {
3227 int64_t inputSize = inputHeight + padding[0] + padding[1];
3228 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3229 int64_t unstridedResult = inputSize - filterSize + 1;
3230 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3231 }
3232
3233 if (ShapedType::isStatic(dValue: inputWidth) && ShapedType::isStatic(dValue: weightWidth)) {
3234 int64_t inputSize = inputWidth + padding[2] + padding[3];
3235 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3236 int64_t unstridedResult = inputSize - filterSize + 1;
3237 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3238 }
3239
3240 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
3241 return success();
3242}
3243
3244LogicalResult DepthwiseConv2DOp::verify() {
3245 if (verifyConvOp(op: *this).failed() || verifyConvOpModes(op: *this).failed() ||
3246 verifyConvOpErrorIf(op: *this).failed())
3247 return failure();
3248 return success();
3249}
3250
3251LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3252 MLIRContext *context, ::std::optional<Location> location,
3253 TransposeConv2DOp::Adaptor adaptor,
3254 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3255 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3256
3257 int64_t inputWidth = ShapedType::kDynamic;
3258 int64_t inputHeight = ShapedType::kDynamic;
3259 int64_t weightWidth = ShapedType::kDynamic;
3260 int64_t weightHeight = ShapedType::kDynamic;
3261
3262 // Input shape describes input width/height and batch.
3263 ShapeAdaptor inputShape(adaptor.getInput().getType());
3264 if (inputShape.hasRank()) {
3265 outputShape[0] = ShapedType::isDynamic(dValue: outputShape[0])
3266 ? inputShape.getDimSize(index: 0)
3267 : outputShape[0];
3268 inputHeight = inputShape.getDimSize(index: 1);
3269 inputWidth = inputShape.getDimSize(index: 2);
3270 }
3271
3272 // Weight shapes describes the filter width/height and the output channels.
3273 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3274 if (weightShape.hasRank()) {
3275 outputShape[3] = ShapedType::isDynamic(dValue: outputShape[3])
3276 ? weightShape.getDimSize(index: 0)
3277 : outputShape[3];
3278 weightHeight = weightShape.getDimSize(index: 1);
3279 weightWidth = weightShape.getDimSize(index: 2);
3280 }
3281
3282 // Bias shape can describe the output channels.
3283 ShapeAdaptor biasShape(adaptor.getInput().getType());
3284 if (biasShape.hasRank()) {
3285 outputShape[3] = ShapedType::isDynamic(dValue: outputShape[3])
3286 ? biasShape.getDimSize(index: 0)
3287 : outputShape[3];
3288 }
3289
3290 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
3291 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3292
3293 if (ShapedType::isStatic(dValue: inputHeight) && ShapedType::isStatic(dValue: weightHeight)) {
3294 int64_t calculateSize =
3295 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3296 outputShape[1] =
3297 ShapedType::isDynamic(dValue: outputShape[1]) ? calculateSize : outputShape[1];
3298 }
3299
3300 if (ShapedType::isStatic(dValue: inputWidth) && ShapedType::isStatic(dValue: weightWidth)) {
3301 int64_t calculateSize =
3302 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3303 outputShape[2] =
3304 ShapedType::isDynamic(dValue: outputShape[2]) ? calculateSize : outputShape[2];
3305 }
3306
3307 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
3308 return success();
3309}
3310
3311LogicalResult TransposeConv2DOp::verify() {
3312 if (verifyConvOp(op: *this).failed() || verifyConvOpModes(op: *this).failed())
3313 return failure();
3314
3315 const llvm::ArrayRef<int64_t> strides = getStride();
3316 const int64_t strideY = strides[0];
3317 const int64_t strideX = strides[1];
3318
3319 if (strideY < 1 || strideX < 1)
3320 return emitOpError(message: "expect all stride values to be >= 1, got [")
3321 << strides << "]";
3322
3323 const auto checkPadAgainstKernelDim =
3324 [this](int64_t pad_value, int64_t kernel_dim_size,
3325 llvm::StringRef pad_name,
3326 llvm::StringRef kernel_dim_name) -> LogicalResult {
3327 if (pad_value <= -kernel_dim_size)
3328 return emitOpError(message: "expected ")
3329 << pad_name << " > -" << kernel_dim_name
3330 << ", but got: " << pad_name << "=" << pad_value << " and "
3331 << kernel_dim_name << "=" << kernel_dim_size;
3332 return success();
3333 };
3334
3335 const llvm::ArrayRef<int64_t> padding = getOutPad();
3336 const int64_t outPadTop = padding[0];
3337 const int64_t outPadBottom = padding[1];
3338 const int64_t outPadLeft = padding[2];
3339 const int64_t outPadRight = padding[3];
3340
3341 const auto weightType =
3342 llvm::dyn_cast<RankedTensorType>(Val: getWeight().getType());
3343
3344 if (weightType) {
3345 const int64_t kernelHeight = weightType.getDimSize(idx: 1);
3346 if (ShapedType::isStatic(dValue: kernelHeight)) {
3347 if (failed(Result: checkPadAgainstKernelDim(outPadTop, kernelHeight,
3348 "out_pad_top", "KH")))
3349 return failure();
3350
3351 if (failed(Result: checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3352 "out_pad_bottom", "KH")))
3353 return failure();
3354 }
3355
3356 const int64_t kernelWidth = weightType.getDimSize(idx: 2);
3357 if (ShapedType::isStatic(dValue: kernelWidth)) {
3358 if (failed(Result: checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3359 "out_pad_left", "KW")))
3360 return failure();
3361
3362 if (failed(Result: checkPadAgainstKernelDim(outPadRight, kernelWidth,
3363 "out_pad_right", "KW")))
3364 return failure();
3365 }
3366 }
3367
3368 // Rest of the checks depend on the output type being a RankedTensorType
3369 const auto outputType =
3370 llvm::dyn_cast<RankedTensorType>(Val: getOutput().getType());
3371 if (!outputType)
3372 return success();
3373
3374 const auto inputType = llvm::dyn_cast<RankedTensorType>(Val: getInput().getType());
3375 if (inputType && weightType) {
3376 const int64_t inputHeight = inputType.getDimSize(idx: 1);
3377 const int64_t kernelHeight = weightType.getDimSize(idx: 1);
3378 const int64_t outputHeight = outputType.getDimSize(idx: 1);
3379
3380 if (ShapedType::isStatic(dValue: inputHeight) &&
3381 ShapedType::isStatic(dValue: outputHeight)) {
3382 if (outputHeight !=
3383 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3384 return emitOpError(
3385 message: "dimension mismatch: expected OH == (IH - 1) * stride_y "
3386 "+ out_pad_top + out_pad_bottom + KH, but got ")
3387 << outputHeight << " != (" << inputHeight << " - 1) * "
3388 << strideY << " + " << outPadTop << " + " << outPadBottom
3389 << " + " << kernelHeight;
3390 }
3391
3392 const int64_t inputWidth = inputType.getDimSize(idx: 2);
3393 const int64_t kernelWidth = weightType.getDimSize(idx: 2);
3394 const int64_t outputWidth = outputType.getDimSize(idx: 2);
3395
3396 if (ShapedType::isStatic(dValue: inputWidth) && ShapedType::isStatic(dValue: outputWidth)) {
3397 if (outputWidth !=
3398 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3399 return emitOpError(
3400 message: "dimension mismatch: expected OW == (IW - 1) * stride_x "
3401 "+ out_pad_left + out_pad_right + KW, but got ")
3402 << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3403 << " + " << outPadLeft << " + " << outPadRight << " + "
3404 << kernelWidth;
3405 }
3406 }
3407
3408 const auto biasType = llvm::dyn_cast<RankedTensorType>(Val: getBias().getType());
3409
3410 if (!biasType)
3411 return success();
3412
3413 const int64_t biasChannels = biasType.getDimSize(idx: 0);
3414
3415 // Skip further checks if bias is dynamic
3416 if (biasChannels == ShapedType::kDynamic)
3417 return success();
3418
3419 const int64_t outputChannels = outputType.getDimSize(idx: 3);
3420 if (biasChannels != outputChannels && biasChannels != 1)
3421 return emitOpError(
3422 message: "bias channels expected to be equal to output channels (")
3423 << outputChannels << ") or 1, got " << biasChannels;
3424
3425 return success();
3426}
3427
3428LogicalResult RescaleOp::verify() {
3429 auto inputType = llvm::dyn_cast<ShapedType>(Val: getInput().getType());
3430 if (!inputType) {
3431 emitOpError(message: "expect shaped tensor for input, got ") << getInput().getType();
3432 return failure();
3433 }
3434
3435 auto inputElementType =
3436 getStorageElementTypeOrSelf(type: inputType.getElementType());
3437 if (!mlir::isa<IntegerType>(Val: inputElementType)) {
3438 emitOpError(message: "expect input to have integer element type, got ")
3439 << inputElementType;
3440 return failure();
3441 }
3442
3443 auto outputType = llvm::dyn_cast<ShapedType>(Val: getOutput().getType());
3444 if (!outputType) {
3445 emitOpError(message: "expect shaped tensor for output, got ")
3446 << getOutput().getType();
3447 return failure();
3448 }
3449
3450 auto outputElementType =
3451 getStorageElementTypeOrSelf(type: outputType.getElementType());
3452 if (!mlir::isa<IntegerType>(Val: outputElementType)) {
3453 emitOpError(message: "expect output to have integer element type, got ")
3454 << outputElementType;
3455 return failure();
3456 }
3457
3458 if (verifyRescaleValueAndZpTypes(op: *this, val: getInput(), valZp: getInputZp(), name: "input")
3459 .failed())
3460 return failure();
3461
3462 if (verifyRescaleValueAndZpTypes(op: *this, val: getOutput(), valZp: getOutputZp(), name: "output")
3463 .failed())
3464 return failure();
3465
3466 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3467 if (succeeded(Result: maybeIZp) && verifyInputZeroPoint(zp: *maybeIZp).failed())
3468 return failure();
3469
3470 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3471 if (succeeded(Result: maybeOZp) && verifyOutputZeroPoint(zp: *maybeOZp).failed())
3472 return failure();
3473
3474 auto multiplierType = llvm::dyn_cast<ShapedType>(Val: getMultiplier().getType());
3475 if (!multiplierType) {
3476 emitOpError(message: "expect shaped tensor for multiplier, got ")
3477 << getMultiplier().getType();
3478 return failure();
3479 }
3480
3481 auto shiftType = llvm::dyn_cast<ShapedType>(Val: getShift().getType());
3482 if (!shiftType) {
3483 emitOpError(message: "expect shaped tensor for shift, got ") << getShift().getType();
3484 return failure();
3485 }
3486
3487 // multiplier element type must be i32 for scale32 = true
3488 if (getScale32() && !multiplierType.getElementType().isInteger(width: 32)) {
3489 emitOpError(message: "expect i32 element type for multiplier for scale32=true, got ")
3490 << multiplierType.getElementType();
3491 return failure();
3492 }
3493
3494 // multiplier element type must be i16 for scale32 = false
3495 if (!getScale32() && !multiplierType.getElementType().isInteger(width: 16)) {
3496 emitOpError(
3497 message: "expect i16 element type for multiplier for scale32=false, got ")
3498 << multiplierType.getElementType();
3499 return failure();
3500 }
3501
3502 if (!inputType.hasRank())
3503 return success();
3504
3505 // multiplier/shift must have shape = {numChannels},
3506 // where numChannel is 1 if per_channel = false
3507 // otherwise numChannel is dimension in input shape's last axis
3508 int64_t numChannels = 1;
3509 if (getPerChannel()) {
3510 if (inputType.getRank() < 1) {
3511 emitOpError(message: "requires input to be at least rank 1 when per_channel is "
3512 "true, but got rank ")
3513 << inputType.getRank();
3514 return failure();
3515 }
3516 numChannels = inputType.getDimSize(idx: inputType.getRank() - 1);
3517 }
3518
3519 if (!multiplierType.hasRank())
3520 return success();
3521
3522 ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
3523 // multiplier input has rank 1 by dialect definition
3524 if (multiplierShape[0] != ShapedType::kDynamic &&
3525 multiplierShape[0] != numChannels) {
3526 emitOpError(message: "expect shape of { ")
3527 << numChannels << " } for multiplier input, got { "
3528 << multiplierShape[0] << " }";
3529 return failure();
3530 }
3531
3532 if (!shiftType.hasRank())
3533 return success();
3534
3535 ArrayRef<int64_t> shiftShape = shiftType.getShape();
3536 // shift input has rank 1 by dialect definition
3537 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3538 emitOpError(message: "expect shape of { ")
3539 << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
3540 return failure();
3541 }
3542
3543 return success();
3544}
3545
3546LogicalResult RescaleOp::inferReturnTypeComponents(
3547 MLIRContext *context, ::std::optional<Location> location,
3548 RescaleOp::Adaptor adaptor,
3549 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3550 ShapeAdaptor inputShape(adaptor.getInput().getType());
3551 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(inputShape));
3552 return success();
3553}
3554
3555LogicalResult IfOp::inferReturnTypeComponents(
3556 MLIRContext *context, ::std::optional<Location> location,
3557 IfOp::Adaptor adaptor,
3558 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3559 llvm::SmallVector<tosa::YieldOp> yieldOps;
3560 for (Region *region : adaptor.getRegions()) {
3561 for (auto &block : *region)
3562 if (auto returnOp = dyn_cast<tosa::YieldOp>(Val: block.getTerminator()))
3563 yieldOps.push_back(Elt: returnOp);
3564 }
3565
3566 if (yieldOps.empty())
3567 return failure();
3568
3569 // Get the initial type information for the yield op.
3570 llvm::SmallVector<ValueKnowledge> resultKnowledge;
3571 resultKnowledge.reserve(N: yieldOps.front().getNumOperands());
3572 for (auto operand : yieldOps.front().getOperands()) {
3573 resultKnowledge.push_back(
3574 Elt: ValueKnowledge::getKnowledgeFromType(type: operand.getType()));
3575 }
3576
3577 for (auto yieldOp : yieldOps) {
3578 if (resultKnowledge.size() != yieldOp.getNumOperands())
3579 return failure();
3580
3581 for (const auto &it : llvm::enumerate(First: yieldOp.getOperands())) {
3582 int32_t index = it.index();
3583 auto meet = ValueKnowledge::meet(
3584 lhs: resultKnowledge[index],
3585 rhs: ValueKnowledge::getKnowledgeFromType(type: it.value().getType()));
3586 if (!meet)
3587 continue;
3588 resultKnowledge[index] = meet;
3589 }
3590 }
3591
3592 for (const ValueKnowledge &result : resultKnowledge) {
3593 inferredReturnShapes.push_back(Elt: result.getShapedTypeComponents());
3594 }
3595
3596 return success();
3597}
3598
3599LogicalResult WhileOp::inferReturnTypeComponents(
3600 MLIRContext *context, ::std::optional<Location> location,
3601 WhileOp::Adaptor adaptor,
3602 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3603 llvm::SmallVector<tosa::YieldOp> yieldOps;
3604 for (auto &block : adaptor.getBodyGraph())
3605 if (auto returnOp = dyn_cast<tosa::YieldOp>(Val: block.getTerminator()))
3606 yieldOps.push_back(Elt: returnOp);
3607
3608 // TOSA's while must have a tosa.yield as its terminator. If not found this
3609 // tosa.while is invalid.
3610 if (yieldOps.empty())
3611 return failure();
3612
3613 // Get the initial type information from the operand types.
3614 llvm::SmallVector<ValueKnowledge> resultKnowledge;
3615 resultKnowledge.reserve(N: yieldOps.front().getNumOperands());
3616 for (auto operand : yieldOps.front().getOperands()) {
3617 resultKnowledge.push_back(
3618 Elt: ValueKnowledge::getKnowledgeFromType(type: operand.getType()));
3619 }
3620
3621 for (auto yieldOp : yieldOps) {
3622 if (resultKnowledge.size() != yieldOp.getNumOperands())
3623 return failure();
3624
3625 for (const auto &it : llvm::enumerate(First: yieldOp.getOperands())) {
3626 int32_t index = it.index();
3627 if (auto meet = ValueKnowledge::meet(
3628 lhs: resultKnowledge[index],
3629 rhs: ValueKnowledge::getKnowledgeFromType(type: it.value().getType()))) {
3630 resultKnowledge[index] = meet;
3631 }
3632 }
3633 }
3634
3635 for (const ValueKnowledge &result : resultKnowledge) {
3636 inferredReturnShapes.push_back(Elt: result.getShapedTypeComponents());
3637 }
3638
3639 return success();
3640}
3641
3642std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3643 if (auto vt = llvm::dyn_cast<VectorType>(Val: getType()))
3644 return llvm::to_vector<4>(Range: vt.getShape());
3645 return std::nullopt;
3646}
3647
3648// parse and print of IfOp refer to the implementation of SCF dialect.
3649ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
3650 // Create the regions for 'then'.
3651 result.regions.reserve(N: 2);
3652 Region *thenRegion = result.addRegion();
3653 Region *elseRegion = result.addRegion();
3654
3655 auto &builder = parser.getBuilder();
3656 OpAsmParser::UnresolvedOperand cond;
3657 // Create a i1 tensor type for the boolean condition.
3658 Type i1Type = RankedTensorType::get(shape: {}, elementType: builder.getIntegerType(width: 1));
3659 if (parser.parseOperand(result&: cond) ||
3660 parser.resolveOperand(operand: cond, type: i1Type, result&: result.operands))
3661 return failure();
3662 // Parse optional results type list.
3663 if (parser.parseOptionalArrowTypeList(result&: result.types))
3664 return failure();
3665 // Parse the 'then' region.
3666 if (parser.parseRegion(region&: *thenRegion, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {}))
3667 return failure();
3668
3669 // If we find an 'else' keyword then parse the 'else' region.
3670 if (!parser.parseOptionalKeyword(keyword: "else")) {
3671 if (parser.parseRegion(region&: *elseRegion, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {}))
3672 return failure();
3673 }
3674
3675 // Parse the optional attribute list.
3676 if (parser.parseOptionalAttrDict(result&: result.attributes))
3677 return failure();
3678 return success();
3679}
3680
3681void IfOp::print(OpAsmPrinter &p) {
3682 bool printBlockTerminators = false;
3683
3684 p << " " << getCondition();
3685 if (!getResults().empty()) {
3686 p << " -> (" << getResultTypes() << ")";
3687 // Print yield explicitly if the op defines values.
3688 printBlockTerminators = true;
3689 }
3690 p << ' ';
3691 p.printRegion(blocks&: getThenGraph(),
3692 /*printEntryBlockArgs=*/false,
3693 /*printBlockTerminators=*/printBlockTerminators);
3694
3695 // Print the 'else' regions if it exists and has a block.
3696 auto &elseRegion = getElseGraph();
3697 if (!elseRegion.empty()) {
3698 p << " else ";
3699 p.printRegion(blocks&: elseRegion,
3700 /*printEntryBlockArgs=*/false,
3701 /*printBlockTerminators=*/printBlockTerminators);
3702 }
3703
3704 p.printOptionalAttrDict(attrs: (*this)->getAttrs());
3705}
3706
3707LogicalResult IfOp::verify() {
3708 if (errorIfTypeOrShapeMismatch(op: *this, list1: getThenGraph().front().getArguments(),
3709 name1: "'then_graph' arguments", list2: getInputList(),
3710 name2: "'input_list'")
3711 .failed())
3712 return failure();
3713
3714 if (errorIfTypeOrShapeMismatch(op: *this, list1: getElseGraph().front().getArguments(),
3715 name1: "'else_graph' arguments", list2: getInputList(),
3716 name2: "'input_list'")
3717 .failed())
3718 return failure();
3719
3720 auto thenYield = cast<tosa::YieldOp>(Val: getThenGraph().front().getTerminator());
3721 if (errorIfTypeOrShapeMismatch(op: *this, list1: thenYield.getInputs(),
3722 name1: "'then_graph' results", list2: getOutputList(),
3723 name2: "'output_list'")
3724 .failed())
3725 return failure();
3726
3727 auto elseYield = cast<tosa::YieldOp>(Val: getElseGraph().front().getTerminator());
3728 if (errorIfTypeOrShapeMismatch(op: *this, list1: elseYield.getInputs(),
3729 name1: "'else_graph' results", list2: getOutputList(),
3730 name2: "'output_list'")
3731 .failed())
3732 return failure();
3733
3734 auto condType = getCondition().getType();
3735 if (errorIfShapeNotSizeOne(op: *this, type: condType).failed())
3736 return emitOpError() << "'condition' must be a size 1 tensor, got "
3737 << condType;
3738
3739 return success();
3740}
3741
3742LogicalResult WhileOp::verify() {
3743 if (errorIfTypeOrShapeMismatch(op: *this, list1: getInputList(), name1: "'input_list'",
3744 list2: getOutputList(), name2: "'output_list'")
3745 .failed())
3746 return failure();
3747
3748 if (errorIfTypeOrShapeMismatch(op: *this, list1: getCondGraph().front().getArguments(),
3749 name1: "'cond_graph' arguments", list2: getInputList(),
3750 name2: "'input_list'")
3751 .failed())
3752 return failure();
3753
3754 if (errorIfTypeOrShapeMismatch(op: *this, list1: getBodyGraph().front().getArguments(),
3755 name1: "'body_graph' arguments", list2: getInputList(),
3756 name2: "'input_list'")
3757 .failed())
3758 return failure();
3759
3760 auto bodyYield = cast<tosa::YieldOp>(Val: getBodyGraph().front().getTerminator());
3761 if (errorIfTypeOrShapeMismatch(op: *this, list1: bodyYield.getInputs(),
3762 name1: "'body_graph' results", list2: getInputList(),
3763 name2: "'input_list'")
3764 .failed())
3765 return failure();
3766
3767 // Condition block output must be a single element tensor with a single bool
3768 // value.
3769 auto condYield = cast<tosa::YieldOp>(Val: getCondGraph().front().getTerminator());
3770 if (condYield.getInputs().size() != 1)
3771 return emitOpError() << "require 'cond_graph' only have one result";
3772
3773 auto condOutType = condYield.getInputs()[0].getType();
3774 if (errorIfShapeNotSizeOne(op: *this, type: condOutType).failed())
3775 return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
3776 << condOutType;
3777
3778 if (!getElementTypeOrSelf(type: condOutType).isInteger(width: 1))
3779 return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
3780 << condOutType;
3781
3782 return success();
3783}
3784
3785LogicalResult ReverseOp::verify() {
3786 if (verifySameElementTypes(op: *this, /* inType = */ getInput1().getType(),
3787 /* outType = */ getOutput().getType())
3788 .failed())
3789 return failure();
3790 TensorType inputType = getInput1().getType();
3791 TensorType outputType = getOutput().getType();
3792 int32_t reverseAxis = getAxis();
3793
3794 if (reverseAxis < 0)
3795 return emitOpError(message: "expected non-negative reverse axis");
3796 if (inputType.hasRank()) {
3797 int64_t inputRank = inputType.getRank();
3798 // We allow for a special case where the input/output shape has rank 0 and
3799 // axis is also 0.
3800 if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
3801 return emitOpError(message: "expect input tensor rank (")
3802 << inputRank << ") to be larger than reverse axis (" << reverseAxis
3803 << ")";
3804 }
3805 if (outputType.hasRank()) {
3806 int64_t outputRank = outputType.getRank();
3807 if (inputType.hasRank() && outputRank != inputType.getRank())
3808 return emitOpError(
3809 message: "expect output tensor rank to be equal to input tensor rank");
3810 if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
3811 return emitOpError(message: "expect output tensor rank (")
3812 << outputRank << ") to be larger than reverse axis ("
3813 << reverseAxis << ")";
3814 }
3815 return success();
3816}
3817
3818LogicalResult tosa::SelectOp::verify() {
3819 // verify input2 and input3 have same element type as output
3820 if (verifySameElementTypes(op: *this, /* inType = */ getOnTrue().getType(),
3821 /* outType = */ getOutput().getType())
3822 .failed() ||
3823 verifySameElementTypes(op: *this, /* inType = */ getOnFalse().getType(),
3824 /* outType = */ getOutput().getType())
3825 .failed()) {
3826 return failure();
3827 }
3828 // verify input1 has element type of bool
3829 auto predicateType = llvm::dyn_cast<ShapedType>(Val: getPred().getType());
3830 if (!predicateType) {
3831 return emitOpError(message: "expect shaped tensor for input1, got ")
3832 << getInput1().getType();
3833 }
3834 auto predicateElementType = predicateType.getElementType();
3835 if (!predicateElementType.isInteger(width: 1)) {
3836 return emitOpError(message: "expect element type of bool for input1, got ")
3837 << predicateElementType;
3838 }
3839
3840 return success();
3841}
3842
3843LogicalResult tosa::VariableOp::verify() {
3844 StringRef symName = getName();
3845 FailureOr<tosa::VariableOp> varOp = findVariableDecl(op: *this, symName);
3846 if (succeeded(Result: varOp))
3847 return emitOpError(message: "illegal to have multiple declaration of '")
3848 << symName << "'";
3849
3850 return success();
3851}
3852
3853LogicalResult tosa::VariableReadOp::verify() {
3854 if (verifyVariableOpErrorIf(op: *this, type: getOutput1().getType(), name: "'output1'")
3855 .failed())
3856 return failure();
3857
3858 return success();
3859}
3860
3861LogicalResult tosa::VariableWriteOp::verify() {
3862 if (verifyVariableOpErrorIf(op: *this, type: getInput1().getType(), name: "'input1'")
3863 .failed())
3864 return failure();
3865
3866 return success();
3867}
3868
3869// parse and print of WhileOp refer to the implementation of SCF dialect.
3870ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3871 SmallVector<OpAsmParser::Argument, 4> regionArgs;
3872 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3873 Region *cond = result.addRegion();
3874 Region *body = result.addRegion();
3875
3876 OptionalParseResult listResult =
3877 parser.parseOptionalAssignmentList(lhs&: regionArgs, rhs&: operands);
3878 if (listResult.has_value() && failed(Result: listResult.value()))
3879 return failure();
3880
3881 FunctionType functionType;
3882 SMLoc typeLoc = parser.getCurrentLocation();
3883 if (failed(Result: parser.parseColonType(result&: functionType)))
3884 return failure();
3885
3886 result.addTypes(newTypes: functionType.getResults());
3887
3888 if (functionType.getNumInputs() != operands.size()) {
3889 return parser.emitError(loc: typeLoc)
3890 << "expected as many input types as operands "
3891 << "(expected " << operands.size() << " got "
3892 << functionType.getNumInputs() << ")";
3893 }
3894
3895 // Resolve input operands.
3896 if (failed(Result: parser.resolveOperands(operands, types: functionType.getInputs(),
3897 loc: parser.getCurrentLocation(),
3898 result&: result.operands)))
3899 return failure();
3900
3901 // Propagate the types into the region arguments.
3902 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3903 regionArgs[i].type = functionType.getInput(i);
3904
3905 return failure(IsFailure: parser.parseRegion(region&: *cond, arguments: regionArgs) ||
3906 parser.parseKeyword(keyword: "do") || parser.parseRegion(region&: *body) ||
3907 parser.parseOptionalAttrDictWithKeyword(result&: result.attributes));
3908}
3909
3910static void printInitializationList(OpAsmPrinter &parser,
3911 Block::BlockArgListType blocksArgs,
3912 ValueRange initializers,
3913 StringRef prefix = "") {
3914 assert(blocksArgs.size() == initializers.size() &&
3915 "expected same length of arguments and initializers");
3916 if (initializers.empty())
3917 return;
3918
3919 parser << prefix << '(';
3920 llvm::interleaveComma(
3921 c: llvm::zip(t&: blocksArgs, u&: initializers), os&: parser,
3922 each_fn: [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
3923 parser << ")";
3924}
3925
3926void WhileOp::print(OpAsmPrinter &parser) {
3927 printInitializationList(parser, blocksArgs: getCondGraph().front().getArguments(),
3928 initializers: getInputList(), prefix: " ");
3929 parser << " : ";
3930 parser.printFunctionalType(inputs: getInputList().getTypes(),
3931 results: getResults().getTypes());
3932 parser << ' ';
3933 parser.printRegion(blocks&: getCondGraph(), /*printEntryBlockArgs=*/false);
3934 parser << " do ";
3935 parser.printRegion(blocks&: getBodyGraph());
3936 parser.printOptionalAttrDictWithKeyword(attrs: (*this)->getAttrs());
3937}
3938
3939// Create a rank-1 const tensor for zero point of the source tensor.
3940std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
3941 Location loc,
3942 Type srcElemType,
3943 int64_t zp) {
3944 srcElemType = getStorageElementTypeOrSelf(type: srcElemType);
3945 auto zpType = mlir::RankedTensorType::get(shape: {1}, elementType: srcElemType);
3946 if (llvm::isa<FloatType>(Val: srcElemType)) {
3947 auto zpAttr = DenseElementsAttr::get(
3948 type: zpType, values: builder.getFloatAttr(type: srcElemType, value: static_cast<double>(zp)));
3949 return builder.create<tosa::ConstOp>(location: loc, args&: zpType, args&: zpAttr);
3950 }
3951 if (llvm::isa<IntegerType>(Val: srcElemType)) {
3952 auto zpAttr =
3953 DenseElementsAttr::get(type: zpType, values: builder.getIntegerAttr(type: srcElemType, value: zp));
3954 return builder.create<tosa::ConstOp>(location: loc, args&: zpType, args&: zpAttr);
3955 }
3956 llvm::errs() << "zero point is not allowed for unsupported data types\n";
3957 return std::nullopt;
3958}
3959
3960//===----------------------------------------------------------------------===//
3961// TOSA Shape and Shape Operators Helper functions.
3962//===----------------------------------------------------------------------===//
3963
3964bool mlir::tosa::isa_tosa_shape_type(mlir::Type t) {
3965 return mlir::isa<tosa::shapeType>(Val: t);
3966}
3967
3968LogicalResult
3969mlir::tosa::shapeType::verify(function_ref<InFlightDiagnostic()> emitError,
3970 int rank) {
3971 if (rank < 0)
3972 return emitError() << "invalid rank (must be >= 0): " << rank;
3973 return success();
3974}
3975
3976LogicalResult OpTrait::tosa::verifyTosaResolvableShapeOperands(Operation *op) {
3977 for (auto v : op->getOperands()) {
3978 if (mlir::isa<::mlir::tosa::shapeType>(Val: v.getType())) {
3979 Operation *definingOp = v.getDefiningOp();
3980 if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
3981 return op->emitOpError(message: "shape operand is not compile time resolvable");
3982 }
3983 }
3984 }
3985 return success();
3986}
3987
3988LogicalResult OpTrait::tosa::verifyTosaShapeOperator(Operation *op) {
3989 for (auto type : op->getOperandTypes()) {
3990 if (!mlir::isa<mlir::tosa::shapeType>(Val: type)) {
3991 return op->emitOpError(message: "must have operands with tosa shape type");
3992 }
3993 }
3994 for (auto type : op->getResultTypes()) {
3995 if (!mlir::isa<mlir::tosa::shapeType>(Val: type)) {
3996 return op->emitOpError(message: "must have result with tosa shape type");
3997 }
3998 }
3999 return success();
4000}
4001
4002LogicalResult
4003OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) {
4004 if (failed(Result: OpTrait::impl::verifyAtLeastNOperands(op, numOperands: 1)) ||
4005 failed(Result: verifyTosaShapeOperator(op)))
4006 return failure();
4007
4008 // delegate function that returns rank of shape type
4009 auto getRank = [](const Type type) {
4010 return mlir::cast<mlir::tosa::shapeType>(Val: type).getRank();
4011 };
4012 auto operandTypes = op->getOperandTypes();
4013 auto resultTypes = op->getResultTypes();
4014
4015 auto rank = getRank(*op->getOperandTypes().begin());
4016 for (auto type : operandTypes) {
4017 if (getRank(type) != rank) {
4018 return op->emitOpError(message: "operands don't have matching ranks");
4019 }
4020 }
4021 for (auto type : resultTypes) {
4022 if (getRank(type) != rank) {
4023 return op->emitOpError(message: "result shape has different rank than operands");
4024 }
4025 }
4026 return success();
4027}
4028
4029//===----------------------------------------------------------------------===//
4030// TOSA Shape Operators verify functions.
4031//===----------------------------------------------------------------------===//
4032
4033LogicalResult tosa::ConstShapeOp::verify() {
4034 // check one dimensional rank
4035 auto valuesRank = getValues().getType().getRank();
4036 if (valuesRank != 1)
4037 return emitOpError(message: "expect elements in attribute values with rank 1");
4038 // check that number of elements in values attr equal to rank of result shape
4039 auto count = getValues().getNumElements();
4040 auto rank = (cast<tosa::shapeType>(Val: getResult().getType())).getRank();
4041 if (!(count == rank || (count == 1 && rank == 0))) {
4042 return emitOpError(message: "expect number of elements in attribute values (")
4043 << count << ") to be equal to the rank (" << rank
4044 << ") for the result shape type";
4045 }
4046 return success();
4047}
4048
4049//===----------------------------------------------------------------------===//
4050// TOSA Attribute Definitions.
4051//===----------------------------------------------------------------------===//
4052
4053#define GET_ATTRDEF_CLASSES
4054#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4055
4056//===----------------------------------------------------------------------===//
4057// TOSA Type Definitions.
4058//===----------------------------------------------------------------------===//
4059#define GET_TYPEDEF_CLASSES
4060#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4061
4062//===----------------------------------------------------------------------===//
4063// TOSA Operator Definitions.
4064//===----------------------------------------------------------------------===//
4065
4066#define GET_OP_CLASSES
4067#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
4068

source code of mlir/lib/Dialect/Tosa/IR/TosaOps.cpp