1//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// This file implements the TOSA Specification:
11// https://developer.mlplatform.org/w/tosa/
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/Tosa/IR/TosaOps.h"
16#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
17#include "mlir/Dialect/Quant/QuantOps.h"
18#include "mlir/Dialect/Tensor/IR/Tensor.h"
19#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
20#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
21#include "mlir/Dialect/Utils/IndexingUtils.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/DialectImplementation.h"
24#include "mlir/IR/Matchers.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/IR/TypeUtilities.h"
27#include "mlir/Interfaces/InferTypeOpInterface.h"
28#include "mlir/Transforms/InliningUtils.h"
29#include "llvm/ADT/APFloat.h"
30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/TypeSwitch.h"
32
33using namespace mlir;
34using namespace mlir::tosa;
35
36#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
37
38//===----------------------------------------------------------------------===//
39// Tosa dialect interface includes.
40//===----------------------------------------------------------------------===//
41
42#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
43
44namespace {
45#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
46
47//===----------------------------------------------------------------------===//
48// Dialect Function Inliner Interface.
49//===----------------------------------------------------------------------===//
50struct TosaInlinerInterface : public DialectInlinerInterface {
51 using DialectInlinerInterface::DialectInlinerInterface;
52
53 //===--------------------------------------------------------------------===//
54 // Analysis Hooks.
55 //===--------------------------------------------------------------------===//
56
57 /// All operations can be inlined by default.
58 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
59 IRMapping &map) const final {
60 return true;
61 }
62
63 /// All regions with If and While parent operators can be inlined.
64 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
65 IRMapping &map) const final {
66 return (isa<tosa::IfOp>(dest->getParentOp()) ||
67 isa<tosa::WhileOp>(dest->getParentOp()));
68 }
69};
70
71/// This class implements the bytecode interface for the Tosa dialect.
72struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
73 TosaDialectBytecodeInterface(Dialect *dialect)
74 : BytecodeDialectInterface(dialect) {}
75
76 //===--------------------------------------------------------------------===//
77 // Attributes
78
79 Attribute readAttribute(DialectBytecodeReader &reader) const override {
80 return ::readAttribute(getContext(), reader);
81 }
82
83 LogicalResult writeAttribute(Attribute attr,
84 DialectBytecodeWriter &writer) const override {
85 return ::writeAttribute(attr, writer);
86 }
87
88 //===--------------------------------------------------------------------===//
89 // Types
90
91 Type readType(DialectBytecodeReader &reader) const override {
92 return ::readType(getContext(), reader);
93 }
94
95 LogicalResult writeType(Type type,
96 DialectBytecodeWriter &writer) const override {
97 return ::writeType(type, writer);
98 }
99
100 void writeVersion(DialectBytecodeWriter &writer) const final {
101 // TODO: Populate.
102 }
103
104 std::unique_ptr<DialectVersion>
105 readVersion(DialectBytecodeReader &reader) const final {
106 // TODO: Populate
107 reader.emitError(msg: "Dialect does not support versioning");
108 return nullptr;
109 }
110
111 LogicalResult upgradeFromVersion(Operation *topLevelOp,
112 const DialectVersion &version) const final {
113 return success();
114 }
115};
116
117} // namespace
118
119//===----------------------------------------------------------------------===//
120// TOSA control flow support.
121//===----------------------------------------------------------------------===//
122
123/// Returns the while loop body.
124SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
125
126//===----------------------------------------------------------------------===//
127// Tosa dialect initialization.
128//===----------------------------------------------------------------------===//
129
130void TosaDialect::initialize() {
131 addOperations<
132#define GET_OP_LIST
133#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
134 >();
135 addAttributes<
136#define GET_ATTRDEF_LIST
137#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
138 >();
139 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
140 declarePromisedInterfaces<
141 mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
142 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, DivOp,
143 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
144 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
145 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
146 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
147 GreaterEqualOp, MatMulOp>();
148}
149
150Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
151 Type type, Location loc) {
152 // Tosa dialect constants only support ElementsAttr unlike standard dialect
153 // constant which supports all attributes.
154 if (llvm::isa<ElementsAttr>(value))
155 return builder.create<tosa::ConstOp>(loc, type,
156 llvm::cast<ElementsAttr>(value));
157 return nullptr;
158}
159
160//===----------------------------------------------------------------------===//
161// Parsers and printers
162//===----------------------------------------------------------------------===//
163
164ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
165 Attribute &attr) {
166 if (succeeded(result: parser.parseOptionalEqual())) {
167 if (failed(result: parser.parseAttribute(result&: attr))) {
168 return parser.emitError(loc: parser.getCurrentLocation())
169 << "expected attribute";
170 }
171 if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
172 typeAttr = TypeAttr::get(typedAttr.getType());
173 }
174 return success();
175 }
176
177 Type type;
178 if (failed(result: parser.parseColonType(result&: type))) {
179 return parser.emitError(loc: parser.getCurrentLocation()) << "expected type";
180 }
181 typeAttr = TypeAttr::get(type);
182
183 return success();
184}
185
186void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
187 Attribute attr) {
188 bool needsSpace = false;
189 auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
190 if (!typedAttr || typedAttr.getType() != type.getValue()) {
191 p << ": ";
192 p.printAttribute(attr: type);
193 needsSpace = true; // subsequent attr value needs a space separator
194 }
195 if (attr) {
196 if (needsSpace)
197 p << ' ';
198 p << "= ";
199 p.printAttribute(attr);
200 }
201}
202
203//===----------------------------------------------------------------------===//
204// TOSA Operator Verifiers.
205//===----------------------------------------------------------------------===//
206
207static bool hasZeroDimension(ShapedType shapedType) {
208 if (!shapedType.hasRank())
209 return false;
210
211 auto rank = shapedType.getRank();
212
213 for (int i = 0; i < rank; i++) {
214 if (shapedType.isDynamicDim(i))
215 continue;
216 if (shapedType.getDimSize(i) == 0)
217 return true;
218 }
219
220 return false;
221}
222
223template <typename T> static LogicalResult verifyConvOp(T op) {
224 // All TOSA conv ops have an input() and weight().
225 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
226 auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
227
228 // Must be ranked tensor types
229 if (!inputType) {
230 op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
231 return failure();
232 }
233 if (!weightType) {
234 op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
235 return failure();
236 }
237
238 if (hasZeroDimension(inputType))
239 return op.emitOpError() << "tensor has a dimension with size zero. Each "
240 "dimension of a tensor must have size >= 1";
241
242 auto inputEType = inputType.getElementType();
243 auto weightEType = weightType.getElementType();
244
245 bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
246 bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
247
248 // Either both must be quantized or both unquantized.
249 if (inputIsQuant != weightIsQuant) {
250 op.emitOpError(
251 "expect both input and weight to be float or not together, got ")
252 << inputEType << " and " << weightEType;
253 return failure();
254 }
255
256 // Quantized type must have constructed the quantizationattr, and unquantized
257 // types should not have a quantizationattr.
258 if ((inputIsQuant && !op.getQuantizationInfo()) ||
259 (!inputIsQuant && op.getQuantizationInfo())) {
260 op.emitOpError("quantizationattr is required for quantized type, and not "
261 "allowed for float type");
262 return failure();
263 }
264
265 return success();
266}
267
268LogicalResult tosa::ArgMaxOp::verify() {
269 // Ensure output is of 32-bit integer
270 const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
271 if (!resultETy.isIntOrIndex())
272 return emitOpError("result tensor is not of integer type");
273
274 // Ensure axis is within the tensor rank
275 const auto inputType = llvm::cast<ShapedType>(getInput().getType());
276 const int64_t axis = getAxisAttr().getInt();
277 if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
278 return emitOpError("specified axis is outside the rank of the tensor");
279
280 return success();
281}
282
283LogicalResult tosa::AvgPool2dOp::verify() {
284 auto inputType = llvm::cast<ShapedType>(getInput().getType());
285 if (hasZeroDimension(inputType))
286 return emitOpError() << "tensor has a dimension with size zero. Each "
287 "dimension of a tensor must have size >= 1";
288
289 auto inputETy = inputType.getElementType();
290 auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
291
292 if (auto quantType =
293 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
294 inputETy = quantType.getStorageType();
295
296 if (auto quantType =
297 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
298 resultETy = quantType.getStorageType();
299
300 auto accType = getAccType();
301 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
302 return emitOpError("accumulator type for integer tensor is not i32");
303
304 if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
305 return emitOpError("accumulator type for f16 tensor is not f16/f32");
306
307 if (inputETy.isBF16() && !accType.isF32())
308 return emitOpError("accumulator type for bf16 tensor is not f32");
309
310 if (inputETy.isF32() && !accType.isF32())
311 return emitOpError("accumulator type for f32 tensor is not f32");
312
313 if ((inputETy.isF32() && resultETy.isF32()) ||
314 (inputETy.isF16() && resultETy.isF16()) ||
315 (inputETy.isBF16() && resultETy.isBF16()) ||
316 (inputETy.isInteger(8) && resultETy.isInteger(8)) ||
317 (inputETy.isInteger(16) && resultETy.isInteger(16)))
318 return success();
319
320 return emitOpError("input/output element types are incompatible.");
321}
322
323LogicalResult tosa::ClampOp::verify() {
324 mlir::Type inputETy =
325 llvm::cast<ShapedType>(getInput().getType()).getElementType();
326 if (auto quantType =
327 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
328 inputETy = quantType.getStorageType();
329 }
330 mlir::Type maxFpType = getMaxFpAttr().getType();
331 mlir::Type minFpType = getMinFpAttr().getType();
332 mlir::Type outputETy =
333 llvm::cast<ShapedType>(getOutput().getType()).getElementType();
334 if (auto quantType =
335 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
336 outputETy = quantType.getStorageType();
337 }
338 unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
339
340 if (inputETy != outputETy)
341 return emitOpError("input/output element types are incompatible.");
342
343 // if input datatype is float, check that the two min/max_fp attributes share
344 // the same type and that their type is either the same of the input's
345 // datatype, or a float type whose bitwidth > input datatype bitwidth
346 if (!inputETy.isInteger(dataTypeBitWidth)) {
347 if (((maxFpType != minFpType) ||
348 (maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <=
349 inputETy.getIntOrFloatBitWidth())))
350 return emitOpError("min/max attributes types are incompatible with "
351 "input/output element types.");
352 }
353
354 return success();
355}
356
357//===----------------------------------------------------------------------===//
358// TOSA Operator Quantization Builders.
359//===----------------------------------------------------------------------===//
360
361/// This builder is called on all convolution operators except TransposeConv,
362/// which has specialized output shape semantics. The builder also defines the
363/// bitwidth of the output given the bit width of the input & weight content.
364static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
365 Type outputType, Value input, Value weight,
366 Value bias, DenseI64ArrayAttr pad,
367 DenseI64ArrayAttr stride,
368 DenseI64ArrayAttr dilation) {
369
370 result.addOperands(newOperands: {input, weight, bias});
371 result.addAttribute("pad", pad);
372 result.addAttribute("stride", stride);
373 result.addAttribute("dilation", dilation);
374
375 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
376 if (quantAttr) {
377 result.addAttribute("quantization_info", quantAttr);
378 result.addTypes(
379 newTypes: buildConvOpResultTypeInfo(builder, outputType, input, weight));
380 } else {
381 result.addTypes(newTypes: outputType);
382 }
383}
384
385/// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
386static void buildTransConvOpWithQuantInfo(
387 OpBuilder &builder, OperationState &result, Type outputType, Value input,
388 Value weight, Value bias, DenseI64ArrayAttr outpad,
389 DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {
390 result.addOperands(newOperands: {input, weight, bias});
391 result.addAttribute("out_pad", outpad);
392 result.addAttribute("stride", stride);
393 result.addAttribute("out_shape", outputShape);
394 auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
395
396 if (quantAttr) {
397 result.addAttribute("quantization_info", quantAttr);
398 result.addTypes(
399 newTypes: buildConvOpResultTypeInfo(builder, outputType, input, weight));
400 } else {
401 result.addTypes(newTypes: outputType);
402 }
403}
404
405/// The tosa.fully_connected op has its own builder as it does not have
406/// strides/dilation/padding.
407static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
408 Type outputType, Value input, Value weight,
409 Value bias) {
410
411 result.addOperands(newOperands: {input, weight, bias});
412 auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
413 if (quantAttr) {
414 result.addAttribute("quantization_info", quantAttr);
415 result.addTypes(
416 newTypes: buildConvOpResultTypeInfo(builder, outputType, input, weight));
417 } else {
418 result.addTypes(newTypes: outputType);
419 }
420}
421
422/// The tosa.matmul op is also intended to be generated where a fully_connected
423/// op must be constructed where the weight is not a constant. In this case,
424/// the fully_connected op must be expressed using matmul.
425/// TODO: Add link to the leglization document explaining this.
426static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
427 OperationState &result, Type outputType,
428 Value a, Value b) {
429 result.addOperands(newOperands: {a, b});
430 auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
431
432 if (quantAttr) {
433 result.addAttribute("quantization_info", quantAttr);
434
435 auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
436 assert(inputType && "Input must be a shaped tensor type!");
437
438 auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
439 inputType.getElementType());
440 assert(inputQType && "Tensor must have quantized datatype!");
441
442 unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
443
444 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
445 assert(outputShapedType && "Output must be a shaped type");
446
447 IntegerType accElementType;
448 if (inputBits == 16)
449 accElementType = builder.getIntegerType(48);
450 else
451 accElementType = builder.getI32Type();
452 auto accType = outputShapedType.clone(accElementType);
453 result.addTypes(accType);
454 } else {
455 result.addTypes(newTypes: outputType);
456 }
457}
458
459/// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
460/// but avg_pool operator has its own builder as it has additional parameters
461/// not part of the unary ops.
462static void
463buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
464 Type outputType, Value input,
465 DenseArrayAttr kernel, DenseArrayAttr stride,
466 DenseArrayAttr pad, TypeAttr accType) {
467 result.addOperands(newOperands: input);
468 result.addAttribute("kernel", kernel);
469 result.addAttribute("stride", stride);
470 result.addAttribute("pad", pad);
471 result.addAttribute("acc_type", accType);
472 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
473 if (quantAttr)
474 result.addAttribute("quantization_info", quantAttr);
475 result.types.push_back(Elt: outputType);
476}
477
478/// This builder is called on single-parameter unary operators that have scale
479/// relationship between their input and output, expressed by the
480/// UnaryOpQuantizationAttr.
481static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
482 OperationState &result, Type outputType,
483 Value input) {
484 result.addOperands(newOperands: input);
485 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
486 if (quantAttr)
487 result.addAttribute("quantization_info", quantAttr);
488 result.types.push_back(Elt: outputType);
489}
490
491/// This builder is called on TOSA pad operator that needs to create its own
492/// OptionalAttr quantization_attr parameter to scale the padding values
493/// correctly. No pad_const is interpreted as zero-padding.
494static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
495 Type outputType, Value input,
496 Value paddings) {
497 result.addOperands(newOperands: {input, paddings});
498 auto quantAttr = buildPadOpQuantizationAttr(builder, input);
499 if (quantAttr)
500 result.addAttribute("quantization_info", quantAttr);
501 result.types.push_back(Elt: outputType);
502}
503
504/// This builder is called on TOSA pad operator when an explicit pad_const
505/// value is passed in. It also optionally constructs quantization_attr.
506static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder,
507 OperationState &result,
508 Type outputType, Value input,
509 Value paddings,
510 Value padConst) {
511 result.addOperands(newOperands: {input, paddings, padConst});
512 auto quantAttr = buildPadOpQuantizationAttr(builder, input);
513 if (quantAttr)
514 result.addAttribute("quantization_info", quantAttr);
515 result.types.push_back(Elt: outputType);
516}
517
518//===----------------------------------------------------------------------===//
519// TOSA Operator Return Type Inference.
520//===----------------------------------------------------------------------===//
521
522static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
523 SmallVector<int64_t> &outShape) {
524 int64_t outRank = 0;
525 for (int i = 0, e = operands.size(); i != e; ++i) {
526 auto shape = operands.getShape(index: i);
527 if (!shape.hasRank()) {
528 // TODO(jennik): Update function to have better case handling for invalid
529 // operands and for ranked tensors.
530 return failure();
531 }
532 outRank = std::max<int64_t>(a: outRank, b: shape.getRank());
533 }
534
535 outShape.resize(N: outRank, NV: 1);
536
537 for (int i = 0, e = operands.size(); i != e; ++i) {
538 auto shape = operands.getShape(index: i);
539 auto rankDiff = outShape.size() - shape.getRank();
540
541 for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
542 auto dim1 = outShape[i + rankDiff];
543 auto dim2 = shape.getDimSize(index: i);
544 auto resolvedDim = dim1;
545
546 if (dim1 == 1) {
547 resolvedDim = dim2;
548 } else if (dim2 == 1) {
549 resolvedDim = dim1;
550 } else if (dim1 != dim2) {
551 return failure();
552 }
553 outShape[i + rankDiff] = resolvedDim;
554 }
555 }
556
557 return success();
558}
559
560LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
561 MLIRContext *context, ::std::optional<Location> location,
562 ArgMaxOp::Adaptor adaptor,
563 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
564 ShapeAdaptor inputShape(adaptor.getInput().getType());
565 IntegerAttr axis = adaptor.getProperties().axis;
566 int32_t axisVal = axis.getValue().getSExtValue();
567
568 if (!inputShape.hasRank()) {
569 inferredReturnShapes.push_back(ShapedTypeComponents());
570 return success();
571 }
572
573 SmallVector<int64_t> outShape;
574 outShape.reserve(inputShape.getRank() - 1);
575 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
576 if (i == axisVal)
577 continue;
578 outShape.push_back(inputShape.getDimSize(i));
579 }
580
581 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
582 return success();
583}
584
585LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
586 MLIRContext *context, ::std::optional<Location> location,
587 RFFT2dOp::Adaptor adaptor,
588 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
589 ShapeAdaptor inputShape(adaptor.getInput().getType());
590
591 if (!inputShape.hasRank())
592 return failure();
593
594 llvm::SmallVector<int64_t> outputShape;
595 outputShape.resize(3, ShapedType::kDynamic);
596 outputShape[0] = inputShape.getDimSize(0);
597 outputShape[1] = inputShape.getDimSize(1);
598 int64_t inWidth = inputShape.getDimSize(2);
599
600 // Note that we can support this calculation symbolically
601 // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
602 if (inWidth != ShapedType::kDynamic)
603 outputShape[2] = inWidth / 2 + 1;
604
605 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
606 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
607
608 return success();
609}
610
611LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
612 MLIRContext *context, ::std::optional<Location> location,
613 FFT2dOp::Adaptor adaptor,
614 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
615 inferredReturnShapes.push_back(
616 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
617 inferredReturnShapes.push_back(
618 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
619 return success();
620}
621
622LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
623 MLIRContext *context, ::std::optional<Location> location,
624 ConcatOp::Adaptor adaptor,
625 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
626 // Infer all dimension sizes by reducing based on inputs.
627 const Properties &prop = adaptor.getProperties();
628 int32_t axis = prop.axis.getValue().getSExtValue();
629 llvm::SmallVector<int64_t> outputShape;
630 bool hasRankedInput = false;
631 for (auto operand : adaptor.getOperands()) {
632 ShapeAdaptor operandShape(operand.getType());
633 if (!operandShape.hasRank())
634 continue;
635
636 // Copy the Operand's rank.
637 if (!hasRankedInput)
638 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
639
640 // Copy shapes until the dim is non-dynamic.
641 for (int i = 0, s = operandShape.getRank(); i < s; i++) {
642 if (i == axis || operandShape.isDynamicDim(i))
643 continue;
644 if (outputShape[i] == ShapedType::kDynamic)
645 outputShape[i] = operandShape.getDimSize(i);
646 if (outputShape[i] != operandShape.getDimSize(i))
647 return emitOptionalError(location,
648 "Cannot concat tensors with different sizes"
649 " on the non-axis dimension ",
650 i);
651 }
652
653 hasRankedInput = true;
654 }
655 Type inputType =
656 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
657 if (!hasRankedInput) {
658 inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
659 return success();
660 }
661
662 // Determine the dimension size along the concatenation axis.
663 int64_t concatDimSize = 0;
664 for (auto operand : adaptor.getOperands()) {
665 ShapeAdaptor operandShape(operand.getType());
666
667 // We need to know the length of the concatenation axis of all inputs to
668 // determine the dimension size of the output shape.
669 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
670 concatDimSize = ShapedType::kDynamic;
671 break;
672 }
673
674 concatDimSize += operandShape.getDimSize(axis);
675 }
676
677 outputShape[axis] = concatDimSize;
678
679 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
680 return success();
681}
682
683LogicalResult tosa::EqualOp::inferReturnTypeComponents(
684 MLIRContext *context, ::std::optional<Location> location,
685 ValueShapeRange operands, DictionaryAttr attributes,
686 OpaqueProperties properties, RegionRange regions,
687 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
688 auto elementType = IntegerType::get(context, /*width=*/1);
689
690 llvm::SmallVector<int64_t> outShape;
691 if (resolveBroadcastShape(operands, outShape).failed()) {
692 inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
693 return success();
694 }
695
696 inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
697 return success();
698}
699
700bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
701 if (l.size() != r.size() || l.size() != 1)
702 return false;
703 return succeeded(verifyCompatibleShape(l[0], r[0]));
704}
705
706LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
707 MLIRContext *context, ::std::optional<Location> location,
708 FullyConnectedOp::Adaptor adaptor,
709 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
710 ShapeAdaptor inputShape(adaptor.getInput().getType());
711 ShapeAdaptor weightShape(adaptor.getWeight().getType());
712 ShapeAdaptor biasShape(adaptor.getBias().getType());
713
714 // All shapes are dynamic.
715 SmallVector<int64_t> outShape;
716 outShape.resize(2, ShapedType::kDynamic);
717
718 if (inputShape.hasRank()) {
719 outShape[0] = inputShape.getDimSize(0);
720 }
721
722 if (weightShape.hasRank()) {
723 outShape[1] = weightShape.getDimSize(0);
724 }
725
726 if (biasShape.hasRank()) {
727 outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0)
728 : outShape[1];
729 }
730
731 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
732 return success();
733}
734
735LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); }
736
737LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
738 MLIRContext *context, ::std::optional<Location> location,
739 MatMulOp::Adaptor adaptor,
740 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
741 ShapeAdaptor lhsShape(adaptor.getA().getType());
742 ShapeAdaptor rhsShape(adaptor.getB().getType());
743
744 // All shapes are dynamic.
745 SmallVector<int64_t> outShape;
746 outShape.resize(3, ShapedType::kDynamic);
747
748 if (lhsShape.hasRank()) {
749 outShape[0] = lhsShape.getDimSize(0);
750 outShape[1] = lhsShape.getDimSize(1);
751 }
752
753 if (rhsShape.hasRank()) {
754 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
755 : outShape[0];
756 outShape[2] = rhsShape.getDimSize(2);
757 }
758
759 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
760 return success();
761}
762
763LogicalResult tosa::PadOp::inferReturnTypeComponents(
764 MLIRContext *context, ::std::optional<Location> location,
765 PadOp::Adaptor adaptor,
766 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
767 ShapeAdaptor inputShape(adaptor.getInput1().getType());
768 ShapeAdaptor paddingShape(adaptor.getPadding().getType());
769 SmallVector<int64_t> outputShape;
770
771 // If both inputs have unknown shape, we cannot determine the shape of the
772 // output.
773 if (!inputShape.hasRank() && !paddingShape.hasRank()) {
774 inferredReturnShapes.push_back(ShapedTypeComponents());
775 return success();
776 }
777
778 // If the input rank is unknown we can info the output rank using the padding
779 // shape's first dim.
780 if (!inputShape.hasRank()) {
781 if (paddingShape.isDynamicDim(0)) {
782 inferredReturnShapes.push_back(ShapedTypeComponents());
783 return success();
784 }
785
786 outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamic);
787 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
788 return success();
789 }
790
791 DenseIntElementsAttr paddings;
792 // If the paddings value is not a constant, all dimensions must be dynamic.
793 if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) {
794 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
795 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
796 return success();
797 }
798
799 SmallVector<int64_t> paddingValues;
800 for (auto val : paddings) {
801 paddingValues.push_back(val.getSExtValue());
802 }
803
804 outputShape.reserve(inputShape.getRank());
805 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
806 if (inputShape.isDynamicDim(i)) {
807 outputShape.push_back(ShapedType::kDynamic);
808 continue;
809 }
810
811 outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
812 paddingValues[i * 2 + 1]);
813 }
814
815 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
816 return success();
817}
818
819static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
820 return to_vector(Range: llvm::map_range(C&: shape, F: [](int64_t dim) {
821 return dim == -1 ? ShapedType::kDynamic : dim;
822 }));
823}
824
825LogicalResult tosa::SliceOp::inferReturnTypeComponents(
826 MLIRContext *context, ::std::optional<Location> location,
827 SliceOp::Adaptor adaptor,
828 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
829 inferredReturnShapes.push_back(
830 ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
831 return success();
832}
833
834LogicalResult tosa::SliceOp::verify() {
835 auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
836 if (!inputType)
837 return success();
838
839 if (static_cast<size_t>(inputType.getRank()) != getStart().size())
840 return emitOpError(
841 "length of start attribute is not equal rank of input shape");
842
843 if (static_cast<size_t>(inputType.getRank()) != getSize().size())
844 return emitOpError(
845 "length of size attribute is not equal rank of input shape");
846
847 return success();
848}
849
850LogicalResult tosa::TableOp::inferReturnTypeComponents(
851 MLIRContext *context, ::std::optional<Location> location,
852 TableOp::Adaptor adaptor,
853 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
854 ShapeAdaptor inputShape(adaptor.getInput().getType());
855
856 if (!inputShape.hasRank()) {
857 inferredReturnShapes.push_back(ShapedTypeComponents());
858 return success();
859 }
860
861 inferredReturnShapes.resize(1);
862 inputShape.getDims(inferredReturnShapes[0]);
863 return success();
864}
865
866LogicalResult tosa::TileOp::inferReturnTypeComponents(
867 MLIRContext *context, ::std::optional<Location> location,
868 TileOp::Adaptor adaptor,
869 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
870 ArrayRef<int64_t> multiples = adaptor.getMultiples();
871 ShapeAdaptor inputShape(adaptor.getInput1().getType());
872 SmallVector<int64_t> outputShape;
873 if (!inputShape.hasRank()) {
874 outputShape.resize(multiples.size(), ShapedType::kDynamic);
875 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
876 return success();
877 } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
878 return failure();
879
880 // Any non dynamic dimension can be multiplied to a known size.
881 outputShape.reserve(multiples.size());
882 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
883 int64_t dim = inputShape.getDimSize(i);
884 if (dim != ShapedType::kDynamic)
885 dim *= multiples[i];
886 outputShape.push_back(dim);
887 }
888
889 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
890 return success();
891}
892
893LogicalResult tosa::TileOp::verify() {
894 ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
895 ShapedType outputType = llvm::cast<ShapedType>(getType());
896 auto multiples = getMultiples();
897
898 if (inputType.hasRank()) {
899 if (static_cast<size_t>(inputType.getRank()) != multiples.size())
900 return emitOpError("expect 'multiples' array to have length ")
901 << inputType.getRank() << " but got " << multiples.size() << ".";
902 if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
903 return emitOpError("expect same input and output tensor rank.");
904 } else if (outputType.hasRank() &&
905 static_cast<size_t>(outputType.getRank()) != multiples.size())
906 return emitOpError("expect 'multiples' array to have length ")
907 << outputType.getRank() << " but got " << multiples.size() << ".";
908
909 return success();
910}
911
912bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
913 if (l.size() != r.size() || l.size() != 1)
914 return false;
915 return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
916}
917
918LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
919 MLIRContext *context, ::std::optional<Location> location,
920 ReshapeOp::Adaptor adaptor,
921 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
922 ShapeAdaptor inputShape(adaptor.getInput1().getType());
923 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
924 llvm::SmallVector<int64_t> newShapeValue =
925 convertToMlirShape(adaptor.getNewShape());
926
927 // We cannot infer from the total number of elements so we must take the
928 // shape attribute as exact.
929 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
930 inferredReturnShapes.push_back(
931 ShapedTypeComponents(newShapeValue, inputType));
932 return success();
933 }
934
935 // Determine the number of elements covered by the slice of all static
936 // dimensions. This allows us to infer the length of the remaining dynamic
937 // dimension.
938 int64_t numElements = inputShape.getNumElements();
939 int64_t staticMul = 1;
940 for (auto val : newShapeValue) {
941 if (!ShapedType::isDynamic(val)) {
942 staticMul *= val;
943 }
944 }
945
946 // Determine the length of the dynamic dimension.
947 for (auto &val : newShapeValue) {
948 if (ShapedType::isDynamic(val))
949 val = numElements / staticMul;
950 }
951
952 inferredReturnShapes.push_back(
953 ShapedTypeComponents(newShapeValue, inputType));
954 return success();
955}
956
957mlir::LogicalResult tosa::ReshapeOp::verify() {
958 TensorType inputType = getInput1().getType();
959 RankedTensorType outputType = getType();
960
961 if (hasZeroDimension(inputType) || hasZeroDimension(outputType))
962 return emitOpError() << "tensor has a dimension with size zero. Each "
963 "dimension of a tensor must have size >= 1";
964
965 if ((int64_t) getNewShape().size() != outputType.getRank())
966 return emitOpError() << "new shape does not match result rank";
967
968 for (auto [newShapeDim, outputShapeDim] :
969 zip(getNewShape(), outputType.getShape()))
970 if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
971 newShapeDim != outputShapeDim)
972 return emitOpError() << "new shape is inconsistent with result shape";
973
974 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
975 int64_t inputElementsNum = inputType.getNumElements();
976 int64_t outputElementsNum = outputType.getNumElements();
977 if (inputElementsNum != outputElementsNum) {
978 return emitOpError() << "cannot reshape " << inputElementsNum
979 << " elements into " << outputElementsNum;
980 }
981 }
982
983 int missingDims = llvm::count(getNewShape(), -1);
984 if (missingDims > 1)
985 return emitOpError() << "expected at most one target dimension to be -1";
986
987 return mlir::success();
988}
989
990LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int64_t> &perms) {
991 // Perms must be constants.
992 DenseIntElementsAttr permsAttr;
993 if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
994 return failure();
995
996 // Transpose is not the identity transpose.
997 perms = llvm::to_vector(
998 llvm::map_range(permsAttr.getValues<APInt>(),
999 [](const APInt &val) { return val.getSExtValue(); }));
1000
1001 return success();
1002}
1003
1004LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1005 MLIRContext *context, ::std::optional<Location> location,
1006 TransposeOp::Adaptor adaptor,
1007 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1008 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1009 ShapeAdaptor permsShape(adaptor.getPerms().getType());
1010
1011 // We cannot infer anything from a rank-0 "permutation" tensor.
1012 if (permsShape.hasRank() && permsShape.getRank() == 0)
1013 return failure();
1014
1015 // If input rank and permutation length is unknown, the output rank is
1016 // unknown.
1017 if (!inputShape.hasRank() || !permsShape.hasRank() ||
1018 permsShape.isDynamicDim(0)) {
1019 inferredReturnShapes.push_back(ShapedTypeComponents());
1020 return success();
1021 }
1022
1023 // This would imply the number of permutations does not match the rank of the
1024 // input which is illegal.
1025 if (permsShape.getDimSize(0) != inputShape.getRank()) {
1026 return failure();
1027 }
1028
1029 SmallVector<int64_t> outputShape;
1030 // Rank-0 means no permutations matter.
1031 if (inputShape.getRank() == 0) {
1032 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1033 return success();
1034 }
1035
1036 // Check whether the input dimensions are all the same.
1037 bool allTheSame = true;
1038 for (int i = 1, s = inputShape.getRank(); i < s; i++) {
1039 if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
1040 allTheSame = false;
1041 break;
1042 }
1043 }
1044
1045 // If all of the input dimensions are the same we don't care about the
1046 // permutation.
1047 if (allTheSame) {
1048 outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
1049 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1050 return success();
1051 }
1052
1053 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1054 // If the permuations are a constant we can directly determine the output
1055 // shape.
1056 DenseIntElementsAttr attr;
1057 if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
1058 attr.getType().getRank() == 1) {
1059 ShapeAdaptor permShape = attr;
1060 // Constant permutation must be the same length as the input rank.
1061 if (inputShape.getRank() != permShape.getRank())
1062 return emitOptionalError(location,
1063 "constant permutation must be the same length"
1064 " as the input rank");
1065
1066 // Constant permutation values must be within the input rank.
1067 for (int i = 0, e = inputShape.getRank(); i < e; i++) {
1068 if (inputShape.getRank() <= permShape.getDimSize(i))
1069 return failure();
1070 }
1071
1072 outputShape.reserve(inputShape.getRank());
1073 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1074 outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
1075 }
1076 }
1077
1078 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1079 return success();
1080}
1081
1082LogicalResult tosa::TransposeOp::verify() {
1083 TensorType inputType = getInput1().getType();
1084 TensorType permType = getPerms().getType();
1085 TensorType outputType = getOutput().getType();
1086
1087 if (permType.hasRank() && permType.getRank() != 1)
1088 return emitOpError()
1089 << "expected permutation tensor to be rank 1 but got rank "
1090 << permType.getRank();
1091 if (inputType.hasRank() && permType.hasRank())
1092 if (!permType.isDynamicDim(0) &&
1093 permType.getDimSize(0) != inputType.getRank())
1094 return emitOpError() << "expected permutation tensor dim 0 to have size "
1095 << inputType.getRank()
1096 << " (input rank) but got size "
1097 << permType.getDimSize(0);
1098 if (inputType.hasRank() && outputType.hasRank() &&
1099 inputType.getRank() != outputType.getRank())
1100 return emitOpError()
1101 << "expected input tensor rank to equal result tensor rank";
1102 if (outputType.hasRank() && permType.hasRank())
1103 if (!permType.isDynamicDim(0) &&
1104 permType.getDimSize(0) != outputType.getRank())
1105 return emitOpError() << "expected permutation tensor dim 0 to have size "
1106 << outputType.getRank()
1107 << " (output rank) but got size "
1108 << permType.getDimSize(0);
1109
1110 SmallVector<int64_t> constantPerms;
1111 if (succeeded(getConstantPerms(constantPerms))) {
1112 // Assert that the permutation tensor has a rank, which means that the rank
1113 // has been verified above.
1114 assert(permType.hasRank() &&
1115 "Unexpectedly found permutation tensor without rank");
1116 if (!isPermutationVector(constantPerms))
1117 return emitOpError() << "expected valid permutation tensor";
1118 }
1119 return success();
1120}
1121
1122LogicalResult tosa::GatherOp::inferReturnTypeComponents(
1123 MLIRContext *context, ::std::optional<Location> location,
1124 GatherOp::Adaptor adaptor,
1125 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1126 llvm::SmallVector<int64_t> outputShape;
1127 outputShape.resize(3, ShapedType::kDynamic);
1128
1129 ShapeAdaptor valuesShape(adaptor.getValues().getType());
1130 if (valuesShape.hasRank()) {
1131 outputShape[0] = valuesShape.getDimSize(0);
1132 outputShape[2] = valuesShape.getDimSize(2);
1133 }
1134
1135 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1136 if (indicesShape.hasRank()) {
1137 if (outputShape[0] == ShapedType::kDynamic)
1138 outputShape[0] = indicesShape.getDimSize(0);
1139 if (outputShape[1] == ShapedType::kDynamic)
1140 outputShape[1] = indicesShape.getDimSize(1);
1141 }
1142
1143 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1144 return success();
1145}
1146
1147LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
1148 MLIRContext *context, ::std::optional<Location> location,
1149 ResizeOp::Adaptor adaptor,
1150 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1151 llvm::SmallVector<int64_t, 4> outputShape;
1152 outputShape.resize(4, ShapedType::kDynamic);
1153
1154 ShapeAdaptor inputShape(adaptor.getInput().getType());
1155 if (!inputShape.hasRank())
1156 return failure();
1157
1158 outputShape[0] = inputShape.getDimSize(0);
1159 outputShape[3] = inputShape.getDimSize(3);
1160 int64_t inputHeight = inputShape.getDimSize(1);
1161 int64_t inputWidth = inputShape.getDimSize(2);
1162
1163 if ((inputHeight == ShapedType::kDynamic) ||
1164 (inputWidth == ShapedType::kDynamic))
1165 return failure();
1166
1167 llvm::ArrayRef<int64_t> scaleInt = adaptor.getScale();
1168 llvm::ArrayRef<int64_t> offsetInt = adaptor.getOffset();
1169 llvm::ArrayRef<int64_t> borderInt = adaptor.getBorder();
1170
1171 // Compute the output shape based on attributes: scale, offset, and border.
1172 outputShape[1] =
1173 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
1174 scaleInt[1]) +
1175 1;
1176
1177 outputShape[2] =
1178 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
1179 scaleInt[3]) +
1180 1;
1181
1182 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1183 return success();
1184}
1185
1186LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
1187 MLIRContext *context, ::std::optional<Location> location,
1188 ScatterOp::Adaptor adaptor,
1189 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1190 llvm::SmallVector<int64_t> outputShape;
1191 outputShape.resize(3, ShapedType::kDynamic);
1192
1193 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
1194 if (valuesInShape.hasRank()) {
1195 outputShape[0] = valuesInShape.getDimSize(0);
1196 outputShape[1] = valuesInShape.getDimSize(1);
1197 outputShape[2] = valuesInShape.getDimSize(2);
1198 }
1199
1200 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1201 if (indicesShape.hasRank()) {
1202 if (outputShape[0] == ShapedType::kDynamic)
1203 outputShape[0] = indicesShape.getDimSize(0);
1204 }
1205
1206 ShapeAdaptor inputShape(adaptor.getInput().getType());
1207 if (inputShape.hasRank()) {
1208 if (outputShape[0] == ShapedType::kDynamic)
1209 outputShape[0] = inputShape.getDimSize(0);
1210 if (outputShape[2] == ShapedType::kDynamic)
1211 outputShape[2] = inputShape.getDimSize(2);
1212 }
1213
1214 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1215 return success();
1216}
1217
1218static LogicalResult ReduceInferReturnTypes(
1219 ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
1220 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1221 int64_t axisVal = axis.getValue().getSExtValue();
1222 if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
1223 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(inputType));
1224 return success();
1225 }
1226
1227 SmallVector<int64_t> outputShape;
1228 operandShape.getDims(res&: outputShape);
1229 outputShape[axisVal] = 1;
1230 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape, inputType));
1231 return success();
1232}
1233
1234#define COMPATIBLE_RETURN_TYPES(OP) \
1235 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
1236 if (l.size() != r.size() || l.size() != 1) \
1237 return false; \
1238 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
1239 return false; \
1240 return succeeded(verifyCompatibleShape(l[0], r[0])); \
1241 }
1242
1243#define REDUCE_SHAPE_INFER(OP) \
1244 LogicalResult OP::inferReturnTypeComponents( \
1245 MLIRContext *context, ::std::optional<Location> location, \
1246 OP::Adaptor adaptor, \
1247 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1248 Type inputType = \
1249 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
1250 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
1251 const Properties &prop = adaptor.getProperties(); \
1252 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
1253 inferredReturnShapes); \
1254 } \
1255 COMPATIBLE_RETURN_TYPES(OP)
1256
1257REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
1258REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
1259REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
1260REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
1261REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
1262REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
1263#undef REDUCE_SHAPE_INFER
1264COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
1265#undef COMPATIBLE_RETURN_TYPES
1266
1267template <typename T>
1268static LogicalResult verifyReduceOp(T op) {
1269 // All TOSA reduce Ops have input, output and axis.
1270 TensorType inputType = op.getInput().getType();
1271 TensorType outputType = op.getOutput().getType();
1272 int32_t reduceAxis = op.getAxis();
1273
1274 if (reduceAxis < 0) {
1275 op.emitOpError("reduce axis must not be negative");
1276 return failure();
1277 }
1278 if (inputType.hasRank()) {
1279 int64_t inputRank = inputType.getRank();
1280 // We allow for a special case where the input/output shape has rank 0 and
1281 // axis is also 0.
1282 if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
1283 op.emitOpError("expect input tensor rank (")
1284 << inputRank << ") to be larger than reduce axis (" << reduceAxis
1285 << ")";
1286 return failure();
1287 }
1288 }
1289 if (outputType.hasRank()) {
1290 int64_t outputRank = outputType.getRank();
1291 if (inputType.hasRank() && outputRank != inputType.getRank()) {
1292 op.emitOpError(
1293 "expect output tensor rank to be equal to input tensor rank");
1294 return failure();
1295 }
1296 if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
1297 op.emitOpError("expect output tensor rank (")
1298 << outputRank << ") to be larger than reduce axis (" << reduceAxis
1299 << ")";
1300 return failure();
1301 }
1302 // We can only verify the reduced dimension size to be 1 if this is not the
1303 // special case of output rank == 0.
1304 if (outputRank != 0) {
1305 auto outputShape = outputType.getShape();
1306 if (!outputType.isDynamicDim(reduceAxis) &&
1307 outputShape[reduceAxis] != 1) {
1308 op.emitOpError("expect reduced dimension size to be 1, got ")
1309 << outputShape[reduceAxis];
1310 return failure();
1311 }
1312 }
1313 }
1314 return success();
1315}
1316
1317LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
1318LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
1319LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
1320LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
1321LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); }
1322LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
1323
1324static LogicalResult NAryInferReturnTypes(
1325 const ValueShapeRange &operands,
1326 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1327 llvm::SmallVector<int64_t> outShape;
1328 if (resolveBroadcastShape(operands, outShape).failed()) {
1329 inferredReturnShapes.push_back(Elt: ShapedTypeComponents());
1330 } else {
1331 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outShape));
1332 }
1333 return success();
1334}
1335
1336#define NARY_SHAPE_INFER(OP) \
1337 LogicalResult OP::inferReturnTypeComponents( \
1338 MLIRContext *context, ::std::optional<Location> location, \
1339 ValueShapeRange operands, DictionaryAttr attributes, \
1340 OpaqueProperties properties, RegionRange regions, \
1341 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1342 return NAryInferReturnTypes(operands, inferredReturnShapes); \
1343 }
1344
1345NARY_SHAPE_INFER(tosa::AbsOp)
1346NARY_SHAPE_INFER(tosa::AddOp)
1347NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
1348NARY_SHAPE_INFER(tosa::BitwiseAndOp)
1349NARY_SHAPE_INFER(tosa::BitwiseOrOp)
1350NARY_SHAPE_INFER(tosa::BitwiseXorOp)
1351NARY_SHAPE_INFER(tosa::BitwiseNotOp)
1352NARY_SHAPE_INFER(tosa::CastOp)
1353NARY_SHAPE_INFER(tosa::CeilOp)
1354NARY_SHAPE_INFER(tosa::ClampOp)
1355NARY_SHAPE_INFER(tosa::ClzOp)
1356NARY_SHAPE_INFER(tosa::CosOp)
1357NARY_SHAPE_INFER(tosa::DivOp)
1358NARY_SHAPE_INFER(tosa::ExpOp)
1359NARY_SHAPE_INFER(tosa::FloorOp)
1360NARY_SHAPE_INFER(tosa::GreaterEqualOp)
1361NARY_SHAPE_INFER(tosa::GreaterOp)
1362NARY_SHAPE_INFER(tosa::IdentityOp)
1363NARY_SHAPE_INFER(tosa::LogOp)
1364NARY_SHAPE_INFER(tosa::LogicalAndOp)
1365NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
1366NARY_SHAPE_INFER(tosa::LogicalNotOp)
1367NARY_SHAPE_INFER(tosa::LogicalOrOp)
1368NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
1369NARY_SHAPE_INFER(tosa::LogicalXorOp)
1370NARY_SHAPE_INFER(tosa::MaximumOp)
1371NARY_SHAPE_INFER(tosa::MinimumOp)
1372NARY_SHAPE_INFER(tosa::MulOp)
1373NARY_SHAPE_INFER(tosa::NegateOp)
1374NARY_SHAPE_INFER(tosa::PowOp)
1375NARY_SHAPE_INFER(tosa::ReciprocalOp)
1376NARY_SHAPE_INFER(tosa::RescaleOp)
1377NARY_SHAPE_INFER(tosa::ReverseOp)
1378NARY_SHAPE_INFER(tosa::RsqrtOp)
1379NARY_SHAPE_INFER(tosa::SinOp)
1380NARY_SHAPE_INFER(tosa::SelectOp)
1381NARY_SHAPE_INFER(tosa::SubOp)
1382NARY_SHAPE_INFER(tosa::TanhOp)
1383NARY_SHAPE_INFER(tosa::ErfOp)
1384NARY_SHAPE_INFER(tosa::SigmoidOp)
1385#undef PRED_SHAPE_INFER
1386
1387static LogicalResult poolingInferReturnTypes(
1388 ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
1389 ArrayRef<int64_t> pad,
1390 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1391 llvm::SmallVector<int64_t> outputShape;
1392 outputShape.resize(4, ShapedType::kDynamic);
1393
1394 // We only know the rank if the input type is unranked.
1395 if (!inputShape) {
1396 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
1397 return success();
1398 }
1399
1400 // Batch and number of channels are identical for pooling layer.
1401 outputShape[0] = inputShape.getDimSize(index: 0);
1402 outputShape[3] = inputShape.getDimSize(index: 3);
1403
1404 int64_t height = inputShape.getDimSize(index: 1);
1405 int64_t width = inputShape.getDimSize(index: 2);
1406
1407 if (!ShapedType::isDynamic(height)) {
1408 int64_t padded = height + pad[0] + pad[1] - kernel[0];
1409 outputShape[1] = padded / stride[0] + 1;
1410 }
1411
1412 if (!ShapedType::isDynamic(width)) {
1413 int64_t padded = width + pad[2] + pad[3] - kernel[1];
1414 outputShape[2] = padded / stride[1] + 1;
1415 }
1416
1417 inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape));
1418 return success();
1419}
1420
1421LogicalResult Conv2DOp::inferReturnTypeComponents(
1422 MLIRContext *context, ::std::optional<Location> location,
1423 Conv2DOp::Adaptor adaptor,
1424 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1425 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
1426
1427 int64_t inputWidth = ShapedType::kDynamic;
1428 int64_t inputHeight = ShapedType::kDynamic;
1429 int64_t weightWidth = ShapedType::kDynamic;
1430 int64_t weightHeight = ShapedType::kDynamic;
1431
1432 // Input shape describes input width/height and batch.
1433
1434 ShapeAdaptor inputShape(adaptor.getInput().getType());
1435 if (inputShape.hasRank()) {
1436 outputShape[0] = inputShape.getDimSize(0);
1437 inputHeight = inputShape.getDimSize(1);
1438 inputWidth = inputShape.getDimSize(2);
1439 }
1440
1441 // Weight shapes describes the filter width/height and the output channels.
1442 ShapeAdaptor weightShape(adaptor.getWeight().getType());
1443 if (weightShape.hasRank()) {
1444 outputShape[3] = weightShape.getDimSize(0);
1445 weightHeight = weightShape.getDimSize(1);
1446 weightWidth = weightShape.getDimSize(2);
1447 }
1448
1449 // Bias shape can describe the output channels.
1450 ShapeAdaptor biasShape(adaptor.getBias().getType());
1451 if (biasShape.hasRank()) {
1452 outputShape[3] = ShapedType::isDynamic(outputShape[3])
1453 ? biasShape.getDimSize(0)
1454 : outputShape[3];
1455 }
1456
1457 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1458 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1459 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
1460
1461 if (!ShapedType::isDynamic(inputHeight) &&
1462 !ShapedType::isDynamic(weightHeight)) {
1463 int64_t inputSize = inputHeight + padding[0] + padding[1];
1464 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1465 int64_t unstridedResult = inputSize - filterSize + 1;
1466 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1467 }
1468
1469 if (!ShapedType::isDynamic(inputWidth) &&
1470 !ShapedType::isDynamic(weightWidth)) {
1471 int64_t inputSize = inputWidth + padding[2] + padding[3];
1472 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1473 int64_t unstridedResult = inputSize - filterSize + 1;
1474 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1475 }
1476
1477 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1478 return success();
1479}
1480
1481LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
1482
1483LogicalResult Conv3DOp::inferReturnTypeComponents(
1484 MLIRContext *context, ::std::optional<Location> location,
1485 Conv3DOp::Adaptor adaptor,
1486 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1487 llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
1488
1489 int64_t inputWidth = ShapedType::kDynamic;
1490 int64_t inputHeight = ShapedType::kDynamic;
1491 int64_t inputDepth = ShapedType::kDynamic;
1492
1493 int64_t weightWidth = ShapedType::kDynamic;
1494 int64_t weightHeight = ShapedType::kDynamic;
1495 int64_t weightDepth = ShapedType::kDynamic;
1496
1497 // Input shape describes input width/height and batch.
1498 ShapeAdaptor inputShape(adaptor.getInput().getType());
1499 if (inputShape.hasRank()) {
1500 outputShape[0] = inputShape.getDimSize(0);
1501 inputDepth = inputShape.getDimSize(1);
1502 inputHeight = inputShape.getDimSize(2);
1503 inputWidth = inputShape.getDimSize(3);
1504 }
1505
1506 // Weight shapes describes the filter width/height and the output channels.
1507 ShapeAdaptor weightShape(adaptor.getWeight().getType());
1508 if (weightShape.hasRank()) {
1509 outputShape[4] = weightShape.getDimSize(0);
1510 weightDepth = weightShape.getDimSize(1);
1511 weightHeight = weightShape.getDimSize(2);
1512 weightWidth = weightShape.getDimSize(3);
1513 }
1514
1515 // Bias shape can describe the output channels.
1516 ShapeAdaptor biasShape(adaptor.getBias().getType());
1517 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
1518 outputShape[4] = biasShape.getDimSize(0);
1519 }
1520
1521 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1522 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1523 llvm::ArrayRef<int64_t> pad = adaptor.getPad();
1524
1525 if (!ShapedType::isDynamic(inputDepth) &&
1526 !ShapedType::isDynamic(weightDepth)) {
1527 int32_t inputSize = inputDepth + pad[0] + pad[1];
1528 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
1529 int32_t unstridedResult = inputSize - filterSize + 1;
1530 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1531 }
1532
1533 if (!ShapedType::isDynamic(inputHeight) &&
1534 !ShapedType::isDynamic(weightHeight)) {
1535 int32_t inputSize = inputHeight + pad[2] + pad[3];
1536 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
1537 int32_t unstridedResult = inputSize - filterSize + 1;
1538 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1539 }
1540
1541 if (!ShapedType::isDynamic(inputWidth) &&
1542 !ShapedType::isDynamic(weightWidth)) {
1543 int32_t inputSize = inputWidth + pad[4] + pad[5];
1544 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
1545 int32_t unstridedResult = inputSize - filterSize + 1;
1546 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
1547 }
1548
1549 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1550 return success();
1551}
1552
1553LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
1554
1555LogicalResult AvgPool2dOp::inferReturnTypeComponents(
1556 MLIRContext *context, ::std::optional<Location> location,
1557 AvgPool2dOp::Adaptor adaptor,
1558 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1559 ShapeAdaptor inputShape(adaptor.getInput().getType());
1560 const Properties &prop = adaptor.getProperties();
1561 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
1562 inferredReturnShapes);
1563}
1564
1565LogicalResult MaxPool2dOp::inferReturnTypeComponents(
1566 MLIRContext *context, ::std::optional<Location> location,
1567 MaxPool2dOp::Adaptor adaptor,
1568 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1569 ShapeAdaptor inputShape(adaptor.getInput().getType());
1570 const Properties &prop = adaptor.getProperties();
1571 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
1572 inferredReturnShapes);
1573}
1574
1575LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
1576 MLIRContext *context, ::std::optional<Location> location,
1577 DepthwiseConv2DOp::Adaptor adaptor,
1578 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1579 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
1580
1581 int64_t inputWidth = ShapedType::kDynamic;
1582 int64_t inputHeight = ShapedType::kDynamic;
1583 int64_t inputChannels = ShapedType::kDynamic;
1584
1585 int64_t weightWidth = ShapedType::kDynamic;
1586 int64_t weightHeight = ShapedType::kDynamic;
1587 int64_t depthChannels = ShapedType::kDynamic;
1588
1589 // Input shape describes input width/height and batch.
1590 ShapeAdaptor inputShape(adaptor.getInput().getType());
1591 if (inputShape.hasRank()) {
1592 outputShape[0] = inputShape.getDimSize(0);
1593 inputHeight = inputShape.getDimSize(1);
1594 inputWidth = inputShape.getDimSize(2);
1595 inputChannels = inputShape.getDimSize(3);
1596 }
1597
1598 // Weight shapes describes the filter width/height and the output channels.
1599 ShapeAdaptor weightShape(adaptor.getWeight().getType());
1600 if (weightShape.hasRank()) {
1601 weightHeight = weightShape.getDimSize(0);
1602 weightWidth = weightShape.getDimSize(1);
1603 inputChannels = ShapedType::isDynamic(inputChannels)
1604 ? weightShape.getDimSize(2)
1605 : inputChannels;
1606 depthChannels = weightShape.getDimSize(3);
1607 }
1608
1609 // If both inputChannels and depthChannels are available we can determine
1610 // the output channels.
1611 if (!ShapedType::isDynamic(inputChannels) &&
1612 !ShapedType::isDynamic(depthChannels)) {
1613 outputShape[3] = inputChannels * depthChannels;
1614 }
1615
1616 // Bias shape can describe the output channels.
1617 ShapeAdaptor biasShape(adaptor.getBias().getType());
1618 if (biasShape.hasRank()) {
1619 outputShape[3] = ShapedType::isDynamic(outputShape[3])
1620 ? biasShape.getDimSize(0)
1621 : outputShape[3];
1622 }
1623
1624 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1625 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
1626 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1627
1628 if (!ShapedType::isDynamic(inputHeight) &&
1629 !ShapedType::isDynamic(weightHeight)) {
1630 int64_t inputSize = inputHeight + padding[0] + padding[1];
1631 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1632 int64_t unstridedResult = inputSize - filterSize + 1;
1633 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1634 }
1635
1636 if (!ShapedType::isDynamic(inputWidth) &&
1637 !ShapedType::isDynamic(weightWidth)) {
1638 int64_t inputSize = inputWidth + padding[2] + padding[3];
1639 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1640 int64_t unstridedResult = inputSize - filterSize + 1;
1641 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1642 }
1643
1644 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1645 return success();
1646}
1647
1648LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); }
1649
1650LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
1651 MLIRContext *context, ::std::optional<Location> location,
1652 TransposeConv2DOp::Adaptor adaptor,
1653 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1654 // outputShape is mutable.
1655 llvm::SmallVector<int64_t> outputShape =
1656 convertToMlirShape(adaptor.getOutShape());
1657
1658 int64_t inputWidth = ShapedType::kDynamic;
1659 int64_t inputHeight = ShapedType::kDynamic;
1660 int64_t weightWidth = ShapedType::kDynamic;
1661 int64_t weightHeight = ShapedType::kDynamic;
1662
1663 // Input shape describes input width/height and batch.
1664 ShapeAdaptor inputShape(adaptor.getInput().getType());
1665 if (inputShape.hasRank()) {
1666 outputShape[0] = ShapedType::isDynamic(outputShape[0])
1667 ? inputShape.getDimSize(0)
1668 : outputShape[0];
1669 inputHeight = inputShape.getDimSize(1);
1670 inputWidth = inputShape.getDimSize(2);
1671 }
1672
1673 // Weight shapes describes the filter width/height and the output channels.
1674 ShapeAdaptor weightShape(adaptor.getFilter().getType());
1675 if (weightShape.hasRank()) {
1676 outputShape[3] = ShapedType::isDynamic(outputShape[3])
1677 ? weightShape.getDimSize(0)
1678 : outputShape[3];
1679 weightHeight = weightShape.getDimSize(1);
1680 weightWidth = weightShape.getDimSize(2);
1681 }
1682
1683 // Bias shape can describe the output channels.
1684 ShapeAdaptor biasShape(adaptor.getInput().getType());
1685 if (biasShape.hasRank()) {
1686 outputShape[3] = ShapedType::isDynamic(outputShape[3])
1687 ? biasShape.getDimSize(0)
1688 : outputShape[3];
1689 }
1690
1691 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
1692 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1693
1694 if (!ShapedType::isDynamic(inputHeight) &&
1695 !ShapedType::isDynamic(weightHeight)) {
1696 int64_t calculateSize =
1697 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
1698 outputShape[1] =
1699 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
1700 }
1701
1702 if (!ShapedType::isDynamic(inputWidth) &&
1703 !ShapedType::isDynamic(weightWidth)) {
1704 int64_t calculateSize =
1705 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
1706 outputShape[2] =
1707 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
1708 }
1709
1710 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1711 return success();
1712}
1713
1714LogicalResult IfOp::inferReturnTypeComponents(
1715 MLIRContext *context, ::std::optional<Location> location,
1716 IfOp::Adaptor adaptor,
1717 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1718 llvm::SmallVector<tosa::YieldOp> yieldOps;
1719 for (Region *region : adaptor.getRegions()) {
1720 for (auto &block : *region)
1721 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1722 yieldOps.push_back(returnOp);
1723 }
1724
1725 if (yieldOps.empty())
1726 return failure();
1727
1728 // Get the initial type information for the yield op.
1729 llvm::SmallVector<ValueKnowledge> resultKnowledge;
1730 resultKnowledge.reserve(yieldOps.front().getNumOperands());
1731 for (auto operand : yieldOps.front().getOperands()) {
1732 resultKnowledge.push_back(
1733 ValueKnowledge::getKnowledgeFromType(operand.getType()));
1734 }
1735
1736 for (auto yieldOp : yieldOps) {
1737 if (resultKnowledge.size() != yieldOp.getNumOperands())
1738 return failure();
1739
1740 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
1741 int32_t index = it.index();
1742 auto meet = ValueKnowledge::meet(
1743 resultKnowledge[index],
1744 ValueKnowledge::getKnowledgeFromType(it.value().getType()));
1745 if (!meet)
1746 continue;
1747 resultKnowledge[index] = meet;
1748 }
1749 }
1750
1751 for (const ValueKnowledge &result : resultKnowledge) {
1752 inferredReturnShapes.push_back(result.getShapedTypeComponents());
1753 }
1754
1755 return success();
1756}
1757
1758LogicalResult WhileOp::inferReturnTypeComponents(
1759 MLIRContext *context, ::std::optional<Location> location,
1760 WhileOp::Adaptor adaptor,
1761 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1762 llvm::SmallVector<tosa::YieldOp> yieldOps;
1763 for (auto &block : adaptor.getBody())
1764 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1765 yieldOps.push_back(returnOp);
1766
1767 // TOSA's while must have a tosa.yield as its terminator. If not found this
1768 // tosa.while is invalid.
1769 if (yieldOps.empty())
1770 return failure();
1771
1772 // Get the initial type information from the operand types.
1773 llvm::SmallVector<ValueKnowledge> resultKnowledge;
1774 resultKnowledge.reserve(yieldOps.front().getNumOperands());
1775 for (auto operand : yieldOps.front().getOperands()) {
1776 resultKnowledge.push_back(
1777 ValueKnowledge::getKnowledgeFromType(operand.getType()));
1778 }
1779
1780 for (auto yieldOp : yieldOps) {
1781 if (resultKnowledge.size() != yieldOp.getNumOperands())
1782 return failure();
1783
1784 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
1785 int32_t index = it.index();
1786 if (auto meet = ValueKnowledge::meet(
1787 resultKnowledge[index],
1788 ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
1789 resultKnowledge[index] = meet;
1790 }
1791 }
1792 }
1793
1794 for (const ValueKnowledge &result : resultKnowledge) {
1795 inferredReturnShapes.push_back(result.getShapedTypeComponents());
1796 }
1797
1798 return success();
1799}
1800
1801std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
1802 if (auto vt = llvm::dyn_cast<VectorType>(getType()))
1803 return llvm::to_vector<4>(vt.getShape());
1804 return std::nullopt;
1805}
1806
1807// parse and print of IfOp refer to the implementation of SCF dialect.
1808ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
1809 // Create the regions for 'then'.
1810 result.regions.reserve(2);
1811 Region *thenRegion = result.addRegion();
1812 Region *elseRegion = result.addRegion();
1813
1814 auto &builder = parser.getBuilder();
1815 OpAsmParser::UnresolvedOperand cond;
1816 // Create a i1 tensor type for the boolean condition.
1817 Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
1818 if (parser.parseOperand(cond) ||
1819 parser.resolveOperand(cond, i1Type, result.operands))
1820 return failure();
1821 // Parse optional results type list.
1822 if (parser.parseOptionalArrowTypeList(result.types))
1823 return failure();
1824 // Parse the 'then' region.
1825 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1826 return failure();
1827
1828 // If we find an 'else' keyword then parse the 'else' region.
1829 if (!parser.parseOptionalKeyword("else")) {
1830 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1831 return failure();
1832 }
1833
1834 // Parse the optional attribute list.
1835 if (parser.parseOptionalAttrDict(result.attributes))
1836 return failure();
1837 return success();
1838}
1839
1840void IfOp::print(OpAsmPrinter &p) {
1841 bool printBlockTerminators = false;
1842
1843 p << " " << getCond();
1844 if (!getResults().empty()) {
1845 p << " -> (" << getResultTypes() << ")";
1846 // Print yield explicitly if the op defines values.
1847 printBlockTerminators = true;
1848 }
1849 p << ' ';
1850 p.printRegion(getThenBranch(),
1851 /*printEntryBlockArgs=*/false,
1852 /*printBlockTerminators=*/printBlockTerminators);
1853
1854 // Print the 'else' regions if it exists and has a block.
1855 auto &elseRegion = getElseBranch();
1856 if (!elseRegion.empty()) {
1857 p << " else ";
1858 p.printRegion(elseRegion,
1859 /*printEntryBlockArgs=*/false,
1860 /*printBlockTerminators=*/printBlockTerminators);
1861 }
1862
1863 p.printOptionalAttrDict((*this)->getAttrs());
1864}
1865
1866LogicalResult ReverseOp::verify() {
1867 TensorType inputType = getInput().getType();
1868 TensorType outputType = getOutput().getType();
1869 int32_t reverseAxis = getAxis();
1870
1871 if (reverseAxis < 0)
1872 return emitOpError("expected non-negative reverse axis");
1873 if (inputType.hasRank()) {
1874 int64_t inputRank = inputType.getRank();
1875 // We allow for a special case where the input/output shape has rank 0 and
1876 // axis is also 0.
1877 if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
1878 return emitOpError("expect input tensor rank (")
1879 << inputRank << ") to be larger than reverse axis (" << reverseAxis
1880 << ")";
1881 }
1882 if (outputType.hasRank()) {
1883 int64_t outputRank = outputType.getRank();
1884 if (inputType.hasRank() && outputRank != inputType.getRank())
1885 return emitOpError(
1886 "expect output tensor rank to be equal to input tensor rank");
1887 if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
1888 return emitOpError("expect output tensor rank (")
1889 << outputRank << ") to be larger than reverse axis ("
1890 << reverseAxis << ")";
1891 }
1892 return success();
1893}
1894
1895// parse and print of WhileOp refer to the implementation of SCF dialect.
1896ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
1897 SmallVector<OpAsmParser::Argument, 4> regionArgs;
1898 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1899 Region *cond = result.addRegion();
1900 Region *body = result.addRegion();
1901
1902 OptionalParseResult listResult =
1903 parser.parseOptionalAssignmentList(regionArgs, operands);
1904 if (listResult.has_value() && failed(listResult.value()))
1905 return failure();
1906
1907 FunctionType functionType;
1908 SMLoc typeLoc = parser.getCurrentLocation();
1909 if (failed(parser.parseColonType(functionType)))
1910 return failure();
1911
1912 result.addTypes(functionType.getResults());
1913
1914 if (functionType.getNumInputs() != operands.size()) {
1915 return parser.emitError(typeLoc)
1916 << "expected as many input types as operands "
1917 << "(expected " << operands.size() << " got "
1918 << functionType.getNumInputs() << ")";
1919 }
1920
1921 // Resolve input operands.
1922 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
1923 parser.getCurrentLocation(),
1924 result.operands)))
1925 return failure();
1926
1927 // Propagate the types into the region arguments.
1928 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
1929 regionArgs[i].type = functionType.getInput(i);
1930
1931 return failure(parser.parseRegion(*cond, regionArgs) ||
1932 parser.parseKeyword("do") || parser.parseRegion(*body) ||
1933 parser.parseOptionalAttrDictWithKeyword(result.attributes));
1934}
1935
1936static void printInitializationList(OpAsmPrinter &parser,
1937 Block::BlockArgListType blocksArgs,
1938 ValueRange initializers,
1939 StringRef prefix = "") {
1940 assert(blocksArgs.size() == initializers.size() &&
1941 "expected same length of arguments and initializers");
1942 if (initializers.empty())
1943 return;
1944
1945 parser << prefix << '(';
1946 llvm::interleaveComma(
1947 c: llvm::zip(t&: blocksArgs, u&: initializers), os&: parser,
1948 each_fn: [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
1949 parser << ")";
1950}
1951
1952void WhileOp::print(OpAsmPrinter &parser) {
1953 printInitializationList(parser, getCond().front().getArguments(), getInputs(),
1954 " ");
1955 parser << " : ";
1956 parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes());
1957 parser << ' ';
1958 parser.printRegion(getCond(), /*printEntryBlockArgs=*/false);
1959 parser << " do ";
1960 parser.printRegion(getBody());
1961 parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
1962}
1963
1964//===----------------------------------------------------------------------===//
1965// TOSA Attribute Definitions.
1966//===----------------------------------------------------------------------===//
1967
1968#define GET_ATTRDEF_CLASSES
1969#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
1970
1971//===----------------------------------------------------------------------===//
1972// TOSA Operator Definitions.
1973//===----------------------------------------------------------------------===//
1974
1975#define GET_OP_CLASSES
1976#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
1977

source code of mlir/lib/Dialect/Tosa/IR/TosaOps.cpp