1//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// This file implements the TOSA Specification:
11// https://www.mlplatform.org/tosa/tosa_spec.html
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/Tosa/IR/TosaOps.h"
16#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
17#include "mlir/Dialect/Quant/IR/Quant.h"
18#include "mlir/Dialect/Tensor/IR/Tensor.h"
19#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
20#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
21#include "mlir/Dialect/Utils/IndexingUtils.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/DialectImplementation.h"
24#include "mlir/IR/Matchers.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/IR/TypeUtilities.h"
27#include "mlir/Interfaces/InferTypeOpInterface.h"
28#include "mlir/Transforms/InliningUtils.h"
29#include "llvm/ADT/APFloat.h"
30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/TypeSwitch.h"
32
33#include <numeric>
34
35using namespace mlir;
36using namespace mlir::tosa;
37
38#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
39#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
40
41//===----------------------------------------------------------------------===//
42// Tosa dialect interface includes.
43//===----------------------------------------------------------------------===//
44
45#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
46#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
47#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
48#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
49
50namespace {
51#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
52
53//===----------------------------------------------------------------------===//
54// Dialect Function Inliner Interface.
55//===----------------------------------------------------------------------===//
56struct TosaInlinerInterface : public DialectInlinerInterface {
57 using DialectInlinerInterface::DialectInlinerInterface;
58
59 //===--------------------------------------------------------------------===//
60 // Analysis Hooks.
61 //===--------------------------------------------------------------------===//
62
63 /// All operations can be inlined by default.
64 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
65 IRMapping &map) const final {
66 return true;
67 }
68
69 /// All regions with If and While parent operators can be inlined.
70 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
71 IRMapping &map) const final {
72 return (isa<tosa::IfOp>(dest->getParentOp()) ||
73 isa<tosa::WhileOp>(dest->getParentOp()));
74 }
75};
76
77/// This class implements the bytecode interface for the Tosa dialect.
78struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
79 TosaDialectBytecodeInterface(Dialect *dialect)
80 : BytecodeDialectInterface(dialect) {}
81
82 //===--------------------------------------------------------------------===//
83 // Attributes
84
85 Attribute readAttribute(DialectBytecodeReader &reader) const override {
86 return ::readAttribute(getContext(), reader);
87 }
88
89 LogicalResult writeAttribute(Attribute attr,
90 DialectBytecodeWriter &writer) const override {
91 return ::writeAttribute(attr, writer);
92 }
93
94 //===--------------------------------------------------------------------===//
95 // Types
96
97 Type readType(DialectBytecodeReader &reader) const override {
98 return ::readType(getContext(), reader);
99 }
100
101 LogicalResult writeType(Type type,
102 DialectBytecodeWriter &writer) const override {
103 return ::writeType(type, writer);
104 }
105
106 void writeVersion(DialectBytecodeWriter &writer) const final {
107 // TODO: Populate.
108 }
109
110 std::unique_ptr<DialectVersion>
111 readVersion(DialectBytecodeReader &reader) const final {
112 // TODO: Populate
113 reader.emitError(msg: "Dialect does not support versioning");
114 return nullptr;
115 }
116
117 LogicalResult upgradeFromVersion(Operation *topLevelOp,
118 const DialectVersion &version) const final {
119 return success();
120 }
121};
122
123} // namespace
124
125//===----------------------------------------------------------------------===//
126// TOSA control flow support.
127//===----------------------------------------------------------------------===//
128
129/// Returns the while loop body.
130SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
131 return {&getBodyGraph()};
132}
133
134//===----------------------------------------------------------------------===//
135// TOSA variable operator support.
136//===----------------------------------------------------------------------===//
137
138static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
139 return to_vector(Range: llvm::map_range(C&: shape, F: [](int64_t dim) {
140 return dim == -1 ? ShapedType::kDynamic : dim;
141 }));
142}
143
144// returns type of variable op
145RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
146 Type elementType = variableOp.getType();
147 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
148 auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
149 return RankedTensorType::get(shape, elementType);
150}
151
152//===----------------------------------------------------------------------===//
153// Tosa dialect initialization.
154//===----------------------------------------------------------------------===//
155
156void TosaDialect::initialize() {
157 addTypes<
158#define GET_TYPEDEF_LIST
159#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
160 >();
161 addOperations<
162#define GET_OP_LIST
163#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
164 >();
165 addAttributes<
166#define GET_ATTRDEF_LIST
167#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
168 >();
169 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
170 declarePromisedInterfaces<
171 mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
172 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
173 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
174 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
175 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
176 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
177 GreaterEqualOp, MatMulOp>();
178}
179
180Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
181 Type type, Location loc) {
182 // Tosa dialect constants only support ElementsAttr unlike standard dialect
183 // constant which supports all attributes.
184 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
185 return builder.create<tosa::ConstShapeOp>(
186 loc, type, llvm::cast<DenseIntElementsAttr>(value));
187 }
188 if (llvm::isa<ElementsAttr>(value))
189 return builder.create<tosa::ConstOp>(loc, type,
190 llvm::cast<ElementsAttr>(value));
191 return nullptr;
192}
193
194//===----------------------------------------------------------------------===//
195// Parsers and printers
196//===----------------------------------------------------------------------===//
197
198namespace {
199
200ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
201 DenseElementsAttr &varShapeAttr,
202 TypeAttr &typeAttr) {
203 if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
204 if (!shapedType.hasRank())
205 return parser.emitError(loc: parser.getCurrentLocation())
206 << "expected ranked type";
207
208 auto elementType = shapedType.getElementType();
209 typeAttr = TypeAttr::get(elementType);
210 ArrayRef<int64_t> shape = shapedType.getShape();
211 Builder builder(parser.getContext());
212 varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
213 return success();
214 }
215 return parser.emitError(loc: parser.getCurrentLocation())
216 << "expected shaped type";
217}
218
219} // namespace
220
221// parses the optional initial value or type for a tosa variable
222// with initial value:
223// tosa.variable @name = dense<0.0> : tensor<1x8xf32>
224//
225// without initial value:
226// tosa.variable @name : tensor<1x8xf32>
227ParseResult mlir::tosa::parseVariableOpTypeOrInitialValue(
228 OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
229 Attribute &initialValueAttr) {
230 if (succeeded(Result: parser.parseOptionalEqual())) {
231 if (failed(Result: parser.parseAttribute(result&: initialValueAttr))) {
232 return parser.emitError(loc: parser.getCurrentLocation())
233 << "expected attribute";
234 }
235 if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
236 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
237 typeAttr);
238 }
239 return parser.emitError(loc: parser.getCurrentLocation())
240 << "expected Typed attr";
241 }
242
243 initialValueAttr = nullptr;
244 Type parsedType;
245 if (failed(Result: parser.parseColonType(result&: parsedType))) {
246 return parser.emitError(loc: parser.getCurrentLocation())
247 << "expected type after colon";
248 }
249 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
250}
251
252void mlir::tosa::printVariableOpTypeOrInitialValue(
253 OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
254 TypeAttr typeAttr, Attribute initialValueAttr) {
255 bool needsSpace = false;
256 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
257 auto shape =
258 convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
259 Type elementType = typeAttr.getValue();
260 RankedTensorType tensorType =
261 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
262 auto tensorTypeAttr = TypeAttr::get(tensorType);
263 p << ": ";
264 p.printAttribute(attr: tensorTypeAttr);
265 needsSpace = true; // subsequent attr value needs a space separator
266 }
267 if (initialValueAttr) {
268 if (needsSpace)
269 p << ' ';
270 p << "= ";
271 p.printAttribute(attr: initialValueAttr);
272 }
273}
274
275//===----------------------------------------------------------------------===//
276// Tosa utilities.
277//===----------------------------------------------------------------------===//
278
279std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
280 if (lhs % rhs != 0)
281 return std::nullopt;
282 return lhs / rhs;
283}
284
285Type getStorageElementTypeOrSelf(Type type) {
286 auto srcType = getElementTypeOrSelf(type);
287 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(Val&: srcType))
288 srcType = quantType.getStorageType();
289 return srcType;
290}
291
292Type getStorageElementTypeOrSelf(Value value) {
293 return getStorageElementTypeOrSelf(type: value.getType());
294}
295
296static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
297 Value valZp, StringRef name) {
298 Type eType = getStorageElementTypeOrSelf(type: val.getType());
299 Type eZpType = getStorageElementTypeOrSelf(type: valZp.getType());
300
301 bool bothInts =
302 mlir::isa<IntegerType>(Val: eType) && mlir::isa<IntegerType>(Val: eZpType);
303 bool sameBitWidth =
304 (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
305
306 if (!bothInts || !sameBitWidth) {
307 return op->emitOpError()
308 << "expected " << name << " and " << name
309 << "_zp to both be integer of the same bitwidth, but got " << eType
310 << " vs. " << eZpType;
311 }
312 return success();
313}
314
315// Create a pad-const const tensor with value of `val` of required data-type
316Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc,
317 Value src, int32_t val) {
318 const auto srcType = getElementTypeOrSelf(val: src);
319 const auto srcElemType = getStorageElementTypeOrSelf(value: src);
320 const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
321 const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
322 const auto padConstAttr{
323 llvm::isa<FloatType>(Val: srcElemType)
324 ? DenseElementsAttr::get(padConstEType,
325 builder.getFloatAttr(srcElemType, val))
326 : DenseElementsAttr::get(padConstEType,
327 builder.getIntegerAttr(srcElemType, val))};
328 return builder.create<tosa::ConstOp>(loc, padConstType, padConstAttr);
329}
330
331//===----------------------------------------------------------------------===//
332// TOSA Operator Verifiers.
333//===----------------------------------------------------------------------===//
334
335template <typename T>
336static LogicalResult verifyConvOp(T op) {
337 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
338 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
339
340 auto inputEType = inputType.getElementType();
341 auto weightEType = weightType.getElementType();
342 auto biasEType =
343 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
344 auto resultEType =
345 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
346 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
347 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
348
349 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
350 inputEType = quantType.getStorageType();
351
352 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
353 weightEType = quantType.getStorageType();
354
355 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
356 biasEType = quantType.getStorageType();
357
358 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
359 resultEType = quantType.getStorageType();
360
361 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
362 // for now, only enforce bias element type == result element type for
363 // float types.
364 op.emitOpError(
365 "expect both bias and result to have same element type, got ")
366 << biasEType << " and " << resultEType;
367 return failure();
368 }
369
370 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
371 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
372 if (inputEType != weightEType) {
373 op.emitOpError(
374 "expect both input and weight to have same element type, got ")
375 << inputEType << " and " << weightEType;
376 return failure();
377 }
378 }
379
380 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
381 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
382
383 // Either both must be float or both non-float.
384 if (inputIsFloat != weightIsFloat) {
385 op.emitOpError(
386 "expect both input and weight to be float or not together, got ")
387 << inputEType << " and " << weightEType;
388 return failure();
389 }
390
391 auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
392 if (inputEType != inputZpEType) {
393 return op.emitOpError("expect both input and its zero point are the same "
394 "element type, got ")
395 << inputEType << " and " << inputZpEType;
396 }
397
398 auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
399 if (weightEType != weightZpEType) {
400 return op.emitOpError("expect both weight and its zero point are the same "
401 "element type, got ")
402 << weightEType << " and " << weightZpEType;
403 }
404
405 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
406 if (succeeded(Result: maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
407 return failure();
408
409 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
410 if (succeeded(Result: maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
411 return failure();
412
413 return success();
414}
415
416LogicalResult tosa::ConstOp::verify() {
417
418 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
419 auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
420
421 if (!attrType || !outputType) {
422 emitOpError("expected tensors for attr/result type");
423 return failure();
424 }
425
426 if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
427 outputType.getElementType())) {
428 if (result.getStorageType() == attrType.getElementType())
429 return success();
430 }
431
432 if (attrType.getElementType() != outputType.getElementType()) {
433 emitOpError("expected same attr/result element types");
434 return failure();
435 }
436
437 return success();
438}
439
440template <typename T>
441static LogicalResult verifyConvOpModes(T op) {
442 auto inputEType =
443 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
444
445 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
446 inputEType = quantType.getStorageType();
447
448 auto accType = op.getAccType();
449 if (inputEType.isInteger(8) && !accType.isInteger(32))
450 return op.emitOpError("accumulator type for i8 tensor is not i32");
451
452 if (inputEType.isInteger(16) && !accType.isInteger(48))
453 return op.emitOpError("accumulator type for i16 tensor is not i48");
454
455 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
456 return op.emitOpError("accumulator type for f8 tensor is not f16");
457
458 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
459 return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
460
461 if (inputEType.isBF16() && !accType.isF32())
462 return op.emitOpError("accumulator type for bf16 tensor is not f32");
463
464 if (inputEType.isF32() && !accType.isF32())
465 return op.emitOpError("accumulator type for f32 tensor is not f32");
466
467 auto resultEType =
468 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
469
470 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
471 resultEType = quantType.getStorageType();
472
473 return success();
474}
475
476//===----------------------------------------------------------------------===//
477// ERROR_IF functions.
478// ERROR_IF is a predicate that must set an error if the condition holds.
479//===----------------------------------------------------------------------===//
480
481template <typename T>
482static LogicalResult verifyConvOpErrorIf(T op) {
483 llvm::ArrayRef<int64_t> padding = op.getPad();
484 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
485 return op.emitOpError("expect all padding values to be >= 0, got ")
486 << padding;
487
488 llvm::ArrayRef<int64_t> strides = op.getStride();
489 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
490 return op.emitOpError("expect all stride values to be >= 1, got ")
491 << strides;
492
493 llvm::ArrayRef<int64_t> dilations = op.getDilation();
494 if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
495 return op.emitOpError("expect all dilation values to be >= 1, got ")
496 << dilations;
497
498 const RankedTensorType outputType =
499 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
500 if (!outputType)
501 // Skip following checks if output is not ranked
502 return success();
503
504 const RankedTensorType inputType =
505 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
506 const RankedTensorType weightType =
507 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
508
509 if (inputType && weightType) {
510 const auto verifyOutputSize =
511 [&op](const int64_t inputSize, const int64_t kernelSize,
512 const int64_t outputSize, const int64_t padBefore,
513 const int64_t padAfter, const int64_t stride,
514 const int64_t dilation, const llvm::StringRef dimName,
515 const llvm::StringRef dimAxis,
516 const llvm::StringRef padBeforeName,
517 const llvm::StringRef padAfterName) -> LogicalResult {
518 if (inputSize == ShapedType::kDynamic ||
519 kernelSize == ShapedType::kDynamic)
520 return success();
521
522 // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
523
524 const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
525 lhs: inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
526 rhs: stride);
527 if (!calculatedOutSizeMinusOne.has_value())
528 return op.emitOpError("expected input_")
529 << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
530 << padAfterName << " - (kernel_" << dimName
531 << " - 1) * dilation_" << dimAxis
532 << " to be wholly divisible by stride_" << dimAxis << ", got ("
533 << inputSize << " - 1 + " << padBefore << " + " << padAfter
534 << " - (" << kernelSize << " - 1) * " << dilation << ") / "
535 << stride;
536
537 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
538 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
539 return op.emitOpError("calculated output ")
540 << dimName << " did not match expected: "
541 << "calculated=" << calculatedOutSize
542 << ", expected=" << outputSize;
543
544 return success();
545 };
546
547 // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
548 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
549 if (failed(verifyOutputSize(
550 inputType.getDimSize(1), weightType.getDimSize(1),
551 outputType.getDimSize(1), padding[0], padding[1], strides[0],
552 dilations[0], "height", "y", "top", "bottom")))
553 return failure();
554
555 if (failed(verifyOutputSize(
556 inputType.getDimSize(2), weightType.getDimSize(2),
557 outputType.getDimSize(2), padding[2], padding[3], strides[1],
558 dilations[1], "width", "x", "left", "right")))
559 return failure();
560 }
561
562 // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
563 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
564 if (failed(verifyOutputSize(
565 inputType.getDimSize(1), weightType.getDimSize(0),
566 outputType.getDimSize(1), padding[0], padding[1], strides[0],
567 dilations[0], "height", "y", "top", "bottom")))
568 return failure();
569
570 if (failed(verifyOutputSize(
571 inputType.getDimSize(2), weightType.getDimSize(1),
572 outputType.getDimSize(2), padding[2], padding[3], strides[1],
573 dilations[1], "width", "x", "left", "right")))
574 return failure();
575 }
576
577 // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
578 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
579 if (failed(verifyOutputSize(
580 inputType.getDimSize(1), weightType.getDimSize(1),
581 outputType.getDimSize(1), padding[0], padding[1], strides[0],
582 dilations[0], "depth", "d", "front", "back")))
583 return failure();
584
585 if (failed(verifyOutputSize(
586 inputType.getDimSize(2), weightType.getDimSize(2),
587 outputType.getDimSize(2), padding[2], padding[3], strides[1],
588 dilations[1], "height", "y", "top", "bottom")))
589 return failure();
590
591 if (failed(verifyOutputSize(
592 inputType.getDimSize(3), weightType.getDimSize(3),
593 outputType.getDimSize(3), padding[4], padding[5], strides[2],
594 dilations[2], "width", "x", "left", "right")))
595 return failure();
596 }
597 }
598
599 const RankedTensorType biasType =
600 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
601 if (!biasType)
602 // Skip following checks if bias is not ranked
603 return success();
604
605 const int64_t biasChannels = biasType.getDimSize(0);
606 const int64_t outputChannels =
607 outputType.getDimSize(outputType.getRank() - 1);
608 if (biasChannels == ShapedType::kDynamic ||
609 outputChannels == ShapedType::kDynamic)
610 // Skip following checks if biasChannels or outputChannels is dynamic dim
611 return success();
612
613 if (biasChannels != outputChannels && biasChannels != 1)
614 return op.emitOpError(
615 "bias channels expected to be equal to output channels (")
616 << outputChannels << ") or 1, got " << biasChannels;
617
618 return success();
619}
620
621// Verify whether same type and shape of the given two types.
622static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
623 StringRef name1, Type type2,
624 StringRef name2) {
625 auto shapeType1 = dyn_cast<ShapedType>(type1);
626 auto shapeType2 = dyn_cast<ShapedType>(type2);
627 if (!shapeType1 || !shapeType2)
628 return failure();
629
630 auto elemType1 = shapeType1.getElementType();
631 auto elemType2 = shapeType2.getElementType();
632 if (elemType1 != elemType2)
633 return op->emitOpError()
634 << "require same element type for " << name1 << " (" << elemType1
635 << ") and " << name2 << " (" << elemType2 << ")";
636
637 if (failed(Result: verifyCompatibleShape(type1, type2)))
638 return op->emitOpError()
639 << "require same shapes for " << name1 << " (" << type1 << ") and "
640 << name2 << " (" << type2 << ")";
641
642 return success();
643}
644
645// Verify whether same length, type, and shape of the given two tensor lists.
646static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
647 StringRef name1,
648 ValueRange list2,
649 StringRef name2) {
650 if (list1.size() != list2.size())
651 return op->emitOpError()
652 << "require same number of values in " << name1 << " ("
653 << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
654
655 for (auto [type1, type2] :
656 llvm::zip_equal(t: list1.getTypes(), u: list2.getTypes())) {
657 if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
658 return failure();
659 }
660
661 return success();
662}
663
664static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
665 ShapeAdaptor shapeAdaptor(type);
666 if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
667 return success();
668
669 return shapeAdaptor.getNumElements() == 1 ? success() : failure();
670}
671
672// Returns the first declaration point prior to this operation or failure if
673// not found.
674static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
675 StringRef symName) {
676 ModuleOp module = op->getParentOfType<ModuleOp>();
677 tosa::VariableOp varOp = nullptr;
678
679 // TODO: Adopt SymbolTable trait to Varible ops.
680 // Currently, the variable's definition point is searched via walk(),
681 // starting from the top-level ModuleOp and stopping at the point of use. Once
682 // TOSA control flow and variable extensions reach the complete state, may
683 // leverage MLIR's Symbol Table functionality to look up symbol and enhance
684 // the search to a TOSA specific graph traversal over the IR structure.
685 module.walk([&](Operation *tempOp) {
686 // Reach this op itself.
687 if (tempOp == op) {
688 return WalkResult::interrupt();
689 }
690
691 if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
692 if (symName == tosaOp.getName()) {
693 varOp = tosaOp;
694 return WalkResult::interrupt();
695 }
696 }
697
698 return WalkResult::advance();
699 });
700
701 if (varOp)
702 return varOp;
703
704 return failure();
705}
706
707template <typename T>
708static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
709 StringRef symName = op.getName();
710 FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
711 if (failed(varOp))
712 return op->emitOpError("'")
713 << symName << "' has not been declared by 'tosa.variable'";
714
715 // Verify type and shape
716 auto variableType = getVariableType(varOp.value());
717 if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
718 "the input tensor")
719 .failed())
720 return failure();
721
722 return success();
723}
724
725// verify that inType and outType have same element types
726template <typename T>
727static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
728 auto inputType = llvm::dyn_cast<TensorType>(Val&: inType);
729 auto outputType = llvm::dyn_cast<TensorType>(Val&: outType);
730 if (!inputType) {
731 op.emitOpError("expect shaped tensor for input, got ") << inType;
732 return failure();
733 }
734 if (!outputType) {
735 op.emitOpError("expect shaped tensor for output, got ") << outType;
736 return failure();
737 }
738 auto inputElementType = inputType.getElementType();
739 auto outputElementType = outputType.getElementType();
740 auto inputQuantType =
741 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(Val&: inputElementType);
742 auto outputQuantType =
743 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(Val&: outputElementType);
744 if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
745 (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
746 inputElementType != outputElementType) {
747 // only check if both element types are int/index/float/UniformQuantized
748 // eg, not sure how to check quant::QuantizedType
749 // this happens in test_conv2d_q_grouped_convolution in
750 // tfl-to-tosa-pipeline.mlir
751 op.emitOpError("expect input and output to have same element type, got ")
752 << inputElementType << " and " << outputElementType;
753 return failure();
754 }
755 return success();
756}
757
758LogicalResult tosa::ArgMaxOp::verify() {
759 const ShapedType resultType = llvm::cast<ShapedType>(getType());
760
761 // Ensure output is of 32-bit integer
762 if (const auto resultETy = resultType.getElementType();
763 !resultETy.isIntOrIndex())
764 return emitOpError("result tensor is not of integer type");
765
766 const auto inputType = llvm::cast<ShapedType>(getInput().getType());
767 if (!inputType.hasRank())
768 return success();
769
770 // Ensure axis is within the tensor rank
771 const int64_t axis = getAxisAttr().getInt();
772 if (((axis < 0) || axis >= inputType.getRank()))
773 return emitOpError("specified axis is outside the rank of the tensor");
774
775 if (!resultType.hasRank())
776 return success();
777
778 const ArrayRef<int64_t> inputShape = inputType.getShape();
779 const ArrayRef<int64_t> outputShape = resultType.getShape();
780 llvm::SmallVector<int64_t> expectedOutputShape(inputShape);
781 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
782 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
783 return emitOpError("expected output shape '")
784 << expectedOutputShape << "', got '" << outputShape << "'";
785
786 return success();
787}
788
789template <typename T>
790static LogicalResult verifyPoolingOp(T op) {
791 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
792 if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
793 return op.emitOpError("expect all kernel values to be >= 1, got ")
794 << kernel;
795
796 const llvm::ArrayRef<int64_t> strides = op.getStride();
797 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
798 return op.emitOpError("expect all stride values to be >= 1, got ")
799 << strides;
800
801 const llvm::ArrayRef<int64_t> padding = op.getPad();
802 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
803 return op.emitOpError("expect all padding values to be >= 0, got ")
804 << padding;
805
806 // Padding must be less than kernel size to avoid a divide-by-zero
807 const int64_t kernelX = kernel[1];
808 const int64_t padLeft = padding[2];
809 const int64_t padRight = padding[3];
810 if (padRight >= kernelX || padLeft >= kernelX)
811 return op.emitOpError("expected left/right padding to be less than the "
812 "width of the kernel, got pad_left=")
813 << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;
814
815 const int64_t kernelY = kernel[0];
816 const int64_t padTop = padding[0];
817 const int64_t padBottom = padding[1];
818 if (padTop >= kernelY || padBottom >= kernelY)
819 return op.emitOpError("expected top/bottom padding to be less than the "
820 "height of the kernel, got pad_top=")
821 << padTop << ", pad_bottom=" << padBottom
822 << ", kernel_y=" << kernelY;
823
824 const auto inputType =
825 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
826 const auto outputType =
827 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
828 if (!inputType || !outputType)
829 return success();
830
831 const auto verifyOutputSize =
832 [&op](const int64_t inputSize, const int64_t outputSize,
833 const int64_t kernelSize, const int64_t strideSize,
834 const int64_t padBefore, const int64_t padAfter,
835 const llvm::StringRef dimName, const llvm::StringRef dimAxis,
836 const llvm::StringRef padBeforeName,
837 const llvm::StringRef padAfterName) -> LogicalResult {
838 if (ShapedType::isDynamic(inputSize))
839 return success();
840
841 const std::optional<int64_t> calculatedOutSizeMinusOne =
842 idivCheck(lhs: inputSize + padBefore + padAfter - kernelSize, rhs: strideSize);
843 if (!calculatedOutSizeMinusOne.has_value())
844 return op.emitOpError("expected input_")
845 << dimName << " + pad_" << padBeforeName << " + pad_"
846 << padAfterName << " - kernel_" << dimAxis
847 << " to be wholly divisible by stride_" << dimAxis << ", got ("
848 << inputSize << " + " << padBefore << " + " << padAfter << " - "
849 << kernelSize << ") / " << strideSize;
850
851 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
852 if (!ShapedType::isDynamic(outputSize) && calculatedOutSize != outputSize)
853 return op.emitOpError("calculated output ")
854 << dimName << " did not match expected: "
855 << "calculated=" << calculatedOutSize
856 << ", expected=" << outputSize;
857
858 return success();
859 };
860
861 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
862 kernel[0], strides[0], padding[0], padding[1],
863 "height", "y", "top", "bottom")))
864 return failure();
865
866 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
867 kernel[1], strides[1], padding[2], padding[3],
868 "width", "x", "left", "right")))
869 return failure();
870
871 return success();
872}
873
874LogicalResult tosa::AvgPool2dOp::verify() {
875 if (failed(verifyPoolingOp(*this)))
876 return failure();
877
878 const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
879 const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
880 const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
881 const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
882
883 auto accType = getAccType();
884 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
885 return emitOpError("accumulator type for integer tensor is not i32");
886
887 if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
888 return emitOpError("accumulator type for f16 tensor is not f16/f32");
889
890 if (inputETy.isBF16() && !accType.isF32())
891 return emitOpError("accumulator type for bf16 tensor is not f32");
892
893 if (inputETy.isF32() && !accType.isF32())
894 return emitOpError("accumulator type for f32 tensor is not f32");
895
896 if (inputETy != inputZpETy)
897 return emitOpError("expect both input and its zero point are the same "
898 "element type, got ")
899 << inputETy << " and " << inputZpETy;
900
901 if (resultETy != outputZpETy)
902 return emitOpError("expect both output and its zero point are the same "
903 "element type, got ")
904 << resultETy << " and " << outputZpETy;
905
906 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
907 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
908 return failure();
909
910 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
911 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
912 return failure();
913
914 return success();
915}
916
917LogicalResult tosa::ClampOp::verify() {
918 mlir::Type inputETy =
919 llvm::cast<ShapedType>(getInput().getType()).getElementType();
920 if (auto quantType =
921 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
922 inputETy = quantType.getStorageType();
923 }
924 mlir::Type outputETy =
925 llvm::cast<ShapedType>(getOutput().getType()).getElementType();
926 if (auto quantType =
927 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
928 outputETy = quantType.getStorageType();
929 }
930 if (inputETy != outputETy)
931 return emitOpError("input/output element types are incompatible.");
932
933 auto maxValAttr = getMaxValAttr();
934 auto minValAttr = getMinValAttr();
935
936 unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
937
938 if (inputETy.isInteger(dataTypeBitWidth)) {
939 // if input datatype is integer, check that the min_val/max_val attributes
940 // are integer attributes, and that their type is the same as the input's
941 // datatype
942 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
943 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
944 if (!intMaxValAttr || !intMinValAttr ||
945 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
946 (intMaxValAttr.getType() != inputETy))
947 return emitOpError("min/max attributes types are incompatible with "
948 "input/output element types.");
949
950 const bool isUnsigned = cast<IntegerType>(inputETy).isUnsigned();
951 const APInt minVal = intMinValAttr.getValue();
952 const APInt maxVal = intMaxValAttr.getValue();
953 if (isUnsigned ? maxVal.ult(minVal) : maxVal.slt(minVal))
954 return emitOpError("expected min_val <= max_val, got min_val=")
955 << minValAttr << ", max_val=" << maxValAttr;
956 } else {
957 // otherwise, input datatype is float, check that the min_val/max_val
958 // attributes share the same type and that their type is the same as the
959 // input's datatype
960 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
961 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
962 if (!floatMaxValAttr || !floatMinValAttr ||
963 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
964 (floatMaxValAttr.getType() != inputETy))
965 return emitOpError("min/max attributes types are incompatible with "
966 "input/output element types.");
967
968 const APFloat minVal = floatMinValAttr.getValue();
969 const APFloat maxVal = floatMaxValAttr.getValue();
970 if (minVal.isNaN() || maxVal.isNaN())
971 return emitOpError("min/max attributes should not be 'NaN', got min_val=")
972 << minValAttr << ", max_val=" << maxValAttr;
973
974 if (maxVal < minVal)
975 return emitOpError("expected min_val <= max_val, got min_val=")
976 << minValAttr << ", max_val=" << maxValAttr;
977 }
978
979 return success();
980}
981
982//===----------------------------------------------------------------------===//
983// TOSA Operator Quantization Builders.
984//===----------------------------------------------------------------------===//
985
986/// This builder is called on all convolution operators except TransposeConv,
987/// which has specialized output shape semantics. The builder also defines the
988/// bitwidth of the output given the bit width of the input & weight content.
989static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
990 Type outputType, Value input, Value weight,
991 Value bias, DenseI64ArrayAttr pad,
992 DenseI64ArrayAttr stride,
993 DenseI64ArrayAttr dilation,
994 TypeAttr accType) {
995 auto zps = createZPsAsConst(builder, input, weight);
996 result.addOperands(newOperands: {input, weight, bias, zps.first, zps.second});
997 result.addAttribute("pad", pad);
998 result.addAttribute("stride", stride);
999 result.addAttribute("dilation", dilation);
1000 result.addAttribute("acc_type", accType);
1001 Type finalOutputType = outputType;
1002 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1003 if (quantAttr) {
1004 finalOutputType =
1005 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1006 }
1007 result.addTypes(newTypes: finalOutputType);
1008}
1009
1010/// Handles tosa.transpose_conv2d which has outpad and output shape
1011/// attributes.
1012static void
1013buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
1014 Type outputType, Value input, Value weight,
1015 Value bias, DenseI64ArrayAttr outpad,
1016 DenseI64ArrayAttr stride, TypeAttr accType) {
1017 auto zps = createZPsAsConst(builder, input, weight);
1018 result.addOperands(newOperands: {input, weight, bias, zps.first, zps.second});
1019 result.addAttribute("out_pad", outpad);
1020 result.addAttribute("stride", stride);
1021 result.addAttribute("acc_type", accType);
1022 Type finalOutputType = outputType;
1023 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1024 if (quantAttr) {
1025 finalOutputType =
1026 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1027 }
1028 result.addTypes(newTypes: finalOutputType);
1029}
1030
1031/// The tosa.matmul op is also intended to be generated where a fully_connected
1032/// op must be constructed where the weight is not a constant. In this case,
1033/// the fully_connected op must be expressed using matmul.
1034/// TODO: Add link to the leglization document explaining this.
1035static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
1036 OperationState &result, Type outputType,
1037 Value a, Value b) {
1038 auto zps = createZPsAsConst(builder, input: a, weight: b);
1039 result.addOperands(newOperands: {a, b, zps.first, zps.second});
1040
1041 Type finalOutputType{outputType};
1042 if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
1043 auto eType = getStorageElementTypeOrSelf(type: a.getType());
1044 auto inputBits = eType.getIntOrFloatBitWidth();
1045
1046 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1047 assert(outputShapedType && "Output must be a shaped type");
1048
1049 IntegerType accElementType;
1050 if (inputBits == 16)
1051 accElementType = builder.getIntegerType(48);
1052 else
1053 accElementType = builder.getI32Type();
1054
1055 finalOutputType = outputShapedType.clone(accElementType);
1056 }
1057 result.addTypes(newTypes: finalOutputType);
1058}
1059
1060/// Both the tosa.avg_pool2d and unary ops use the same
1061/// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
1062/// has additional parameters not part of the unary ops.
1063static void
1064buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
1065 Type outputType, Value input,
1066 DenseArrayAttr kernel, DenseArrayAttr stride,
1067 DenseArrayAttr pad, TypeAttr accType) {
1068 const Location loc{result.location};
1069 int64_t inputZp{0};
1070 int64_t outputZp{0};
1071
1072 if (auto quantAttr =
1073 buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1074 inputZp = quantAttr.getInputZp();
1075 outputZp = quantAttr.getOutputZp();
1076 }
1077 const std::optional<Value> inputZpOp =
1078 createZeroPointTensor(builder, loc, srcElemType: input.getType(), zp: inputZp);
1079 if (!inputZpOp) {
1080 (void)emitError(
1081 loc,
1082 message: "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1083 }
1084 const std::optional<Value> outputZpOp =
1085 createZeroPointTensor(builder, loc, srcElemType: outputType, zp: outputZp);
1086 if (!outputZpOp) {
1087 (void)emitError(loc, message: "Failed to create output zero point tensor for "
1088 "quantized AVG_POOL2D op");
1089 }
1090
1091 if (inputZpOp && outputZpOp) {
1092 result.addOperands(newOperands: {input, inputZpOp.value(), outputZpOp.value()});
1093 } else {
1094 // failed to create one or more zero points above: just add input as
1095 // operands this will trigger error in building the op because of missing
1096 // zero points
1097 result.addOperands(newOperands: {input});
1098 }
1099 result.addAttribute("kernel", kernel);
1100 result.addAttribute("stride", stride);
1101 result.addAttribute("pad", pad);
1102 result.addAttribute("acc_type", accType);
1103 result.types.push_back(Elt: outputType);
1104}
1105
1106/// This builder is called on single-parameter negate operator
1107/// to construct input and output zero points based on their
1108/// types.
1109static void buildNegateOpWithQuantInfo(OpBuilder &builder,
1110 OperationState &result, Type outputType,
1111 Value input) {
1112 const Location loc{result.location};
1113 int64_t input1Zp{0};
1114 int64_t outputZp{0};
1115 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
1116 if (quantAttr) {
1117 input1Zp = quantAttr.getInputZp();
1118 outputZp = quantAttr.getOutputZp();
1119 }
1120 const std::optional<Value> input1ZpOp =
1121 createZeroPointTensor(builder, loc, srcElemType: input.getType(), zp: input1Zp);
1122 if (!input1ZpOp) {
1123 (void)emitError(
1124 loc, message: "Failed to create input1 zero point for quantized NEGATE op");
1125 }
1126
1127 const std::optional<Value> outputZpOp =
1128 createZeroPointTensor(builder, loc, srcElemType: input.getType(), zp: outputZp);
1129 if (!outputZpOp) {
1130 (void)emitError(
1131 loc, message: "Failed to create output zero point for quantized NEGATE op");
1132 }
1133
1134 if (input1ZpOp && outputZpOp) {
1135 result.addOperands(newOperands: {input, input1ZpOp.value(), outputZpOp.value()});
1136 } else {
1137 // failed to create one or more zero points above: just add input as
1138 // operands. This will trigger error in building the op because of
1139 // missing zero points
1140 result.addOperands(newOperands: {input});
1141 }
1142
1143 result.types.push_back(Elt: outputType);
1144}
1145
1146/// This builder is called on TOSA pad operator that needs to create its own
1147/// OptionalAttr quantization_attr parameter to scale the padding values
1148/// correctly. No pad_const is interpreted as zero-padding.
1149static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
1150 Type outputType, Value input,
1151 Value paddings) {
1152 const Location loc{result.location};
1153 int32_t zp{0};
1154 const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
1155 if (quantAttr) {
1156 zp = static_cast<int32_t>(quantAttr.getInputZp());
1157 }
1158 const auto padConstOp{createPadConstTensor(builder, loc, src: input, val: zp)};
1159 result.addOperands(newOperands: {input, paddings, padConstOp});
1160 result.types.push_back(Elt: outputType);
1161}
1162
1163static void buildVariableOp(OpBuilder &builder, OperationState &result,
1164 StringRef name, Type variableType,
1165 Attribute initialValue) {
1166 const Location loc{result.location};
1167 auto nameAttr = builder.getStringAttr(name);
1168
1169 auto shapedType = dyn_cast<ShapedType>(variableType);
1170 if (!shapedType) {
1171 (void)emitError(loc, message: "variable type must be a shaped type");
1172 return;
1173 }
1174 if (!shapedType.hasRank()) {
1175 (void)emitError(loc, message: "variable type must be a ranked type");
1176 return;
1177 }
1178
1179 auto elementType = shapedType.getElementType();
1180 auto elementTypeAttr = TypeAttr::get(elementType);
1181 ArrayRef<int64_t> shape = shapedType.getShape();
1182 auto varShapeAttr = builder.getIndexTensorAttr(values: convertFromMlirShape(shape));
1183
1184 result.addAttribute("name", nameAttr);
1185 result.addAttribute("var_shape", varShapeAttr);
1186 result.addAttribute("type", elementTypeAttr);
1187 result.addAttribute(name: "initial_value", attr: initialValue);
1188}
1189
1190//===----------------------------------------------------------------------===//
1191// TOSA Operator Return Type Inference.
1192//===----------------------------------------------------------------------===//
1193
1194static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1195 SmallVector<int64_t> &outShape) {
1196 int64_t outRank = 0;
1197 for (int i = 0, e = operands.size(); i != e; ++i) {
1198 auto shape = operands.getShape(index: i);
1199 if (!shape.hasRank()) {
1200 // TODO(jennik): Update function to have better case handling for
1201 // invalid operands and for ranked tensors.
1202 return failure();
1203 }
1204 outRank = std::max<int64_t>(a: outRank, b: shape.getRank());
1205 }
1206
1207 outShape.resize(N: outRank, NV: 1);
1208
1209 for (int i = 0, e = operands.size(); i != e; ++i) {
1210 auto shape = operands.getShape(index: i);
1211 auto rankDiff = outShape.size() - shape.getRank();
1212
1213 for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1214 auto dim1 = outShape[i + rankDiff];
1215 auto dim2 = shape.getDimSize(index: i);
1216 auto resolvedDim = dim1;
1217
1218 if (dim1 == 1) {
1219 resolvedDim = dim2;
1220 } else if (dim2 == 1) {
1221 resolvedDim = dim1;
1222 } else if (dim1 != dim2) {
1223 return failure();
1224 }
1225 outShape[i + rankDiff] = resolvedDim;
1226 }
1227 }
1228
1229 return success();
1230}
1231
1232LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1233 MLIRContext *context, ::std::optional<Location> location,
1234 ArgMaxOp::Adaptor adaptor,
1235 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1236 ShapeAdaptor inputShape(adaptor.getInput().getType());
1237 IntegerAttr axis = adaptor.getProperties().axis;
1238 int32_t axisVal = axis.getValue().getSExtValue();
1239
1240 if (!inputShape.hasRank()) {
1241 inferredReturnShapes.push_back(ShapedTypeComponents());
1242 return success();
1243 }
1244
1245 SmallVector<int64_t> outShape;
1246 outShape.reserve(inputShape.getRank() - 1);
1247 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1248 if (i == axisVal)
1249 continue;
1250 outShape.push_back(inputShape.getDimSize(i));
1251 }
1252
1253 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1254 return success();
1255}
1256
1257LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1258 MLIRContext *context, ::std::optional<Location> location,
1259 RFFT2dOp::Adaptor adaptor,
1260 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1261 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1262
1263 if (!inputShape.hasRank())
1264 return failure();
1265
1266 llvm::SmallVector<int64_t> outputShape;
1267 outputShape.resize(3, ShapedType::kDynamic);
1268 outputShape[0] = inputShape.getDimSize(0);
1269 outputShape[1] = inputShape.getDimSize(1);
1270 int64_t inWidth = inputShape.getDimSize(2);
1271
1272 // Note that we can support this calculation symbolically
1273 // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
1274 if (inWidth != ShapedType::kDynamic)
1275 outputShape[2] = inWidth / 2 + 1;
1276
1277 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1278 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1279
1280 return success();
1281}
1282
1283static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
1284 const llvm::StringRef dimName) {
1285 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1286 if (!isPowerOfTwo)
1287 return op->emitOpError(message: "expected ")
1288 << dimName << " to be a power of two, got " << dimSize;
1289
1290 return success();
1291}
1292
1293LogicalResult tosa::RFFT2dOp::verify() {
1294 const auto outputTypes = getResultTypes();
1295 if (failed(verifyCompatibleShapes(outputTypes)))
1296 return emitOpError("expected output shapes to match, got ") << outputTypes;
1297
1298 const auto inputType =
1299 llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1300 if (!inputType)
1301 return success();
1302
1303 const int64_t height = inputType.getDimSize(1);
1304 if (!ShapedType::isDynamic(height) &&
1305 failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1306 return failure();
1307
1308 const int64_t width = inputType.getDimSize(2);
1309 if (!ShapedType::isDynamic(width) &&
1310 failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1311 return failure();
1312
1313 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1314 if (!outputType)
1315 return success();
1316
1317 // Batch and height input/output dimensions should match
1318 if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
1319 outputType.getShape().drop_back())))
1320 return emitOpError("expected batch and height dimensions of input/output "
1321 "to match, got input=")
1322 << inputType << " output=" << outputType;
1323
1324 // Output width dimension expected to be input_width / 2 + 1
1325 const int64_t outputWidth = outputType.getDimSize(2);
1326 if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
1327 (outputWidth != (width / 2) + 1))
1328 return emitOpError(
1329 "expected output width to be equal to input_width / 2 + 1, got ")
1330 << outputWidth;
1331
1332 return success();
1333}
1334
1335LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1336 MLIRContext *context, ::std::optional<Location> location,
1337 FFT2dOp::Adaptor adaptor,
1338 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1339 inferredReturnShapes.push_back(
1340 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1341 inferredReturnShapes.push_back(
1342 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1343 return success();
1344}
1345
1346LogicalResult tosa::FFT2dOp::verify() {
1347 const auto inputRealType =
1348 llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1349 const auto inputImagType =
1350 llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
1351 if (!inputRealType || !inputImagType)
1352 return success();
1353
1354 const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1355 return ShapedType::isDynamic(a) ? a : b;
1356 };
1357
1358 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1359 inputImagType.getDimSize(1));
1360 if (!ShapedType::isDynamic(height) &&
1361 failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1362 return failure();
1363
1364 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1365 inputImagType.getDimSize(2));
1366 if (!ShapedType::isDynamic(width) &&
1367 failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1368 return failure();
1369
1370 return success();
1371}
1372
1373LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1374 MLIRContext *context, ::std::optional<Location> location,
1375 ConcatOp::Adaptor adaptor,
1376 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1377 // Infer all dimension sizes by reducing based on inputs.
1378 const Properties &prop = adaptor.getProperties();
1379 int32_t axis = prop.axis.getValue().getSExtValue();
1380 llvm::SmallVector<int64_t> outputShape;
1381 bool hasRankedInput = false;
1382 for (auto operand : adaptor.getOperands()) {
1383 ShapeAdaptor operandShape(operand.getType());
1384 if (!operandShape.hasRank())
1385 continue;
1386
1387 // Copy the Operand's rank.
1388 if (!hasRankedInput)
1389 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1390
1391 // Copy shapes until the dim is non-dynamic.
1392 for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1393 if (i == axis || operandShape.isDynamicDim(i))
1394 continue;
1395 if (outputShape[i] == ShapedType::kDynamic)
1396 outputShape[i] = operandShape.getDimSize(i);
1397 if (outputShape[i] != operandShape.getDimSize(i))
1398 return emitOptionalError(location,
1399 "Cannot concat tensors with different sizes"
1400 " on the non-axis dimension ",
1401 i);
1402 }
1403
1404 hasRankedInput = true;
1405 }
1406
1407 if (adaptor.getInput1().empty())
1408 return failure();
1409
1410 Type inputType =
1411 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1412 if (!hasRankedInput) {
1413 inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1414 return success();
1415 }
1416
1417 // Determine the dimension size along the concatenation axis.
1418 int64_t concatDimSize = 0;
1419 for (auto operand : adaptor.getOperands()) {
1420 ShapeAdaptor operandShape(operand.getType());
1421
1422 // We need to know the length of the concatenation axis of all inputs to
1423 // determine the dimension size of the output shape.
1424 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1425 concatDimSize = ShapedType::kDynamic;
1426 break;
1427 }
1428
1429 concatDimSize += operandShape.getDimSize(axis);
1430 }
1431
1432 outputShape[axis] = concatDimSize;
1433
1434 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1435 return success();
1436}
1437
1438LogicalResult tosa::ConcatOp::verify() {
1439 // check that each input has same element type as output
1440 auto outType = getOutput().getType();
1441 const Operation::operand_range inputList = getInput1();
1442
1443 // Check there is at least one input
1444 if (inputList.empty())
1445 return emitOpError("expect at least one input");
1446
1447 if (!llvm::all_of(inputList, [&](auto input) {
1448 return succeeded(verifySameElementTypes(
1449 *this, /* inType = */ input.getType(), outType));
1450 })) {
1451 return failure();
1452 }
1453
1454 const int32_t axis = getAxis();
1455 ShapeAdaptor firstRankedInputShape = nullptr;
1456 for (const auto &input : inputList) {
1457 const Type inputType = input.getType();
1458 ShapeAdaptor currShape(inputType);
1459 if (currShape.hasRank()) {
1460 firstRankedInputShape = currShape;
1461 // Check axis is in expected range
1462 if (axis < 0 || axis >= firstRankedInputShape.getRank())
1463 return emitOpError("expect axis to be within range 0 < axis < "
1464 "rank(input1[firstRankedTensorIdx]), got ")
1465 << axis;
1466 break;
1467 }
1468 }
1469
1470 const auto allOperandsHasRank = [](const Value input) {
1471 return ShapeAdaptor(input.getType()).hasRank();
1472 };
1473 if (llvm::all_of(inputList, allOperandsHasRank)) {
1474 const int64_t firstInputRank = firstRankedInputShape.getRank();
1475
1476 for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1477 const ShapeAdaptor inputShape(input.getType());
1478 const int64_t inputRank = inputShape.getRank();
1479 const size_t operandNum = index + 1;
1480
1481 // Check that each operand has the same rank
1482 if (inputRank != firstInputRank)
1483 return emitOpError(
1484 "expect all operands to have the same rank, but got ")
1485 << firstInputRank << " vs " << inputRank << " on operands 0 and "
1486 << operandNum;
1487
1488 // Check non-axis dims match
1489 for (int i = 0; i < inputRank; i++) {
1490 const int64_t inputDim = inputShape.getDimSize(i);
1491 const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1492 if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1493 inputShape.isDynamicDim(i))
1494 continue;
1495 if (inputDim != firstInputDim)
1496 return emitOpError("expect all operand shapes to have the same sizes "
1497 "on non-axis dimensions, but got ")
1498 << inputDim << " vs " << firstInputDim << " at index " << i
1499 << " on operands 0 and " << operandNum;
1500 }
1501 }
1502
1503 // ERROR_IF(axis_sum != shape[axis]);
1504 int64_t axisSum = 0;
1505 for (const auto &input : inputList) {
1506 const ShapeAdaptor inputShape(input.getType());
1507 if (inputShape.isDynamicDim(axis)) {
1508 // make axisSum negative to indicate invalid value
1509 axisSum = -1;
1510 break;
1511 }
1512 axisSum += inputShape.getDimSize(axis);
1513 }
1514 const ShapeAdaptor outputShape(outType);
1515 if (axisSum >= 0 && outputShape.hasRank() &&
1516 !outputShape.isDynamicDim(axis) &&
1517 axisSum != outputShape.getDimSize(axis))
1518 return emitOpError("requires sum of axis dimensions of input1 "
1519 "equal to output axis dimension, got ")
1520 << axisSum << " and " << outputShape.getDimSize(axis);
1521 }
1522
1523 return success();
1524}
1525
1526LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1527 MLIRContext *context, ::std::optional<Location> location,
1528 ValueShapeRange operands, DictionaryAttr attributes,
1529 OpaqueProperties properties, RegionRange regions,
1530 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1531 auto elementType = IntegerType::get(context, /*width=*/1);
1532
1533 llvm::SmallVector<int64_t> outShape;
1534 if (resolveBroadcastShape(operands, outShape).failed()) {
1535 inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
1536 return success();
1537 }
1538
1539 inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
1540 return success();
1541}
1542
1543bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1544 if (l.size() != r.size() || l.size() != 1)
1545 return false;
1546 return succeeded(verifyCompatibleShape(l[0], r[0]));
1547}
1548
1549LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1550 MLIRContext *context, ::std::optional<Location> location,
1551 MatMulOp::Adaptor adaptor,
1552 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1553 ShapeAdaptor lhsShape(adaptor.getA().getType());
1554 ShapeAdaptor rhsShape(adaptor.getB().getType());
1555
1556 // All shapes are dynamic.
1557 SmallVector<int64_t> outShape;
1558 outShape.resize(3, ShapedType::kDynamic);
1559
1560 if (lhsShape.hasRank()) {
1561 outShape[0] = lhsShape.getDimSize(0);
1562 outShape[1] = lhsShape.getDimSize(1);
1563 }
1564
1565 if (rhsShape.hasRank()) {
1566 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1567 : outShape[0];
1568 outShape[2] = rhsShape.getDimSize(2);
1569 }
1570
1571 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1572 return success();
1573}
1574
1575LogicalResult MatMulOp::verify() {
1576 auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
1577 auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
1578
1579 // Must be shaped tensor types
1580 if (!aType)
1581 return emitOpError("expect a shaped tensor for input a, got ")
1582 << getA().getType();
1583
1584 if (!bType)
1585 return emitOpError("expect a shaped tensor for input b, got ")
1586 << getB().getType();
1587
1588 auto aElementType = aType.getElementType();
1589 auto bElementType = bType.getElementType();
1590
1591 auto aQuantizedEType =
1592 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1593 auto bQuantizedEType =
1594 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1595
1596 if (aQuantizedEType || bQuantizedEType) {
1597 if (!aQuantizedEType || !bQuantizedEType) {
1598 return emitOpError("expect operands to be both quantized or both not "
1599 "quantized, got ")
1600 << aElementType << " and " << bElementType;
1601 }
1602 // both a and b have quantized element types
1603 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1604 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1605 if (aQuantWidth != bQuantWidth) {
1606 return emitOpError("expect quantized operands to have same widths, got ")
1607 << aQuantWidth << " and " << bQuantWidth;
1608 }
1609 } else {
1610 // non-quantized element types
1611 if (aElementType != bElementType) {
1612 return emitOpError("expect same element type for inputs a and b, got ")
1613 << aElementType << " and " << bElementType;
1614 }
1615 }
1616
1617 // check a_zp and b_zp
1618 auto aEType = getStorageElementTypeOrSelf(aType);
1619 auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
1620 if (aEType != aZpEType) {
1621 return emitOpError("expect input a and a_zp have the same "
1622 "element type, got ")
1623 << aEType << " and " << aZpEType;
1624 }
1625
1626 auto bEType = getStorageElementTypeOrSelf(bType);
1627 auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
1628 if (bEType != bZpEType) {
1629 return emitOpError("expect input b and b_zp have the same "
1630 "element type, got ")
1631 << bEType << " and " << bZpEType;
1632 }
1633
1634 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1635 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1636 return failure();
1637
1638 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1639 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1640 return failure();
1641
1642 return success();
1643}
1644
1645LogicalResult tosa::PadOp::inferReturnTypeComponents(
1646 MLIRContext *context, ::std::optional<Location> location,
1647 PadOp::Adaptor adaptor,
1648 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1649 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1650 auto paddingRank =
1651 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
1652 SmallVector<int64_t> outputShape;
1653
1654 // If the input rank is unknown, we can infer the output rank using the
1655 // padding shape's rank divided by 2.
1656 if (!inputShape.hasRank()) {
1657 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1658 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1659 return success();
1660 }
1661
1662 SmallVector<int64_t> paddingValues;
1663 // If the paddings value is not a constant, all dimensions must be dynamic.
1664 if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
1665 paddingValues)) {
1666 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1667 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1668 return success();
1669 }
1670
1671 outputShape.reserve(inputShape.getRank());
1672 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1673 if (inputShape.isDynamicDim(i)) {
1674 outputShape.push_back(ShapedType::kDynamic);
1675 continue;
1676 }
1677 auto padFront = paddingValues[i * 2];
1678 auto padBack = paddingValues[i * 2 + 1];
1679 if (padFront < 0 || padBack < 0) {
1680 // if either padding for dim i is -1, output dim is unknown
1681 outputShape.push_back(ShapedType::kDynamic);
1682 continue;
1683 }
1684
1685 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1686 }
1687
1688 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1689 return success();
1690}
1691
1692LogicalResult tosa::PadOp::verify() {
1693 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1694 /* outType = */ getOutput().getType())
1695 .failed()) {
1696 return failure();
1697 }
1698
1699 if (auto padConst = getPadConst()) {
1700 if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
1701 /* outType = */ getOutput().getType())
1702 .failed()) {
1703 return failure();
1704 }
1705 }
1706
1707 RankedTensorType inputType =
1708 llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1709 RankedTensorType outputType =
1710 llvm::dyn_cast<RankedTensorType>(getOutput().getType());
1711 if (!inputType || !outputType)
1712 return success();
1713
1714 auto inputRank = inputType.getRank();
1715 auto outputRank = outputType.getRank();
1716 if (inputRank != outputRank)
1717 return emitOpError() << "expect same input and output tensor rank, but got "
1718 << "inputRank: " << inputRank
1719 << ", outputRank: " << outputRank;
1720
1721 DenseIntElementsAttr paddingAttr;
1722 if (!matchPattern(getPadding(), m_Constant(&paddingAttr))) {
1723 return failure();
1724 }
1725
1726 auto paddingValues = paddingAttr.getValues<APInt>();
1727 if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
1728 return emitOpError() << "padding tensor must have " << inputRank
1729 << " * 2 = " << inputRank * 2 << " elements, but got "
1730 << paddingValues.size();
1731
1732 auto inputShape = inputType.getShape();
1733 auto outputShape = outputType.getShape();
1734
1735 for (int64_t i = 0; i < inputRank; ++i) {
1736 int64_t padStart = paddingValues[i * 2].getSExtValue();
1737 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
1738
1739 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
1740 return emitOpError()
1741 << "invalid padding values at dimension " << i
1742 << ": values must be non-negative or -1 for dynamic padding, got ["
1743 << padStart << ", " << padEnd << "]";
1744 }
1745
1746 // Skip shape verification for dynamic input/output
1747 if (inputShape[i] == ShapedType::kDynamic ||
1748 outputShape[i] == ShapedType::kDynamic)
1749 continue;
1750
1751 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
1752 return emitOpError() << "mismatch in output shape at dimension " << i
1753 << ": expected " << inputShape[i] << " + "
1754 << padStart << " + " << padEnd << " = "
1755 << (inputShape[i] + padStart + padEnd)
1756 << ", but got " << outputShape[i];
1757 }
1758 }
1759
1760 return success();
1761}
1762
1763LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1764 MLIRContext *context, ::std::optional<Location> location,
1765 SliceOp::Adaptor adaptor,
1766 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1767
1768 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1769 SmallVector<int64_t> start;
1770 SmallVector<int64_t> size;
1771
1772 if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
1773 !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
1774 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
1775 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1776 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
1777 return success();
1778 }
1779
1780 // if size[i] is -1, all remaining elements in dimension i are included
1781 // in the slice, similar to TF.
1782 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1783 // initialize outputShape to all unknown
1784 SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
1785 if (inputShape.hasRank()) {
1786 for (size_t i = 0; i < size.size(); i++) {
1787 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
1788 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
1789 start[i] < inputShape.getDimSize(i))) {
1790 // size[i] is not 0 and not < -1, and start[i] is in valid range
1791 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
1792 // input shape has unknown dim[i] - only valid if size[i] > 0
1793 if (size[i] > 0) {
1794 outputShape[i] = size[i];
1795 }
1796 } else {
1797 // input shape has known dim[i]
1798 if (size[i] == -1) {
1799 outputShape[i] = inputShape.getDimSize(i) - start[i];
1800 } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
1801 // start[i] + size[i] is within bound of input shape's dim[i]
1802 outputShape[i] = size[i];
1803 }
1804 }
1805 }
1806 }
1807 } else {
1808 outputShape = convertToMlirShape(size);
1809 }
1810 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1811 return success();
1812}
1813
1814LogicalResult tosa::SliceOp::verify() {
1815 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1816 /* outType = */ getOutput().getType())
1817 .failed())
1818 return failure();
1819
1820 const ShapeAdaptor inputShape(getInput1().getType());
1821 if (inputShape.hasRank()) {
1822 const auto inputRank = inputShape.getRank();
1823 const ShapeAdaptor outputShape(getOutput().getType());
1824 if (outputShape.hasRank() && inputRank != outputShape.getRank())
1825 return emitOpError(
1826 "expect input1 and output to have the same ranks, got ")
1827 << inputRank << " and " << outputShape.getRank();
1828
1829 const auto startShapeRank =
1830 llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
1831 if (inputRank != startShapeRank)
1832 return emitOpError("length of start is not equal to rank of input shape");
1833
1834 const auto sizeShapeRank =
1835 llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
1836 if (inputRank != sizeShapeRank)
1837 return emitOpError("length of size is not equal to rank of input shape");
1838 }
1839
1840 return success();
1841}
1842
1843LogicalResult tosa::MulOp::inferReturnTypeComponents(
1844 MLIRContext *context, ::std::optional<Location> location,
1845 ValueShapeRange operands, DictionaryAttr attributes,
1846 OpaqueProperties properties, RegionRange regions,
1847 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1848 // mul op's output shape only depend on input1 and input2, not on shift
1849 ValueShapeRange twoInputs = operands.drop_back();
1850 llvm::SmallVector<int64_t> outShape;
1851 if (resolveBroadcastShape(twoInputs, outShape).failed()) {
1852 inferredReturnShapes.push_back(ShapedTypeComponents());
1853 } else {
1854 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1855 }
1856 return success();
1857}
1858
1859LogicalResult tosa::MulOp::verify() {
1860 const Value output = getOutput();
1861 auto resElemType = getElementTypeOrSelf(output);
1862
1863 // Verify if the element type among operands and result match tosa
1864 // specification.
1865 if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
1866 IntegerType lhsIntType =
1867 dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
1868 IntegerType rhsIntType =
1869 dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
1870 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
1871 return emitOpError("requires the same element type for all operands");
1872
1873 // Though the spec requires the element type of result to be i32, a more
1874 // relaxed way is provided at dialect level for easier cooperating with
1875 // other dialects.
1876 if (lhsIntType.getWidth() > resIntType.getWidth())
1877 return emitOpError("invalid data type size for operands or result");
1878
1879 } else {
1880 // For other supported type, the spec requires requires the same element
1881 // type for all operands (excludes `shift` operand) and results.
1882 for (int i = 0; i < 2; ++i) {
1883 if (getElementTypeOrSelf(getOperand(i)) != resElemType)
1884 return emitOpError(
1885 "requires the same element type for all operands and results");
1886 }
1887
1888 // verify shift has value 0 for non-integer types
1889 ElementsAttr shift_elem;
1890 if (matchPattern(getShift(), m_Constant(&shift_elem))) {
1891 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1892 if (shift != 0) {
1893 return emitOpError() << "require shift to be 0 for float type";
1894 }
1895 }
1896 }
1897
1898 // Verify the op has same ranks for all main operands (excludes extra operands
1899 // such as shift of mul op, so this is the only difference with the built-in
1900 // `SameOperandsAndResultRank` trait) and results types, if known.
1901 TypeRange operandTypes = getOperandTypes();
1902 ShapedType aType = cast<ShapedType>(operandTypes[0]);
1903 ShapedType bType = cast<ShapedType>(operandTypes[1]);
1904
1905 const bool aHasRank = aType.hasRank();
1906 const bool bHasRank = bType.hasRank();
1907 if (aHasRank && bHasRank) {
1908 const int64_t aRank = aType.getRank();
1909 const int64_t bRank = bType.getRank();
1910 if (aRank != bRank)
1911 return emitOpError("a and b operands don't have matching ranks, got ")
1912 << aRank << " and " << bRank;
1913
1914 // check for broadcast compatible shapes
1915 SmallVector<int64_t> resultShape;
1916 if (!mlir::OpTrait::util::getBroadcastedShape(
1917 aType.getShape(), bType.getShape(), resultShape))
1918 return emitOpError("a and b operands don't have broadcast-compatible "
1919 "shapes, got ")
1920 << aType << " and " << bType;
1921 }
1922
1923 ShapedType resultType = cast<ShapedType>(output.getType());
1924 if (!resultType.hasRank())
1925 return success();
1926
1927 const int64_t resultRank = resultType.getRank();
1928 if (aHasRank && resultRank != aType.getRank())
1929 return emitOpError("result type has different rank than a, got ")
1930 << resultRank << " vs " << aType.getRank();
1931 if (bHasRank && resultRank != bType.getRank())
1932 return emitOpError("result type has different rank than b, got ")
1933 << resultRank << " vs " << bType.getRank();
1934
1935 return success();
1936}
1937
1938LogicalResult tosa::TableOp::inferReturnTypeComponents(
1939 MLIRContext *context, ::std::optional<Location> location,
1940 TableOp::Adaptor adaptor,
1941 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1942 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1943
1944 if (!inputShape.hasRank()) {
1945 inferredReturnShapes.push_back(ShapedTypeComponents());
1946 return success();
1947 }
1948
1949 inferredReturnShapes.resize(1);
1950 inputShape.getDims(inferredReturnShapes[0]);
1951 return success();
1952}
1953
1954LogicalResult tosa::TableOp::verify() {
1955 TensorType inputType = getInput1().getType();
1956 TensorType outputType = getOutput().getType();
1957
1958 if (inputType.hasRank() && outputType.hasRank() &&
1959 inputType.getRank() != outputType.getRank())
1960 return emitOpError()
1961 << "expected input tensor rank to equal result tensor rank";
1962
1963 auto inputDims = inputType.getShape();
1964 auto outputDims = outputType.getShape();
1965 for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
1966 int64_t dim = it.index();
1967 auto [inputDim, outputDim] = it.value();
1968 if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
1969 return emitOpError() << "dim(result, " << dim << ") = " << outputDim
1970 << " doesn't match dim(input, " << dim
1971 << ") = " << inputDim;
1972 }
1973 }
1974 return success();
1975}
1976
1977LogicalResult
1978tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
1979 // Multiples must be constants.
1980 DenseIntElementsAttr multiplesAttr;
1981 if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
1982 return failure();
1983 multiples = llvm::to_vector(
1984 llvm::map_range(multiplesAttr.getValues<APInt>(),
1985 [](const APInt &val) { return val.getSExtValue(); }));
1986 return success();
1987}
1988
1989LogicalResult tosa::TileOp::inferReturnTypeComponents(
1990 MLIRContext *context, ::std::optional<Location> location,
1991 TileOp::Adaptor adaptor,
1992 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1993 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1994 SmallVector<int64_t> multiples;
1995 if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
1996 multiples)) {
1997 auto rank =
1998 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
1999 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2000 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2001 return success();
2002 } else {
2003 multiples = convertToMlirShape(multiples);
2004 }
2005
2006 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2007 SmallVector<int64_t> outputShape;
2008 if (!inputShape.hasRank()) {
2009 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2010 inferredReturnShapes.push_back(
2011 ShapedTypeComponents(outputShape, inputType));
2012 return success();
2013 } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
2014 return failure();
2015
2016 // Any non dynamic dimension can be multiplied to a known size.
2017 outputShape.reserve(multiples.size());
2018 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2019 if (multiples[i] == ShapedType::kDynamic) {
2020 outputShape.push_back(ShapedType::kDynamic);
2021 } else {
2022 int64_t dim = inputShape.getDimSize(i);
2023 if (dim != ShapedType::kDynamic)
2024 dim *= multiples[i];
2025 outputShape.push_back(dim);
2026 }
2027 }
2028
2029 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2030 return success();
2031}
2032
2033LogicalResult tosa::TileOp::verify() {
2034 if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
2035 /* outType = */ getOutput().getType())
2036 .failed()) {
2037 return failure();
2038 }
2039 ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
2040 ShapedType outputType = llvm::cast<ShapedType>(getType());
2041
2042 shapeType multiplesType =
2043 llvm::cast<tosa::shapeType>(getMultiples().getType());
2044
2045 auto multiplesRank = multiplesType.getRank();
2046
2047 if (inputType.hasRank()) {
2048 if (inputType.getRank() != multiplesRank)
2049 return emitOpError("expect 'multiples' to have rank ")
2050 << inputType.getRank() << " but got " << multiplesRank << ".";
2051 if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
2052 return emitOpError("expect same input and output tensor rank.");
2053 } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2054 return emitOpError("expect 'multiples' array to have length ")
2055 << outputType.getRank() << " but got " << multiplesRank << ".";
2056
2057 SmallVector<int64_t> multiples;
2058 if (getConstantMultiples(multiples).succeeded() &&
2059 llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
2060 return emitOpError(
2061 "expect element of 'multiples' to be positive integer or -1.");
2062
2063 return success();
2064}
2065
2066bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2067 if (l.size() != r.size() || l.size() != 1)
2068 return false;
2069 return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
2070}
2071
2072LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2073 MLIRContext *context, ::std::optional<Location> location,
2074 ReshapeOp::Adaptor adaptor,
2075 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2076 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2077 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2078 llvm::SmallVector<int64_t> newShapeValue;
2079 if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
2080 newShapeValue)) {
2081 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2082 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2083 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2084 return success();
2085 } else {
2086 newShapeValue = convertToMlirShape(newShapeValue);
2087 }
2088
2089 // We cannot infer from the total number of elements so we must take the
2090 // shape attribute as exact.
2091 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2092 inferredReturnShapes.push_back(
2093 ShapedTypeComponents(newShapeValue, inputType));
2094 return success();
2095 }
2096
2097 // Determine the number of elements covered by the slice of all static
2098 // dimensions. This allows us to infer the length of the remaining dynamic
2099 // dimension.
2100 int64_t numElements = inputShape.getNumElements();
2101 int64_t staticMul = 1;
2102 for (auto val : newShapeValue) {
2103 if (!ShapedType::isDynamic(val)) {
2104 staticMul *= val;
2105 }
2106 }
2107
2108 // Determine the length of the dynamic dimension.
2109 for (auto &val : newShapeValue) {
2110 if (ShapedType::isDynamic(val))
2111 val = numElements / staticMul;
2112 }
2113
2114 inferredReturnShapes.push_back(
2115 ShapedTypeComponents(newShapeValue, inputType));
2116 return success();
2117}
2118
2119llvm::LogicalResult tosa::ReshapeOp::verify() {
2120 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2121 /* outType = */ getOutput().getType())
2122 .failed()) {
2123 return failure();
2124 }
2125 TensorType inputType = getInput1().getType();
2126
2127 SmallVector<int64_t> shapeValues;
2128 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
2129 // skip following checks if shape is not constant
2130 return mlir::success();
2131 }
2132
2133 int missingDims = llvm::count(shapeValues, -1);
2134 if (missingDims > 1)
2135 return emitOpError() << "expected at most one target dimension to be -1";
2136
2137 const auto outputType = dyn_cast<RankedTensorType>(getType());
2138 if (!outputType)
2139 return success();
2140
2141 if ((int64_t)shapeValues.size() != outputType.getRank())
2142 return emitOpError() << "new shape does not match result rank";
2143
2144 for (auto [newShapeDim, outputShapeDim] :
2145 zip(shapeValues, outputType.getShape())) {
2146 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2147 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2148 return emitOpError() << "new shape is inconsistent with result shape";
2149
2150 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2151 return emitOpError() << "new shape has invalid tensor dimension size "
2152 << newShapeDim;
2153 }
2154
2155 if (inputType.hasStaticShape()) {
2156 int64_t inputElementsNum = inputType.getNumElements();
2157 if (outputType.hasStaticShape()) {
2158 int64_t outputElementsNum = outputType.getNumElements();
2159 if (inputElementsNum != outputElementsNum) {
2160 return emitOpError() << "cannot reshape " << inputElementsNum
2161 << " elements into " << outputElementsNum;
2162 }
2163 }
2164
2165 int64_t newShapeElementsNum = std::accumulate(
2166 shapeValues.begin(), shapeValues.end(), 1LL,
2167 [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
2168 bool isStaticNewShape =
2169 llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
2170 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2171 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2172 return emitOpError() << "cannot reshape " << inputElementsNum
2173 << " elements into " << newShapeElementsNum;
2174 }
2175 }
2176
2177 return mlir::success();
2178}
2179
2180// return failure if val is not a constant
2181// set zp to -1 if val is non-zero float or val is not integer nor float
2182// otherwise set zp to val's constant value
2183static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
2184 ElementsAttr zpAttr;
2185 if (!matchPattern(val, m_Constant(&zpAttr))) {
2186 return failure();
2187 }
2188
2189 Type zpElemType = zpAttr.getElementType();
2190
2191 if (llvm::isa<FloatType>(Val: zpElemType)) {
2192 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2193 return 0;
2194 }
2195 // return non-zero value to trigger error check
2196 return -1;
2197 }
2198
2199 if (llvm::isa<IntegerType>(Val: zpElemType)) {
2200 if (signExtend)
2201 return zpAttr.getValues<APInt>()[0].getSExtValue();
2202 else
2203 return zpAttr.getValues<APInt>()[0].getZExtValue();
2204 }
2205
2206 // return non-zero value to trigger error check
2207 return -1;
2208}
2209
2210template <typename T>
2211static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
2212 const std::string &operand) {
2213 Type zpElemType = getElementTypeOrSelf(val);
2214
2215 if (!zpElemType.isInteger(width: 8) && zp != 0) {
2216 // convert operand to lower case for error message
2217 std::string lower = operand;
2218 std::transform(first: lower.begin(), last: lower.end(), result: lower.begin(), unary_op: ::tolower);
2219 return op.emitOpError()
2220 << lower << " zero point must be zero for non-int8 integer types";
2221 }
2222
2223 return success();
2224}
2225
2226static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
2227 const int64_t &zp,
2228 const std::string &operand) {
2229 bool isInputZp = (operand == "Input");
2230
2231 bool tensorUnsigned =
2232 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2233 StringRef tensorName = isInputZp ? "input" : "output";
2234
2235 Type zpElemType = getElementTypeOrSelf(val: zpVal);
2236
2237 if (zp != 0) {
2238 if (!zpElemType.isInteger(width: 8) &&
2239 !(zpElemType.isInteger(width: 16) && tensorUnsigned)) {
2240 return op.emitOpError()
2241 << "expect " << tensorName << "_zp of 0, got " << zp;
2242 }
2243 if (zpElemType.isInteger(width: 16) && tensorUnsigned && zp != 32768) {
2244 return op.emitOpError() << "expect " << tensorName
2245 << "_zp of 0 or 32768 for unsigned int16 "
2246 << tensorName << ", got " << zp;
2247 }
2248 }
2249
2250 return success();
2251}
2252
2253#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2254 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2255 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2256 } \
2257 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2258 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2259 }
2260
2261ZERO_POINT_HELPER(Conv2DOp, Input, true)
2262ZERO_POINT_HELPER(Conv2DOp, Weight, true)
2263ZERO_POINT_HELPER(Conv3DOp, Input, true)
2264ZERO_POINT_HELPER(Conv3DOp, Weight, true)
2265ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
2266ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
2267ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
2268ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
2269ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
2270ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
2271ZERO_POINT_HELPER(MatMulOp, A, true)
2272ZERO_POINT_HELPER(MatMulOp, B, true)
2273ZERO_POINT_HELPER(NegateOp, Input1, true)
2274ZERO_POINT_HELPER(NegateOp, Output, true)
2275ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
2276ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
2277#undef ZERO_POINT_HELPER
2278
2279LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2280 MLIRContext *context, ::std::optional<Location> location,
2281 TransposeOp::Adaptor adaptor,
2282 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2283 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2284
2285 // If input rank and permutation length is unknown, the output rank is
2286 // unknown.
2287 if (!inputShape.hasRank()) {
2288 inferredReturnShapes.push_back(ShapedTypeComponents());
2289 return success();
2290 }
2291
2292 const auto inputRank = inputShape.getRank();
2293
2294 // This would imply the number of permutations does not match the rank of
2295 // the input which is illegal.
2296 if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
2297 return failure();
2298 }
2299
2300 SmallVector<int64_t> outputShape;
2301 // Rank-0 means no permutations matter.
2302 if (inputRank == 0) {
2303 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2304 return success();
2305 }
2306
2307 // Check whether the input dimensions are all the same.
2308 bool allTheSame = true;
2309 for (int i = 1, s = inputRank; i < s; i++) {
2310 if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
2311 allTheSame = false;
2312 break;
2313 }
2314 }
2315
2316 // If all of the input dimensions are the same we don't care about the
2317 // permutation.
2318 if (allTheSame) {
2319 outputShape.resize(inputRank, inputShape.getDimSize(0));
2320 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2321 return success();
2322 }
2323
2324 outputShape.resize(inputRank, ShapedType::kDynamic);
2325
2326 // Constant permutation values must be within the input rank.
2327 if (llvm::any_of(adaptor.getPerms(),
2328 [inputRank](const auto i) { return i >= inputRank; }))
2329 return failure();
2330
2331 outputShape.reserve(inputRank);
2332 for (int i = 0, s = inputRank; i < s; i++) {
2333 outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
2334 }
2335
2336 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2337 return success();
2338}
2339
2340LogicalResult tosa::TransposeOp::verify() {
2341 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2342 /* outType = */ getOutput().getType())
2343 .failed()) {
2344 return failure();
2345 }
2346
2347 const ShapeAdaptor inputShape(getInput1().getType());
2348 const ShapeAdaptor outputShape(getOutput().getType());
2349
2350 const llvm::ArrayRef<int32_t> constantPerms = getPerms();
2351
2352 if (inputShape.hasRank() &&
2353 constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
2354 return emitOpError() << "expected perms attribute to have size "
2355 << inputShape.getRank()
2356 << " (input rank) but got size "
2357 << constantPerms.size();
2358
2359 if (inputShape.hasRank() && outputShape.hasRank() &&
2360 inputShape.getRank() != outputShape.getRank())
2361 return emitOpError()
2362 << "expected input tensor rank to equal result tensor rank";
2363
2364 if (outputShape.hasRank() &&
2365 constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
2366 return emitOpError() << "expected perms attribute to have size "
2367 << outputShape.getRank()
2368 << " (output rank) but got size "
2369 << constantPerms.size();
2370
2371 if (!llvm::all_of(constantPerms,
2372 [&constantPerms](int32_t s) {
2373 return s >= 0 &&
2374 static_cast<size_t>(s) < constantPerms.size();
2375 }) ||
2376 !isPermutationVector(llvm::to_vector(llvm::map_range(
2377 constantPerms, [](int32_t v) -> int64_t { return v; }))))
2378 return emitOpError() << "expected valid permutation indices";
2379
2380 // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
2381 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2382 inputShape.getNumElements() != outputShape.getNumElements())
2383 return emitOpError() << "expected input1 and output to have same numbers "
2384 "of elements, got "
2385 << inputShape.getNumElements() << " and "
2386 << outputShape.getNumElements();
2387
2388 // Verify that the types of the input and output tensors are properly
2389 // permuted.
2390 if (inputShape.hasRank() && outputShape.hasRank()) {
2391 for (auto i = 0; i < outputShape.getRank(); i++) {
2392 if (inputShape.isDynamicDim(constantPerms[i]) ||
2393 outputShape.isDynamicDim(i))
2394 continue;
2395
2396 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2397 return emitOpError()
2398 << "expected output tensor dim " << i << " to match "
2399 << "input dim " << constantPerms[i] << " with value of "
2400 << inputShape.getDimSize(constantPerms[i]);
2401 }
2402 }
2403
2404 return success();
2405}
2406
2407LogicalResult TransposeOp::reifyResultShapes(
2408 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2409
2410 const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2411
2412 Value input = getInput1();
2413 auto inputType = cast<TensorType>(input.getType());
2414
2415 SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2416 for (auto dim : transposePerms) {
2417 int32_t dimInInput = transposePerms[dim];
2418 if (inputType.isDynamicDim(dimInInput))
2419 returnedDims[dim] =
2420 builder.create<tensor::DimOp>(getLoc(), input, dimInInput)
2421 .getResult();
2422 else
2423 returnedDims[dim] =
2424 builder.getIndexAttr(inputType.getDimSize(dimInInput));
2425 }
2426
2427 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2428 return success();
2429}
2430
2431LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2432 MLIRContext *context, ::std::optional<Location> location,
2433 GatherOp::Adaptor adaptor,
2434 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2435 llvm::SmallVector<int64_t> outputShape;
2436 outputShape.resize(3, ShapedType::kDynamic);
2437
2438 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2439 if (valuesShape.hasRank()) {
2440 outputShape[0] = valuesShape.getDimSize(0);
2441 outputShape[2] = valuesShape.getDimSize(2);
2442 }
2443
2444 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2445 if (indicesShape.hasRank()) {
2446 if (outputShape[0] == ShapedType::kDynamic)
2447 outputShape[0] = indicesShape.getDimSize(0);
2448 if (outputShape[1] == ShapedType::kDynamic)
2449 outputShape[1] = indicesShape.getDimSize(1);
2450 }
2451
2452 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2453 return success();
2454}
2455
2456LogicalResult tosa::GatherOp::verify() {
2457 if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2458 /* outType = */ getOutput().getType())
2459 .failed()) {
2460 return failure();
2461 }
2462
2463 const ShapeAdaptor valuesShape(getValues().getType());
2464 const ShapeAdaptor indicesShape(getIndices().getType());
2465 const ShapeAdaptor outputShape(getOutput().getType());
2466
2467 int64_t N = ShapedType::kDynamic;
2468 int64_t W = ShapedType::kDynamic;
2469 int64_t C = ShapedType::kDynamic;
2470
2471 if (valuesShape.hasRank()) {
2472 N = valuesShape.getDimSize(0);
2473 C = valuesShape.getDimSize(2);
2474 }
2475 if (indicesShape.hasRank()) {
2476 const int64_t indicesN = indicesShape.getDimSize(0);
2477 W = indicesShape.getDimSize(1);
2478 if (N == ShapedType::kDynamic)
2479 N = indicesN;
2480 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2481 return emitOpError() << "requires indices dimension 0 to have size " << N
2482 << ", got " << indicesN;
2483 }
2484 if (outputShape.hasRank()) {
2485 const int64_t outputN = outputShape.getDimSize(0);
2486 const int64_t outputW = outputShape.getDimSize(1);
2487 const int64_t outputC = outputShape.getDimSize(2);
2488 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2489 N != outputN)
2490 return emitOpError() << "requires output dimension 0 to have size " << N
2491 << ", got " << outputN;
2492
2493 if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2494 W != outputW)
2495 return emitOpError() << "requires output dimension 1 to have size " << W
2496 << ", got " << outputW;
2497 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2498 C != outputC)
2499 return emitOpError() << "requires output dimension 2 to have size " << C
2500 << ", got " << outputC;
2501 }
2502 return success();
2503}
2504
2505LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2506 MLIRContext *context, ::std::optional<Location> location,
2507 ResizeOp::Adaptor adaptor,
2508 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2509 llvm::SmallVector<int64_t, 4> outputShape;
2510 outputShape.resize(4, ShapedType::kDynamic);
2511
2512 ShapeAdaptor inputShape(adaptor.getInput().getType());
2513 if (!inputShape.hasRank())
2514 return failure();
2515
2516 outputShape[0] = inputShape.getDimSize(0);
2517 outputShape[3] = inputShape.getDimSize(3);
2518 int64_t inputHeight = inputShape.getDimSize(1);
2519 int64_t inputWidth = inputShape.getDimSize(2);
2520
2521 if ((inputHeight == ShapedType::kDynamic) ||
2522 (inputWidth == ShapedType::kDynamic))
2523 return failure();
2524
2525 SmallVector<int64_t> scaleInt, offsetInt, borderInt;
2526 if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
2527 scaleInt) ||
2528 !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
2529 offsetInt) ||
2530 !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
2531 borderInt)) {
2532 return failure();
2533 }
2534
2535 // Compute the output shape based on attributes: scale, offset, and border.
2536 outputShape[1] =
2537 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2538 scaleInt[1]) +
2539 1;
2540
2541 outputShape[2] =
2542 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2543 scaleInt[3]) +
2544 1;
2545
2546 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2547 return success();
2548}
2549
2550LogicalResult tosa::ResizeOp::verify() {
2551 const Value input = getInput();
2552 const Value output = getOutput();
2553 const RankedTensorType inputType =
2554 llvm::dyn_cast<RankedTensorType>(input.getType());
2555 const RankedTensorType outputType =
2556 llvm::dyn_cast<RankedTensorType>(output.getType());
2557
2558 SmallVector<int64_t> scaleValues;
2559 SmallVector<int64_t> offsetValues;
2560 SmallVector<int64_t> borderValues;
2561 if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
2562 !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
2563 !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
2564 // Skip following checks if shape is not constant
2565 return success();
2566 }
2567
2568 if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
2569 return emitOpError("expect all scale values to be > 0, got ")
2570 << scaleValues;
2571
2572 const int64_t scaleYN = scaleValues[0];
2573 const int64_t scaleYD = scaleValues[1];
2574 const int64_t scaleXN = scaleValues[2];
2575 const int64_t scaleXD = scaleValues[3];
2576
2577 const int64_t offsetY = offsetValues[0];
2578 const int64_t offsetX = offsetValues[1];
2579
2580 const int64_t borderY = borderValues[0];
2581 const int64_t borderX = borderValues[1];
2582
2583 if (!inputType)
2584 return success();
2585 if (!outputType)
2586 return success();
2587
2588 const int64_t oh = outputType.getDimSize(1);
2589 const int64_t ow = outputType.getDimSize(2);
2590 const int64_t ih = inputType.getDimSize(1);
2591 const int64_t iw = inputType.getDimSize(2);
2592
2593 // Don't check with input height that could be broadcast (ih != 1)
2594 // since Linalg, a consumer of TOSA, expects broadcasting support
2595 // in resize to be available. Taking the cautious approach for now,
2596 // we can consider removing support for broadcasting later.
2597 if (ih != ShapedType::kDynamic && ih != 1) {
2598 const std::optional<int64_t> calculatedOutHeightMinusOne =
2599 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
2600 if (!calculatedOutHeightMinusOne.has_value())
2601 return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
2602 "border_y ")
2603 << "to be wholly divisible by scale_y_d, got ((" << ih
2604 << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
2605 << ") / " << scaleYD;
2606 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2607 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2608 return emitOpError("calculated output height did not match expected: ")
2609 << "calculated=" << calculatedOutHeight << ", expected=" << oh;
2610 }
2611
2612 // Don't check with input width that could be broadcast (iw != 1)
2613 // since Linalg, a consumer of TOSA, expects broadcasting support
2614 // in resize to be available. Taking the cautious approach for now,
2615 // we can consider removing support for broadcasting later.
2616 if (iw != ShapedType::kDynamic && iw != 1) {
2617 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2618 const std::optional<int64_t> calculatedOutWidthMinusOne =
2619 idivCheck(scaledInWidth, scaleXD);
2620 if (!calculatedOutWidthMinusOne.has_value())
2621 return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
2622 "border_x ")
2623 << "to be wholly divisible by scale_x_d, got ((" << iw
2624 << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
2625 << ") / " << scaleXD;
2626 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2627 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2628 return emitOpError("calculated output width did not match expected: ")
2629 << "calculated=" << calculatedOutWidth << ", expected=" << ow;
2630 }
2631
2632 return success();
2633}
2634
2635LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2636 MLIRContext *context, ::std::optional<Location> location,
2637 ScatterOp::Adaptor adaptor,
2638 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2639 llvm::SmallVector<int64_t> outputShape;
2640 outputShape.resize(3, ShapedType::kDynamic);
2641
2642 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2643 if (valuesInShape.hasRank()) {
2644 outputShape[0] = valuesInShape.getDimSize(0);
2645 outputShape[1] = valuesInShape.getDimSize(1);
2646 outputShape[2] = valuesInShape.getDimSize(2);
2647 }
2648
2649 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2650 if (indicesShape.hasRank()) {
2651 if (outputShape[0] == ShapedType::kDynamic)
2652 outputShape[0] = indicesShape.getDimSize(0);
2653 }
2654
2655 ShapeAdaptor inputShape(adaptor.getInput().getType());
2656 if (inputShape.hasRank()) {
2657 if (outputShape[0] == ShapedType::kDynamic)
2658 outputShape[0] = inputShape.getDimSize(0);
2659 if (outputShape[2] == ShapedType::kDynamic)
2660 outputShape[2] = inputShape.getDimSize(2);
2661 }
2662
2663 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2664 return success();
2665}
2666
2667LogicalResult tosa::ScatterOp::verify() {
2668 if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
2669 /* outType = */ getValuesOut().getType())
2670 .failed() ||
2671 verifySameElementTypes(*this, /* inType = */ getInput().getType(),
2672 /* outType = */ getValuesOut().getType())
2673 .failed()) {
2674 return failure();
2675 }
2676
2677 const ShapeAdaptor valuesInShape(getValuesIn().getType());
2678 const ShapeAdaptor indicesShape(getIndices().getType());
2679 const ShapeAdaptor inputShape(getInput().getType());
2680 const ShapeAdaptor outputShape(getValuesOut().getType());
2681
2682 int64_t N = ShapedType::kDynamic;
2683 int64_t K = ShapedType::kDynamic;
2684 int64_t W = ShapedType::kDynamic;
2685 int64_t C = ShapedType::kDynamic;
2686 if (valuesInShape.hasRank()) {
2687 N = valuesInShape.getDimSize(0);
2688 K = valuesInShape.getDimSize(1);
2689 C = valuesInShape.getDimSize(2);
2690 }
2691 if (indicesShape.hasRank()) {
2692 const int64_t indicesN = indicesShape.getDimSize(0);
2693 W = indicesShape.getDimSize(1);
2694 if (N == ShapedType::kDynamic)
2695 N = indicesN;
2696 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2697 return emitOpError() << "requires indices dimension 0 to have size " << N
2698 << ", got " << indicesN;
2699 }
2700 if (inputShape.hasRank()) {
2701 const int64_t inputN = inputShape.getDimSize(0);
2702 const int64_t inputW = inputShape.getDimSize(1);
2703 const int64_t inputC = inputShape.getDimSize(2);
2704 if (N == ShapedType::kDynamic)
2705 N = inputN;
2706 else if (inputN != ShapedType::kDynamic && N != inputN)
2707 return emitOpError() << "requires input dimension 0 to have size " << N
2708 << ", got " << inputN;
2709 if (W == ShapedType::kDynamic)
2710 W = inputW;
2711 else if (inputW != ShapedType::kDynamic && W != inputW)
2712 return emitOpError() << "requires input dimension 1 to have size " << W
2713 << ", got " << inputW;
2714
2715 if (C == ShapedType::kDynamic)
2716 C = inputC;
2717 else if (inputC != ShapedType::kDynamic && C != inputC)
2718 return emitOpError() << "requires input dimension 2 to have size " << C
2719 << ", got " << inputC;
2720 }
2721 if (outputShape.hasRank()) {
2722 const int64_t outputN = outputShape.getDimSize(0);
2723 const int64_t outputK = outputShape.getDimSize(1);
2724 const int64_t outputC = outputShape.getDimSize(2);
2725 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2726 N != outputN)
2727 return emitOpError() << "requires values_out dimension 0 to have size "
2728 << N << ", got " << outputN;
2729 if (K == ShapedType::kDynamic)
2730 K = outputK;
2731 else if (outputK != ShapedType::kDynamic && K != outputK)
2732 return emitOpError() << "requires values_out dimension 1 to have size "
2733 << K << ", got " << outputK;
2734 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2735 C != outputC)
2736 return emitOpError() << "requires values_out dimension 2 to have size "
2737 << C << ", got " << outputC;
2738 }
2739 if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
2740 return emitOpError() << "requires dimensions K >= W, got K=" << K
2741 << " and W=" << W;
2742
2743 return success();
2744}
2745
2746static LogicalResult ReduceInferReturnTypes(
2747 ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
2748 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2749 int64_t axisVal = axis.getValue().getSExtValue();
2750 if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
2751 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(inputType));
2752 return success();
2753 }
2754
2755 SmallVector<int64_t> outputShape;
2756 operandShape.getDims(res&: outputShape);
2757 outputShape[axisVal] = 1;
2758 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape, inputType));
2759 return success();
2760}
2761
2762#define COMPATIBLE_RETURN_TYPES(OP) \
2763 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
2764 if (l.size() != r.size() || l.size() != 1) \
2765 return false; \
2766 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
2767 return false; \
2768 return succeeded(verifyCompatibleShape(l[0], r[0])); \
2769 }
2770
2771#define REDUCE_SHAPE_INFER(OP) \
2772 LogicalResult OP::inferReturnTypeComponents( \
2773 MLIRContext *context, ::std::optional<Location> location, \
2774 OP::Adaptor adaptor, \
2775 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2776 Type inputType = \
2777 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
2778 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
2779 const Properties &prop = adaptor.getProperties(); \
2780 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
2781 inferredReturnShapes); \
2782 } \
2783 COMPATIBLE_RETURN_TYPES(OP)
2784
2785REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
2786REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
2787REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
2788REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
2789REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
2790REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
2791#undef REDUCE_SHAPE_INFER
2792COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
2793#undef COMPATIBLE_RETURN_TYPES
2794
2795template <typename T>
2796static LogicalResult verifyReduceOp(T op) {
2797 // All TOSA reduce Ops have input, output and axis.
2798 TensorType inputType = op.getInput().getType();
2799 TensorType outputType = op.getOutput().getType();
2800 int32_t reduceAxis = op.getAxis();
2801
2802 if (reduceAxis < 0) {
2803 op.emitOpError("reduce axis must not be negative");
2804 return failure();
2805 }
2806 if (inputType.hasRank()) {
2807 int64_t inputRank = inputType.getRank();
2808 // We allow for a special case where the input/output shape has rank 0 and
2809 // axis is also 0.
2810 if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
2811 op.emitOpError("expect input tensor rank (")
2812 << inputRank << ") to be larger than reduce axis (" << reduceAxis
2813 << ")";
2814 return failure();
2815 }
2816 }
2817 if (outputType.hasRank()) {
2818 int64_t outputRank = outputType.getRank();
2819 if (inputType.hasRank() && outputRank != inputType.getRank()) {
2820 op.emitOpError(
2821 "expect output tensor rank to be equal to input tensor rank");
2822 return failure();
2823 }
2824 if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
2825 op.emitOpError("expect output tensor rank (")
2826 << outputRank << ") to be larger than reduce axis (" << reduceAxis
2827 << ")";
2828 return failure();
2829 }
2830 // We can only verify the reduced dimension size to be 1 if this is not
2831 // the special case of output rank == 0.
2832 if (outputRank != 0) {
2833 auto outputShape = outputType.getShape();
2834 if (!outputType.isDynamicDim(reduceAxis) &&
2835 outputShape[reduceAxis] != 1) {
2836 op.emitOpError("expect reduced dimension size to be 1, got ")
2837 << outputShape[reduceAxis];
2838 return failure();
2839 }
2840 }
2841 }
2842 return success();
2843}
2844
2845LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
2846LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
2847LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
2848LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
2849LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); }
2850LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
2851
2852static LogicalResult NAryInferReturnTypes(
2853 const ValueShapeRange &operands,
2854 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2855 llvm::SmallVector<int64_t> outShape;
2856 if (resolveBroadcastShape(operands, outShape).failed()) {
2857 inferredReturnShapes.push_back(Elt: ShapedTypeComponents());
2858 } else {
2859 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outShape));
2860 }
2861 return success();
2862}
2863
2864#define NARY_SHAPE_INFER(OP) \
2865 LogicalResult OP::inferReturnTypeComponents( \
2866 MLIRContext *context, ::std::optional<Location> location, \
2867 ValueShapeRange operands, DictionaryAttr attributes, \
2868 OpaqueProperties properties, RegionRange regions, \
2869 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2870 return NAryInferReturnTypes(operands, inferredReturnShapes); \
2871 }
2872
2873NARY_SHAPE_INFER(tosa::AbsOp)
2874NARY_SHAPE_INFER(tosa::AddOp)
2875NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
2876NARY_SHAPE_INFER(tosa::BitwiseAndOp)
2877NARY_SHAPE_INFER(tosa::BitwiseOrOp)
2878NARY_SHAPE_INFER(tosa::BitwiseXorOp)
2879NARY_SHAPE_INFER(tosa::BitwiseNotOp)
2880NARY_SHAPE_INFER(tosa::CastOp)
2881NARY_SHAPE_INFER(tosa::CeilOp)
2882NARY_SHAPE_INFER(tosa::ClampOp)
2883NARY_SHAPE_INFER(tosa::ClzOp)
2884NARY_SHAPE_INFER(tosa::CosOp)
2885NARY_SHAPE_INFER(tosa::ExpOp)
2886NARY_SHAPE_INFER(tosa::FloorOp)
2887NARY_SHAPE_INFER(tosa::GreaterEqualOp)
2888NARY_SHAPE_INFER(tosa::GreaterOp)
2889NARY_SHAPE_INFER(tosa::IdentityOp)
2890NARY_SHAPE_INFER(tosa::IntDivOp)
2891NARY_SHAPE_INFER(tosa::LogOp)
2892NARY_SHAPE_INFER(tosa::LogicalAndOp)
2893NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
2894NARY_SHAPE_INFER(tosa::LogicalNotOp)
2895NARY_SHAPE_INFER(tosa::LogicalOrOp)
2896NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
2897NARY_SHAPE_INFER(tosa::LogicalXorOp)
2898NARY_SHAPE_INFER(tosa::MaximumOp)
2899NARY_SHAPE_INFER(tosa::MinimumOp)
2900NARY_SHAPE_INFER(tosa::PowOp)
2901NARY_SHAPE_INFER(tosa::ReciprocalOp)
2902NARY_SHAPE_INFER(tosa::ReverseOp)
2903NARY_SHAPE_INFER(tosa::RsqrtOp)
2904NARY_SHAPE_INFER(tosa::SinOp)
2905NARY_SHAPE_INFER(tosa::SelectOp)
2906NARY_SHAPE_INFER(tosa::SubOp)
2907NARY_SHAPE_INFER(tosa::TanhOp)
2908NARY_SHAPE_INFER(tosa::ErfOp)
2909NARY_SHAPE_INFER(tosa::SigmoidOp)
2910#undef PRED_SHAPE_INFER
2911
2912LogicalResult tosa::NegateOp::inferReturnTypeComponents(
2913 MLIRContext *context, ::std::optional<Location> location,
2914 NegateOp::Adaptor adaptor,
2915 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2916 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2917 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
2918 return success();
2919}
2920
2921LogicalResult tosa::NegateOp::verify() {
2922 // Verify same element type
2923 const Type input1Type = getInput1().getType();
2924 const Type outputType = getOutput().getType();
2925 if (verifySameElementTypes(*this, input1Type, outputType).failed())
2926 return failure();
2927
2928 // Verify same shape
2929 const SmallVector<Type, 2> types = {input1Type, outputType};
2930 if (failed(verifyCompatibleShapes(types)))
2931 return emitOpError() << "requires the same shape for input1 and output";
2932
2933 const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
2934 const Type input1ZpEType =
2935 getStorageElementTypeOrSelf(getInput1Zp().getType());
2936 if (input1EType != input1ZpEType) {
2937 return emitOpError("expect both input1 and its zero point are the same "
2938 "element type, got ")
2939 << input1EType << " and " << input1ZpEType;
2940 }
2941 const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
2942 const Type outputZpEType =
2943 getStorageElementTypeOrSelf(getOutputZp().getType());
2944 if (outputEType != outputZpEType) {
2945 return emitOpError("expect both output and its zero point are the same "
2946 "element type, got ")
2947 << outputEType << " and " << outputZpEType;
2948 }
2949
2950 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2951 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
2952 return failure();
2953
2954 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2955 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2956 return failure();
2957
2958 return success();
2959}
2960
2961static LogicalResult poolingInferReturnTypes(
2962 ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
2963 ArrayRef<int64_t> pad,
2964 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2965 llvm::SmallVector<int64_t> outputShape;
2966 outputShape.resize(4, ShapedType::kDynamic);
2967
2968 // We only know the rank if the input type is unranked.
2969 if (!inputShape) {
2970 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
2971 return success();
2972 }
2973
2974 // Batch and number of channels are identical for pooling layer.
2975 outputShape[0] = inputShape.getDimSize(index: 0);
2976 outputShape[3] = inputShape.getDimSize(index: 3);
2977
2978 int64_t height = inputShape.getDimSize(index: 1);
2979 int64_t width = inputShape.getDimSize(index: 2);
2980
2981 if (!ShapedType::isDynamic(height)) {
2982 int64_t padded = height + pad[0] + pad[1] - kernel[0];
2983 outputShape[1] = padded / stride[0] + 1;
2984 }
2985
2986 if (!ShapedType::isDynamic(width)) {
2987 int64_t padded = width + pad[2] + pad[3] - kernel[1];
2988 outputShape[2] = padded / stride[1] + 1;
2989 }
2990
2991 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
2992 return success();
2993}
2994
2995LogicalResult Conv2DOp::inferReturnTypeComponents(
2996 MLIRContext *context, ::std::optional<Location> location,
2997 Conv2DOp::Adaptor adaptor,
2998 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2999 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3000
3001 int64_t inputWidth = ShapedType::kDynamic;
3002 int64_t inputHeight = ShapedType::kDynamic;
3003 int64_t weightWidth = ShapedType::kDynamic;
3004 int64_t weightHeight = ShapedType::kDynamic;
3005
3006 // Input shape describes input width/height and batch.
3007
3008 ShapeAdaptor inputShape(adaptor.getInput().getType());
3009 if (inputShape.hasRank()) {
3010 outputShape[0] = inputShape.getDimSize(0);
3011 inputHeight = inputShape.getDimSize(1);
3012 inputWidth = inputShape.getDimSize(2);
3013 }
3014
3015 // Weight shapes describes the filter width/height and the output channels.
3016 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3017 if (weightShape.hasRank()) {
3018 outputShape[3] = weightShape.getDimSize(0);
3019 weightHeight = weightShape.getDimSize(1);
3020 weightWidth = weightShape.getDimSize(2);
3021 }
3022
3023 // Bias shape can describe the output channels.
3024 ShapeAdaptor biasShape(adaptor.getBias().getType());
3025 if (biasShape.hasRank()) {
3026 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3027 ? biasShape.getDimSize(0)
3028 : outputShape[3];
3029 }
3030
3031 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3032 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3033 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3034
3035 if (!ShapedType::isDynamic(inputHeight) &&
3036 !ShapedType::isDynamic(weightHeight)) {
3037 int64_t inputSize = inputHeight + padding[0] + padding[1];
3038 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3039 int64_t unstridedResult = inputSize - filterSize + 1;
3040 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3041 }
3042
3043 if (!ShapedType::isDynamic(inputWidth) &&
3044 !ShapedType::isDynamic(weightWidth)) {
3045 int64_t inputSize = inputWidth + padding[2] + padding[3];
3046 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3047 int64_t unstridedResult = inputSize - filterSize + 1;
3048 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3049 }
3050
3051 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3052 return success();
3053}
3054
3055LogicalResult Conv2DOp::verify() {
3056 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3057 verifyConvOpErrorIf(*this).failed())
3058 return failure();
3059 return success();
3060}
3061
3062LogicalResult Conv3DOp::inferReturnTypeComponents(
3063 MLIRContext *context, ::std::optional<Location> location,
3064 Conv3DOp::Adaptor adaptor,
3065 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3066 llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
3067
3068 int64_t inputWidth = ShapedType::kDynamic;
3069 int64_t inputHeight = ShapedType::kDynamic;
3070 int64_t inputDepth = ShapedType::kDynamic;
3071
3072 int64_t weightWidth = ShapedType::kDynamic;
3073 int64_t weightHeight = ShapedType::kDynamic;
3074 int64_t weightDepth = ShapedType::kDynamic;
3075
3076 // Input shape describes input width/height and batch.
3077 ShapeAdaptor inputShape(adaptor.getInput().getType());
3078 if (inputShape.hasRank()) {
3079 outputShape[0] = inputShape.getDimSize(0);
3080 inputDepth = inputShape.getDimSize(1);
3081 inputHeight = inputShape.getDimSize(2);
3082 inputWidth = inputShape.getDimSize(3);
3083 }
3084
3085 // Weight shapes describes the filter width/height and the output channels.
3086 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3087 if (weightShape.hasRank()) {
3088 outputShape[4] = weightShape.getDimSize(0);
3089 weightDepth = weightShape.getDimSize(1);
3090 weightHeight = weightShape.getDimSize(2);
3091 weightWidth = weightShape.getDimSize(3);
3092 }
3093
3094 // Bias shape can describe the output channels.
3095 ShapeAdaptor biasShape(adaptor.getBias().getType());
3096 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3097 outputShape[4] = biasShape.getDimSize(0);
3098 }
3099
3100 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3101 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3102 llvm::ArrayRef<int64_t> pad = adaptor.getPad();
3103
3104 if (!ShapedType::isDynamic(inputDepth) &&
3105 !ShapedType::isDynamic(weightDepth)) {
3106 int32_t inputSize = inputDepth + pad[0] + pad[1];
3107 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3108 int32_t unstridedResult = inputSize - filterSize + 1;
3109 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3110 }
3111
3112 if (!ShapedType::isDynamic(inputHeight) &&
3113 !ShapedType::isDynamic(weightHeight)) {
3114 int32_t inputSize = inputHeight + pad[2] + pad[3];
3115 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3116 int32_t unstridedResult = inputSize - filterSize + 1;
3117 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3118 }
3119
3120 if (!ShapedType::isDynamic(inputWidth) &&
3121 !ShapedType::isDynamic(weightWidth)) {
3122 int32_t inputSize = inputWidth + pad[4] + pad[5];
3123 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3124 int32_t unstridedResult = inputSize - filterSize + 1;
3125 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3126 }
3127
3128 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3129 return success();
3130}
3131
3132LogicalResult Conv3DOp::verify() {
3133 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3134 verifyConvOpErrorIf(*this).failed())
3135 return failure();
3136 return success();
3137}
3138
3139LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3140 MLIRContext *context, ::std::optional<Location> location,
3141 AvgPool2dOp::Adaptor adaptor,
3142 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3143 ShapeAdaptor inputShape(adaptor.getInput().getType());
3144 const Properties &prop = adaptor.getProperties();
3145 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3146 inferredReturnShapes);
3147}
3148
3149LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3150 MLIRContext *context, ::std::optional<Location> location,
3151 MaxPool2dOp::Adaptor adaptor,
3152 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3153 ShapeAdaptor inputShape(adaptor.getInput().getType());
3154 const Properties &prop = adaptor.getProperties();
3155 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3156 inferredReturnShapes);
3157}
3158
3159LogicalResult MaxPool2dOp::verify() {
3160 if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
3161 /* outType = */ getOutput().getType())))
3162 return failure();
3163
3164 if (failed(verifyPoolingOp(*this)))
3165 return failure();
3166
3167 return success();
3168}
3169
3170LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3171 MLIRContext *context, ::std::optional<Location> location,
3172 DepthwiseConv2DOp::Adaptor adaptor,
3173 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3174 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3175
3176 int64_t inputWidth = ShapedType::kDynamic;
3177 int64_t inputHeight = ShapedType::kDynamic;
3178 int64_t inputChannels = ShapedType::kDynamic;
3179
3180 int64_t weightWidth = ShapedType::kDynamic;
3181 int64_t weightHeight = ShapedType::kDynamic;
3182 int64_t depthChannels = ShapedType::kDynamic;
3183
3184 // Input shape describes input width/height and batch.
3185 ShapeAdaptor inputShape(adaptor.getInput().getType());
3186 if (inputShape.hasRank()) {
3187 outputShape[0] = inputShape.getDimSize(0);
3188 inputHeight = inputShape.getDimSize(1);
3189 inputWidth = inputShape.getDimSize(2);
3190 inputChannels = inputShape.getDimSize(3);
3191 }
3192
3193 // Weight shapes describes the filter width/height and the output channels.
3194 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3195 if (weightShape.hasRank()) {
3196 weightHeight = weightShape.getDimSize(0);
3197 weightWidth = weightShape.getDimSize(1);
3198 inputChannels = ShapedType::isDynamic(inputChannels)
3199 ? weightShape.getDimSize(2)
3200 : inputChannels;
3201 depthChannels = weightShape.getDimSize(3);
3202 }
3203
3204 // If both inputChannels and depthChannels are available we can determine
3205 // the output channels.
3206 if (!ShapedType::isDynamic(inputChannels) &&
3207 !ShapedType::isDynamic(depthChannels)) {
3208 outputShape[3] = inputChannels * depthChannels;
3209 }
3210
3211 // Bias shape can describe the output channels.
3212 ShapeAdaptor biasShape(adaptor.getBias().getType());
3213 if (biasShape.hasRank()) {
3214 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3215 ? biasShape.getDimSize(0)
3216 : outputShape[3];
3217 }
3218
3219 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3220 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3221 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3222
3223 if (!ShapedType::isDynamic(inputHeight) &&
3224 !ShapedType::isDynamic(weightHeight)) {
3225 int64_t inputSize = inputHeight + padding[0] + padding[1];
3226 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3227 int64_t unstridedResult = inputSize - filterSize + 1;
3228 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3229 }
3230
3231 if (!ShapedType::isDynamic(inputWidth) &&
3232 !ShapedType::isDynamic(weightWidth)) {
3233 int64_t inputSize = inputWidth + padding[2] + padding[3];
3234 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3235 int64_t unstridedResult = inputSize - filterSize + 1;
3236 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3237 }
3238
3239 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3240 return success();
3241}
3242
3243LogicalResult DepthwiseConv2DOp::verify() {
3244 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3245 verifyConvOpErrorIf(*this).failed())
3246 return failure();
3247 return success();
3248}
3249
3250LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3251 MLIRContext *context, ::std::optional<Location> location,
3252 TransposeConv2DOp::Adaptor adaptor,
3253 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3254 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3255
3256 int64_t inputWidth = ShapedType::kDynamic;
3257 int64_t inputHeight = ShapedType::kDynamic;
3258 int64_t weightWidth = ShapedType::kDynamic;
3259 int64_t weightHeight = ShapedType::kDynamic;
3260
3261 // Input shape describes input width/height and batch.
3262 ShapeAdaptor inputShape(adaptor.getInput().getType());
3263 if (inputShape.hasRank()) {
3264 outputShape[0] = ShapedType::isDynamic(outputShape[0])
3265 ? inputShape.getDimSize(0)
3266 : outputShape[0];
3267 inputHeight = inputShape.getDimSize(1);
3268 inputWidth = inputShape.getDimSize(2);
3269 }
3270
3271 // Weight shapes describes the filter width/height and the output channels.
3272 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3273 if (weightShape.hasRank()) {
3274 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3275 ? weightShape.getDimSize(0)
3276 : outputShape[3];
3277 weightHeight = weightShape.getDimSize(1);
3278 weightWidth = weightShape.getDimSize(2);
3279 }
3280
3281 // Bias shape can describe the output channels.
3282 ShapeAdaptor biasShape(adaptor.getInput().getType());
3283 if (biasShape.hasRank()) {
3284 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3285 ? biasShape.getDimSize(0)
3286 : outputShape[3];
3287 }
3288
3289 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
3290 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3291
3292 if (!ShapedType::isDynamic(inputHeight) &&
3293 !ShapedType::isDynamic(weightHeight)) {
3294 int64_t calculateSize =
3295 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3296 outputShape[1] =
3297 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3298 }
3299
3300 if (!ShapedType::isDynamic(inputWidth) &&
3301 !ShapedType::isDynamic(weightWidth)) {
3302 int64_t calculateSize =
3303 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3304 outputShape[2] =
3305 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3306 }
3307
3308 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3309 return success();
3310}
3311
3312LogicalResult TransposeConv2DOp::verify() {
3313 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
3314 return failure();
3315
3316 const llvm::ArrayRef<int64_t> strides = getStride();
3317 const int64_t strideY = strides[0];
3318 const int64_t strideX = strides[1];
3319
3320 if (strideY < 1 || strideX < 1)
3321 return emitOpError("expect all stride values to be >= 1, got [")
3322 << strides << "]";
3323
3324 const auto checkPadAgainstKernelDim =
3325 [this](int64_t pad_value, int64_t kernel_dim_size,
3326 llvm::StringRef pad_name,
3327 llvm::StringRef kernel_dim_name) -> LogicalResult {
3328 if (pad_value <= -kernel_dim_size)
3329 return emitOpError("expected ")
3330 << pad_name << " > -" << kernel_dim_name
3331 << ", but got: " << pad_name << "=" << pad_value << " and "
3332 << kernel_dim_name << "=" << kernel_dim_size;
3333 return success();
3334 };
3335
3336 const llvm::ArrayRef<int64_t> padding = getOutPad();
3337 const int64_t outPadTop = padding[0];
3338 const int64_t outPadBottom = padding[1];
3339 const int64_t outPadLeft = padding[2];
3340 const int64_t outPadRight = padding[3];
3341
3342 const auto weightType =
3343 llvm::dyn_cast<RankedTensorType>(getWeight().getType());
3344
3345 if (weightType) {
3346 const int64_t kernelHeight = weightType.getDimSize(1);
3347 if (!ShapedType::isDynamic(kernelHeight)) {
3348 if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3349 "out_pad_top", "KH")))
3350 return failure();
3351
3352 if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3353 "out_pad_bottom", "KH")))
3354 return failure();
3355 }
3356
3357 const int64_t kernelWidth = weightType.getDimSize(2);
3358 if (!ShapedType::isDynamic(kernelWidth)) {
3359 if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3360 "out_pad_left", "KW")))
3361 return failure();
3362
3363 if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3364 "out_pad_right", "KW")))
3365 return failure();
3366 }
3367 }
3368
3369 // Rest of the checks depend on the output type being a RankedTensorType
3370 const auto outputType =
3371 llvm::dyn_cast<RankedTensorType>(getOutput().getType());
3372 if (!outputType)
3373 return success();
3374
3375 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
3376 if (inputType && weightType) {
3377 const int64_t inputHeight = inputType.getDimSize(1);
3378 const int64_t kernelHeight = weightType.getDimSize(1);
3379 const int64_t outputHeight = outputType.getDimSize(1);
3380
3381 if (!ShapedType::isDynamic(inputHeight) &&
3382 !ShapedType::isDynamic(outputHeight)) {
3383 if (outputHeight !=
3384 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3385 return emitOpError(
3386 "dimension mismatch: expected OH == (IH - 1) * stride_y "
3387 "+ out_pad_top + out_pad_bottom + KH, but got ")
3388 << outputHeight << " != (" << inputHeight << " - 1) * "
3389 << strideY << " + " << outPadTop << " + " << outPadBottom
3390 << " + " << kernelHeight;
3391 }
3392
3393 const int64_t inputWidth = inputType.getDimSize(2);
3394 const int64_t kernelWidth = weightType.getDimSize(2);
3395 const int64_t outputWidth = outputType.getDimSize(2);
3396
3397 if (!ShapedType::isDynamic(inputWidth) &&
3398 !ShapedType::isDynamic(outputWidth)) {
3399 if (outputWidth !=
3400 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3401 return emitOpError(
3402 "dimension mismatch: expected OW == (IW - 1) * stride_x "
3403 "+ out_pad_left + out_pad_right + KW, but got ")
3404 << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3405 << " + " << outPadLeft << " + " << outPadRight << " + "
3406 << kernelWidth;
3407 }
3408 }
3409
3410 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());
3411
3412 if (!biasType)
3413 return success();
3414
3415 const int64_t biasChannels = biasType.getDimSize(0);
3416
3417 // Skip further checks if bias is dynamic
3418 if (biasChannels == ShapedType::kDynamic)
3419 return success();
3420
3421 const int64_t outputChannels = outputType.getDimSize(3);
3422 if (biasChannels != outputChannels && biasChannels != 1)
3423 return emitOpError(
3424 "bias channels expected to be equal to output channels (")
3425 << outputChannels << ") or 1, got " << biasChannels;
3426
3427 return success();
3428}
3429
3430LogicalResult RescaleOp::verify() {
3431 auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
3432 if (!inputType) {
3433 emitOpError("expect shaped tensor for input, got ") << getInput().getType();
3434 return failure();
3435 }
3436
3437 auto inputElementType =
3438 getStorageElementTypeOrSelf(inputType.getElementType());
3439 if (!mlir::isa<IntegerType>(inputElementType)) {
3440 emitOpError("expect input to have integer element type, got ")
3441 << inputElementType;
3442 return failure();
3443 }
3444
3445 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
3446 if (!outputType) {
3447 emitOpError("expect shaped tensor for output, got ")
3448 << getOutput().getType();
3449 return failure();
3450 }
3451
3452 auto outputElementType =
3453 getStorageElementTypeOrSelf(outputType.getElementType());
3454 if (!mlir::isa<IntegerType>(outputElementType)) {
3455 emitOpError("expect output to have integer element type, got ")
3456 << outputElementType;
3457 return failure();
3458 }
3459
3460 if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
3461 .failed())
3462 return failure();
3463
3464 if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
3465 .failed())
3466 return failure();
3467
3468 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3469 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3470 return failure();
3471
3472 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3473 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3474 return failure();
3475
3476 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
3477 if (!multiplierType) {
3478 emitOpError("expect shaped tensor for multiplier, got ")
3479 << getMultiplier().getType();
3480 return failure();
3481 }
3482
3483 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
3484 if (!shiftType) {
3485 emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
3486 return failure();
3487 }
3488
3489 // multiplier element type must be i32 for scale32 = true
3490 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3491 emitOpError("expect i32 element type for multiplier for scale32=true, got ")
3492 << multiplierType.getElementType();
3493 return failure();
3494 }
3495
3496 // multiplier element type must be i16 for scale32 = false
3497 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3498 emitOpError(
3499 "expect i16 element type for multiplier for scale32=false, got ")
3500 << multiplierType.getElementType();
3501 return failure();
3502 }
3503
3504 if (!inputType.hasRank())
3505 return success();
3506
3507 // multiplier/shift must have shape = {numChannels},
3508 // where numChannel is 1 if per_channel = false
3509 // otherwise numChannel is dimension in input shape's last axis
3510 int64_t numChannels = 1;
3511 if (getPerChannel()) {
3512 if (inputType.getRank() < 1) {
3513 emitOpError("requires input to be at least rank 1 when per_channel is "
3514 "true, but got rank ")
3515 << inputType.getRank();
3516 return failure();
3517 }
3518 numChannels = inputType.getDimSize(inputType.getRank() - 1);
3519 }
3520
3521 if (!multiplierType.hasRank())
3522 return success();
3523
3524 ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
3525 // multiplier input has rank 1 by dialect definition
3526 if (multiplierShape[0] != ShapedType::kDynamic &&
3527 multiplierShape[0] != numChannels) {
3528 emitOpError("expect shape of { ")
3529 << numChannels << " } for multiplier input, got { "
3530 << multiplierShape[0] << " }";
3531 return failure();
3532 }
3533
3534 if (!shiftType.hasRank())
3535 return success();
3536
3537 ArrayRef<int64_t> shiftShape = shiftType.getShape();
3538 // shift input has rank 1 by dialect definition
3539 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3540 emitOpError("expect shape of { ")
3541 << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
3542 return failure();
3543 }
3544
3545 return success();
3546}
3547
3548LogicalResult RescaleOp::inferReturnTypeComponents(
3549 MLIRContext *context, ::std::optional<Location> location,
3550 RescaleOp::Adaptor adaptor,
3551 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3552 ShapeAdaptor inputShape(adaptor.getInput().getType());
3553 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3554 return success();
3555}
3556
3557LogicalResult IfOp::inferReturnTypeComponents(
3558 MLIRContext *context, ::std::optional<Location> location,
3559 IfOp::Adaptor adaptor,
3560 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3561 llvm::SmallVector<tosa::YieldOp> yieldOps;
3562 for (Region *region : adaptor.getRegions()) {
3563 for (auto &block : *region)
3564 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3565 yieldOps.push_back(returnOp);
3566 }
3567
3568 if (yieldOps.empty())
3569 return failure();
3570
3571 // Get the initial type information for the yield op.
3572 llvm::SmallVector<ValueKnowledge> resultKnowledge;
3573 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3574 for (auto operand : yieldOps.front().getOperands()) {
3575 resultKnowledge.push_back(
3576 ValueKnowledge::getKnowledgeFromType(operand.getType()));
3577 }
3578
3579 for (auto yieldOp : yieldOps) {
3580 if (resultKnowledge.size() != yieldOp.getNumOperands())
3581 return failure();
3582
3583 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3584 int32_t index = it.index();
3585 auto meet = ValueKnowledge::meet(
3586 resultKnowledge[index],
3587 ValueKnowledge::getKnowledgeFromType(it.value().getType()));
3588 if (!meet)
3589 continue;
3590 resultKnowledge[index] = meet;
3591 }
3592 }
3593
3594 for (const ValueKnowledge &result : resultKnowledge) {
3595 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3596 }
3597
3598 return success();
3599}
3600
3601LogicalResult WhileOp::inferReturnTypeComponents(
3602 MLIRContext *context, ::std::optional<Location> location,
3603 WhileOp::Adaptor adaptor,
3604 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3605 llvm::SmallVector<tosa::YieldOp> yieldOps;
3606 for (auto &block : adaptor.getBodyGraph())
3607 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3608 yieldOps.push_back(returnOp);
3609
3610 // TOSA's while must have a tosa.yield as its terminator. If not found this
3611 // tosa.while is invalid.
3612 if (yieldOps.empty())
3613 return failure();
3614
3615 // Get the initial type information from the operand types.
3616 llvm::SmallVector<ValueKnowledge> resultKnowledge;
3617 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3618 for (auto operand : yieldOps.front().getOperands()) {
3619 resultKnowledge.push_back(
3620 ValueKnowledge::getKnowledgeFromType(operand.getType()));
3621 }
3622
3623 for (auto yieldOp : yieldOps) {
3624 if (resultKnowledge.size() != yieldOp.getNumOperands())
3625 return failure();
3626
3627 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3628 int32_t index = it.index();
3629 if (auto meet = ValueKnowledge::meet(
3630 resultKnowledge[index],
3631 ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
3632 resultKnowledge[index] = meet;
3633 }
3634 }
3635 }
3636
3637 for (const ValueKnowledge &result : resultKnowledge) {
3638 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3639 }
3640
3641 return success();
3642}
3643
3644std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3645 if (auto vt = llvm::dyn_cast<VectorType>(getType()))
3646 return llvm::to_vector<4>(vt.getShape());
3647 return std::nullopt;
3648}
3649
3650// parse and print of IfOp refer to the implementation of SCF dialect.
3651ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
3652 // Create the regions for 'then'.
3653 result.regions.reserve(2);
3654 Region *thenRegion = result.addRegion();
3655 Region *elseRegion = result.addRegion();
3656
3657 auto &builder = parser.getBuilder();
3658 OpAsmParser::UnresolvedOperand cond;
3659 // Create a i1 tensor type for the boolean condition.
3660 Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
3661 if (parser.parseOperand(cond) ||
3662 parser.resolveOperand(cond, i1Type, result.operands))
3663 return failure();
3664 // Parse optional results type list.
3665 if (parser.parseOptionalArrowTypeList(result.types))
3666 return failure();
3667 // Parse the 'then' region.
3668 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
3669 return failure();
3670
3671 // If we find an 'else' keyword then parse the 'else' region.
3672 if (!parser.parseOptionalKeyword("else")) {
3673 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
3674 return failure();
3675 }
3676
3677 // Parse the optional attribute list.
3678 if (parser.parseOptionalAttrDict(result.attributes))
3679 return failure();
3680 return success();
3681}
3682
3683void IfOp::print(OpAsmPrinter &p) {
3684 bool printBlockTerminators = false;
3685
3686 p << " " << getCondition();
3687 if (!getResults().empty()) {
3688 p << " -> (" << getResultTypes() << ")";
3689 // Print yield explicitly if the op defines values.
3690 printBlockTerminators = true;
3691 }
3692 p << ' ';
3693 p.printRegion(getThenGraph(),
3694 /*printEntryBlockArgs=*/false,
3695 /*printBlockTerminators=*/printBlockTerminators);
3696
3697 // Print the 'else' regions if it exists and has a block.
3698 auto &elseRegion = getElseGraph();
3699 if (!elseRegion.empty()) {
3700 p << " else ";
3701 p.printRegion(elseRegion,
3702 /*printEntryBlockArgs=*/false,
3703 /*printBlockTerminators=*/printBlockTerminators);
3704 }
3705
3706 p.printOptionalAttrDict((*this)->getAttrs());
3707}
3708
3709LogicalResult IfOp::verify() {
3710 if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
3711 "'then_graph' arguments", getInputList(),
3712 "'input_list'")
3713 .failed())
3714 return failure();
3715
3716 if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
3717 "'else_graph' arguments", getInputList(),
3718 "'input_list'")
3719 .failed())
3720 return failure();
3721
3722 auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
3723 if (errorIfTypeOrShapeMismatch(*this, thenYield.getInputs(),
3724 "'then_graph' results", getOutputList(),
3725 "'output_list'")
3726 .failed())
3727 return failure();
3728
3729 auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
3730 if (errorIfTypeOrShapeMismatch(*this, elseYield.getInputs(),
3731 "'else_graph' results", getOutputList(),
3732 "'output_list'")
3733 .failed())
3734 return failure();
3735
3736 auto condType = getCondition().getType();
3737 if (errorIfShapeNotSizeOne(*this, condType).failed())
3738 return emitOpError() << "'condition' must be a size 1 tensor, got "
3739 << condType;
3740
3741 return success();
3742}
3743
3744LogicalResult WhileOp::verify() {
3745 if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
3746 getOutputList(), "'output_list'")
3747 .failed())
3748 return failure();
3749
3750 if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
3751 "'cond_graph' arguments", getInputList(),
3752 "'input_list'")
3753 .failed())
3754 return failure();
3755
3756 if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
3757 "'body_graph' arguments", getInputList(),
3758 "'input_list'")
3759 .failed())
3760 return failure();
3761
3762 auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
3763 if (errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
3764 "'body_graph' results", getInputList(),
3765 "'input_list'")
3766 .failed())
3767 return failure();
3768
3769 // Condition block output must be a single element tensor with a single bool
3770 // value.
3771 auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
3772 if (condYield.getInputs().size() != 1)
3773 return emitOpError() << "require 'cond_graph' only have one result";
3774
3775 auto condOutType = condYield.getInputs()[0].getType();
3776 if (errorIfShapeNotSizeOne(*this, condOutType).failed())
3777 return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
3778 << condOutType;
3779
3780 if (!getElementTypeOrSelf(condOutType).isInteger(1))
3781 return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
3782 << condOutType;
3783
3784 return success();
3785}
3786
3787LogicalResult ReverseOp::verify() {
3788 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
3789 /* outType = */ getOutput().getType())
3790 .failed())
3791 return failure();
3792 TensorType inputType = getInput1().getType();
3793 TensorType outputType = getOutput().getType();
3794 int32_t reverseAxis = getAxis();
3795
3796 if (reverseAxis < 0)
3797 return emitOpError("expected non-negative reverse axis");
3798 if (inputType.hasRank()) {
3799 int64_t inputRank = inputType.getRank();
3800 // We allow for a special case where the input/output shape has rank 0 and
3801 // axis is also 0.
3802 if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
3803 return emitOpError("expect input tensor rank (")
3804 << inputRank << ") to be larger than reverse axis (" << reverseAxis
3805 << ")";
3806 }
3807 if (outputType.hasRank()) {
3808 int64_t outputRank = outputType.getRank();
3809 if (inputType.hasRank() && outputRank != inputType.getRank())
3810 return emitOpError(
3811 "expect output tensor rank to be equal to input tensor rank");
3812 if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
3813 return emitOpError("expect output tensor rank (")
3814 << outputRank << ") to be larger than reverse axis ("
3815 << reverseAxis << ")";
3816 }
3817 return success();
3818}
3819
3820LogicalResult tosa::SelectOp::verify() {
3821 // verify input2 and input3 have same element type as output
3822 if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
3823 /* outType = */ getOutput().getType())
3824 .failed() ||
3825 verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
3826 /* outType = */ getOutput().getType())
3827 .failed()) {
3828 return failure();
3829 }
3830 // verify input1 has element type of bool
3831 auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
3832 if (!predicateType) {
3833 return emitOpError("expect shaped tensor for input1, got ")
3834 << getInput1().getType();
3835 }
3836 auto predicateElementType = predicateType.getElementType();
3837 if (!predicateElementType.isInteger(1)) {
3838 return emitOpError("expect element type of bool for input1, got ")
3839 << predicateElementType;
3840 }
3841
3842 return success();
3843}
3844
3845LogicalResult tosa::VariableOp::verify() {
3846 StringRef symName = getName();
3847 FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
3848 if (succeeded(varOp))
3849 return emitOpError("illegal to have multiple declaration of '")
3850 << symName << "'";
3851
3852 return success();
3853}
3854
3855LogicalResult tosa::VariableReadOp::verify() {
3856 if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
3857 .failed())
3858 return failure();
3859
3860 return success();
3861}
3862
3863LogicalResult tosa::VariableWriteOp::verify() {
3864 if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
3865 .failed())
3866 return failure();
3867
3868 return success();
3869}
3870
3871// parse and print of WhileOp refer to the implementation of SCF dialect.
3872ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3873 SmallVector<OpAsmParser::Argument, 4> regionArgs;
3874 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3875 Region *cond = result.addRegion();
3876 Region *body = result.addRegion();
3877
3878 OptionalParseResult listResult =
3879 parser.parseOptionalAssignmentList(regionArgs, operands);
3880 if (listResult.has_value() && failed(listResult.value()))
3881 return failure();
3882
3883 FunctionType functionType;
3884 SMLoc typeLoc = parser.getCurrentLocation();
3885 if (failed(parser.parseColonType(functionType)))
3886 return failure();
3887
3888 result.addTypes(functionType.getResults());
3889
3890 if (functionType.getNumInputs() != operands.size()) {
3891 return parser.emitError(typeLoc)
3892 << "expected as many input types as operands "
3893 << "(expected " << operands.size() << " got "
3894 << functionType.getNumInputs() << ")";
3895 }
3896
3897 // Resolve input operands.
3898 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3899 parser.getCurrentLocation(),
3900 result.operands)))
3901 return failure();
3902
3903 // Propagate the types into the region arguments.
3904 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3905 regionArgs[i].type = functionType.getInput(i);
3906
3907 return failure(parser.parseRegion(*cond, regionArgs) ||
3908 parser.parseKeyword("do") || parser.parseRegion(*body) ||
3909 parser.parseOptionalAttrDictWithKeyword(result.attributes));
3910}
3911
3912static void printInitializationList(OpAsmPrinter &parser,
3913 Block::BlockArgListType blocksArgs,
3914 ValueRange initializers,
3915 StringRef prefix = "") {
3916 assert(blocksArgs.size() == initializers.size() &&
3917 "expected same length of arguments and initializers");
3918 if (initializers.empty())
3919 return;
3920
3921 parser << prefix << '(';
3922 llvm::interleaveComma(
3923 c: llvm::zip(t&: blocksArgs, u&: initializers), os&: parser,
3924 each_fn: [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
3925 parser << ")";
3926}
3927
3928void WhileOp::print(OpAsmPrinter &parser) {
3929 printInitializationList(parser, getCondGraph().front().getArguments(),
3930 getInputList(), " ");
3931 parser << " : ";
3932 parser.printFunctionalType(getInputList().getTypes(),
3933 getResults().getTypes());
3934 parser << ' ';
3935 parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
3936 parser << " do ";
3937 parser.printRegion(getBodyGraph());
3938 parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3939}
3940
3941// Create a rank-1 const tensor for zero point of the source tensor.
3942std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
3943 Location loc,
3944 Type srcElemType,
3945 int64_t zp) {
3946 srcElemType = getStorageElementTypeOrSelf(type: srcElemType);
3947 auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
3948 if (llvm::isa<FloatType>(Val: srcElemType)) {
3949 auto zpAttr = DenseElementsAttr::get(
3950 zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
3951 return builder.create<tosa::ConstOp>(loc, zpType, zpAttr);
3952 }
3953 if (llvm::isa<IntegerType>(Val: srcElemType)) {
3954 auto zpAttr =
3955 DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
3956 return builder.create<tosa::ConstOp>(loc, zpType, zpAttr);
3957 }
3958 llvm::errs() << "zero point is not allowed for unsupported data types\n";
3959 return std::nullopt;
3960}
3961
3962//===----------------------------------------------------------------------===//
3963// TOSA Shape and Shape Operators Helper functions.
3964//===----------------------------------------------------------------------===//
3965
3966bool mlir::tosa::isa_tosa_shape_type(mlir::Type t) {
3967 return mlir::isa<tosa::shapeType>(t);
3968}
3969
3970LogicalResult
3971mlir::tosa::shapeType::verify(function_ref<InFlightDiagnostic()> emitError,
3972 int rank) {
3973 if (rank < 0)
3974 return emitError() << "invalid rank (must be >= 0): " << rank;
3975 return success();
3976}
3977
3978LogicalResult OpTrait::tosa::verifyTosaResolvableShapeOperands(Operation *op) {
3979 for (auto v : op->getOperands()) {
3980 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
3981 Operation *definingOp = v.getDefiningOp();
3982 if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
3983 return op->emitOpError(message: "shape operand is not compile time resolvable");
3984 }
3985 }
3986 }
3987 return success();
3988}
3989
3990LogicalResult OpTrait::tosa::verifyTosaShapeOperator(Operation *op) {
3991 for (auto type : op->getOperandTypes()) {
3992 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3993 return op->emitOpError(message: "must have operands with tosa shape type");
3994 }
3995 }
3996 for (auto type : op->getResultTypes()) {
3997 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3998 return op->emitOpError(message: "must have result with tosa shape type");
3999 }
4000 }
4001 return success();
4002}
4003
4004LogicalResult
4005OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) {
4006 if (failed(Result: OpTrait::impl::verifyAtLeastNOperands(op, numOperands: 1)) ||
4007 failed(Result: verifyTosaShapeOperator(op)))
4008 return failure();
4009
4010 // delegate function that returns rank of shape type
4011 auto getRank = [](const Type type) {
4012 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4013 };
4014 auto operandTypes = op->getOperandTypes();
4015 auto resultTypes = op->getResultTypes();
4016
4017 auto rank = getRank(*op->getOperandTypes().begin());
4018 for (auto type : operandTypes) {
4019 if (getRank(type) != rank) {
4020 return op->emitOpError(message: "operands don't have matching ranks");
4021 }
4022 }
4023 for (auto type : resultTypes) {
4024 if (getRank(type) != rank) {
4025 return op->emitOpError(message: "result shape has different rank than operands");
4026 }
4027 }
4028 return success();
4029}
4030
4031//===----------------------------------------------------------------------===//
4032// TOSA Shape Operators verify functions.
4033//===----------------------------------------------------------------------===//
4034
4035LogicalResult tosa::ConstShapeOp::verify() {
4036 // check one dimensional rank
4037 auto valuesRank = getValues().getType().getRank();
4038 if (valuesRank != 1)
4039 return emitOpError("expect elements in attribute values with rank 1");
4040 // check that number of elements in values attr equal to rank of result shape
4041 auto count = getValues().getNumElements();
4042 auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
4043 if (!(count == rank || (count == 1 && rank == 0))) {
4044 return emitOpError("expect number of elements in attribute values (")
4045 << count << ") to be equal to the rank (" << rank
4046 << ") for the result shape type";
4047 }
4048 return success();
4049}
4050
4051//===----------------------------------------------------------------------===//
4052// TOSA Attribute Definitions.
4053//===----------------------------------------------------------------------===//
4054
4055#define GET_ATTRDEF_CLASSES
4056#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4057
4058//===----------------------------------------------------------------------===//
4059// TOSA Type Definitions.
4060//===----------------------------------------------------------------------===//
4061#define GET_TYPEDEF_CLASSES
4062#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4063
4064//===----------------------------------------------------------------------===//
4065// TOSA Operator Definitions.
4066//===----------------------------------------------------------------------===//
4067
4068#define GET_OP_CLASSES
4069#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
4070

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Tosa/IR/TosaOps.cpp