1 | //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // \file |
10 | // This file implements the TOSA Specification: |
11 | // https://developer.mlplatform.org/w/tosa/ |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
16 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" |
17 | #include "mlir/Dialect/Quant/QuantOps.h" |
18 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
19 | #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" |
20 | #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" |
21 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
22 | #include "mlir/IR/BuiltinTypes.h" |
23 | #include "mlir/IR/DialectImplementation.h" |
24 | #include "mlir/IR/Matchers.h" |
25 | #include "mlir/IR/PatternMatch.h" |
26 | #include "mlir/IR/TypeUtilities.h" |
27 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
28 | #include "mlir/Transforms/InliningUtils.h" |
29 | #include "llvm/ADT/APFloat.h" |
30 | #include "llvm/ADT/DenseMap.h" |
31 | #include "llvm/ADT/TypeSwitch.h" |
32 | |
33 | using namespace mlir; |
34 | using namespace mlir::tosa; |
35 | |
36 | #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc" |
37 | |
38 | //===----------------------------------------------------------------------===// |
39 | // Tosa dialect interface includes. |
40 | //===----------------------------------------------------------------------===// |
41 | |
42 | #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc" |
43 | |
44 | namespace { |
45 | #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc" |
46 | |
47 | //===----------------------------------------------------------------------===// |
48 | // Dialect Function Inliner Interface. |
49 | //===----------------------------------------------------------------------===// |
50 | struct TosaInlinerInterface : public DialectInlinerInterface { |
51 | using DialectInlinerInterface::DialectInlinerInterface; |
52 | |
53 | //===--------------------------------------------------------------------===// |
54 | // Analysis Hooks. |
55 | //===--------------------------------------------------------------------===// |
56 | |
57 | /// All operations can be inlined by default. |
58 | bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned, |
59 | IRMapping &map) const final { |
60 | return true; |
61 | } |
62 | |
63 | /// All regions with If and While parent operators can be inlined. |
64 | bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, |
65 | IRMapping &map) const final { |
66 | return (isa<tosa::IfOp>(dest->getParentOp()) || |
67 | isa<tosa::WhileOp>(dest->getParentOp())); |
68 | } |
69 | }; |
70 | |
71 | /// This class implements the bytecode interface for the Tosa dialect. |
72 | struct TosaDialectBytecodeInterface : public BytecodeDialectInterface { |
73 | TosaDialectBytecodeInterface(Dialect *dialect) |
74 | : BytecodeDialectInterface(dialect) {} |
75 | |
76 | //===--------------------------------------------------------------------===// |
77 | // Attributes |
78 | |
79 | Attribute readAttribute(DialectBytecodeReader &reader) const override { |
80 | return ::readAttribute(getContext(), reader); |
81 | } |
82 | |
83 | LogicalResult writeAttribute(Attribute attr, |
84 | DialectBytecodeWriter &writer) const override { |
85 | return ::writeAttribute(attr, writer); |
86 | } |
87 | |
88 | //===--------------------------------------------------------------------===// |
89 | // Types |
90 | |
91 | Type readType(DialectBytecodeReader &reader) const override { |
92 | return ::readType(getContext(), reader); |
93 | } |
94 | |
95 | LogicalResult writeType(Type type, |
96 | DialectBytecodeWriter &writer) const override { |
97 | return ::writeType(type, writer); |
98 | } |
99 | |
100 | void writeVersion(DialectBytecodeWriter &writer) const final { |
101 | // TODO: Populate. |
102 | } |
103 | |
104 | std::unique_ptr<DialectVersion> |
105 | readVersion(DialectBytecodeReader &reader) const final { |
106 | // TODO: Populate |
107 | reader.emitError(msg: "Dialect does not support versioning" ); |
108 | return nullptr; |
109 | } |
110 | |
111 | LogicalResult upgradeFromVersion(Operation *topLevelOp, |
112 | const DialectVersion &version) const final { |
113 | return success(); |
114 | } |
115 | }; |
116 | |
117 | } // namespace |
118 | |
119 | //===----------------------------------------------------------------------===// |
120 | // TOSA control flow support. |
121 | //===----------------------------------------------------------------------===// |
122 | |
123 | /// Returns the while loop body. |
124 | SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; } |
125 | |
126 | //===----------------------------------------------------------------------===// |
127 | // Tosa dialect initialization. |
128 | //===----------------------------------------------------------------------===// |
129 | |
130 | void TosaDialect::initialize() { |
131 | addOperations< |
132 | #define GET_OP_LIST |
133 | #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" |
134 | >(); |
135 | addAttributes< |
136 | #define GET_ATTRDEF_LIST |
137 | #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc" |
138 | >(); |
139 | addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>(); |
140 | declarePromisedInterfaces< |
141 | mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp, |
142 | ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, DivOp, |
143 | LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp, |
144 | LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp, |
145 | BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp, |
146 | NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp, |
147 | GreaterEqualOp, MatMulOp>(); |
148 | } |
149 | |
150 | Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, |
151 | Type type, Location loc) { |
152 | // Tosa dialect constants only support ElementsAttr unlike standard dialect |
153 | // constant which supports all attributes. |
154 | if (llvm::isa<ElementsAttr>(value)) |
155 | return builder.create<tosa::ConstOp>(loc, type, |
156 | llvm::cast<ElementsAttr>(value)); |
157 | return nullptr; |
158 | } |
159 | |
160 | //===----------------------------------------------------------------------===// |
161 | // Parsers and printers |
162 | //===----------------------------------------------------------------------===// |
163 | |
164 | ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, |
165 | Attribute &attr) { |
166 | if (succeeded(result: parser.parseOptionalEqual())) { |
167 | if (failed(result: parser.parseAttribute(result&: attr))) { |
168 | return parser.emitError(loc: parser.getCurrentLocation()) |
169 | << "expected attribute" ; |
170 | } |
171 | if (auto typedAttr = dyn_cast<TypedAttr>(attr)) { |
172 | typeAttr = TypeAttr::get(typedAttr.getType()); |
173 | } |
174 | return success(); |
175 | } |
176 | |
177 | Type type; |
178 | if (failed(result: parser.parseColonType(result&: type))) { |
179 | return parser.emitError(loc: parser.getCurrentLocation()) << "expected type" ; |
180 | } |
181 | typeAttr = TypeAttr::get(type); |
182 | |
183 | return success(); |
184 | } |
185 | |
186 | void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, |
187 | Attribute attr) { |
188 | bool needsSpace = false; |
189 | auto typedAttr = dyn_cast_or_null<TypedAttr>(attr); |
190 | if (!typedAttr || typedAttr.getType() != type.getValue()) { |
191 | p << ": " ; |
192 | p.printAttribute(attr: type); |
193 | needsSpace = true; // subsequent attr value needs a space separator |
194 | } |
195 | if (attr) { |
196 | if (needsSpace) |
197 | p << ' '; |
198 | p << "= " ; |
199 | p.printAttribute(attr); |
200 | } |
201 | } |
202 | |
203 | //===----------------------------------------------------------------------===// |
204 | // TOSA Operator Verifiers. |
205 | //===----------------------------------------------------------------------===// |
206 | |
207 | static bool hasZeroDimension(ShapedType shapedType) { |
208 | if (!shapedType.hasRank()) |
209 | return false; |
210 | |
211 | auto rank = shapedType.getRank(); |
212 | |
213 | for (int i = 0; i < rank; i++) { |
214 | if (shapedType.isDynamicDim(i)) |
215 | continue; |
216 | if (shapedType.getDimSize(i) == 0) |
217 | return true; |
218 | } |
219 | |
220 | return false; |
221 | } |
222 | |
223 | template <typename T> static LogicalResult verifyConvOp(T op) { |
224 | // All TOSA conv ops have an input() and weight(). |
225 | auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType()); |
226 | auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType()); |
227 | |
228 | // Must be ranked tensor types |
229 | if (!inputType) { |
230 | op.emitOpError("expect a ranked tensor for input, got " ) << op.getInput(); |
231 | return failure(); |
232 | } |
233 | if (!weightType) { |
234 | op.emitOpError("expect a ranked tensor for weight, got " ) << op.getWeight(); |
235 | return failure(); |
236 | } |
237 | |
238 | if (hasZeroDimension(inputType)) |
239 | return op.emitOpError() << "tensor has a dimension with size zero. Each " |
240 | "dimension of a tensor must have size >= 1" ; |
241 | |
242 | auto inputEType = inputType.getElementType(); |
243 | auto weightEType = weightType.getElementType(); |
244 | |
245 | bool inputIsQuant = !llvm::isa<FloatType>(inputEType); |
246 | bool weightIsQuant = !llvm::isa<FloatType>(weightEType); |
247 | |
248 | // Either both must be quantized or both unquantized. |
249 | if (inputIsQuant != weightIsQuant) { |
250 | op.emitOpError( |
251 | "expect both input and weight to be float or not together, got " ) |
252 | << inputEType << " and " << weightEType; |
253 | return failure(); |
254 | } |
255 | |
256 | // Quantized type must have constructed the quantizationattr, and unquantized |
257 | // types should not have a quantizationattr. |
258 | if ((inputIsQuant && !op.getQuantizationInfo()) || |
259 | (!inputIsQuant && op.getQuantizationInfo())) { |
260 | op.emitOpError("quantizationattr is required for quantized type, and not " |
261 | "allowed for float type" ); |
262 | return failure(); |
263 | } |
264 | |
265 | return success(); |
266 | } |
267 | |
268 | LogicalResult tosa::ArgMaxOp::verify() { |
269 | // Ensure output is of 32-bit integer |
270 | const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType(); |
271 | if (!resultETy.isIntOrIndex()) |
272 | return emitOpError("result tensor is not of integer type" ); |
273 | |
274 | // Ensure axis is within the tensor rank |
275 | const auto inputType = llvm::cast<ShapedType>(getInput().getType()); |
276 | const int64_t axis = getAxisAttr().getInt(); |
277 | if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank())) |
278 | return emitOpError("specified axis is outside the rank of the tensor" ); |
279 | |
280 | return success(); |
281 | } |
282 | |
283 | LogicalResult tosa::AvgPool2dOp::verify() { |
284 | auto inputType = llvm::cast<ShapedType>(getInput().getType()); |
285 | if (hasZeroDimension(inputType)) |
286 | return emitOpError() << "tensor has a dimension with size zero. Each " |
287 | "dimension of a tensor must have size >= 1" ; |
288 | |
289 | auto inputETy = inputType.getElementType(); |
290 | auto resultETy = llvm::cast<ShapedType>(getType()).getElementType(); |
291 | |
292 | if (auto quantType = |
293 | llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) |
294 | inputETy = quantType.getStorageType(); |
295 | |
296 | if (auto quantType = |
297 | llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy)) |
298 | resultETy = quantType.getStorageType(); |
299 | |
300 | auto accType = getAccType(); |
301 | if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32)) |
302 | return emitOpError("accumulator type for integer tensor is not i32" ); |
303 | |
304 | if (inputETy.isF16() && !(accType.isF16() || accType.isF32())) |
305 | return emitOpError("accumulator type for f16 tensor is not f16/f32" ); |
306 | |
307 | if (inputETy.isBF16() && !accType.isF32()) |
308 | return emitOpError("accumulator type for bf16 tensor is not f32" ); |
309 | |
310 | if (inputETy.isF32() && !accType.isF32()) |
311 | return emitOpError("accumulator type for f32 tensor is not f32" ); |
312 | |
313 | if ((inputETy.isF32() && resultETy.isF32()) || |
314 | (inputETy.isF16() && resultETy.isF16()) || |
315 | (inputETy.isBF16() && resultETy.isBF16()) || |
316 | (inputETy.isInteger(8) && resultETy.isInteger(8)) || |
317 | (inputETy.isInteger(16) && resultETy.isInteger(16))) |
318 | return success(); |
319 | |
320 | return emitOpError("input/output element types are incompatible." ); |
321 | } |
322 | |
323 | LogicalResult tosa::ClampOp::verify() { |
324 | mlir::Type inputETy = |
325 | llvm::cast<ShapedType>(getInput().getType()).getElementType(); |
326 | if (auto quantType = |
327 | llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) { |
328 | inputETy = quantType.getStorageType(); |
329 | } |
330 | mlir::Type maxFpType = getMaxFpAttr().getType(); |
331 | mlir::Type minFpType = getMinFpAttr().getType(); |
332 | mlir::Type outputETy = |
333 | llvm::cast<ShapedType>(getOutput().getType()).getElementType(); |
334 | if (auto quantType = |
335 | llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) { |
336 | outputETy = quantType.getStorageType(); |
337 | } |
338 | unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth(); |
339 | |
340 | if (inputETy != outputETy) |
341 | return emitOpError("input/output element types are incompatible." ); |
342 | |
343 | // if input datatype is float, check that the two min/max_fp attributes share |
344 | // the same type and that their type is either the same of the input's |
345 | // datatype, or a float type whose bitwidth > input datatype bitwidth |
346 | if (!inputETy.isInteger(dataTypeBitWidth)) { |
347 | if (((maxFpType != minFpType) || |
348 | (maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <= |
349 | inputETy.getIntOrFloatBitWidth()))) |
350 | return emitOpError("min/max attributes types are incompatible with " |
351 | "input/output element types." ); |
352 | } |
353 | |
354 | return success(); |
355 | } |
356 | |
357 | //===----------------------------------------------------------------------===// |
358 | // TOSA Operator Quantization Builders. |
359 | //===----------------------------------------------------------------------===// |
360 | |
361 | /// This builder is called on all convolution operators except TransposeConv, |
362 | /// which has specialized output shape semantics. The builder also defines the |
363 | /// bitwidth of the output given the bit width of the input & weight content. |
364 | static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, |
365 | Type outputType, Value input, Value weight, |
366 | Value bias, DenseI64ArrayAttr pad, |
367 | DenseI64ArrayAttr stride, |
368 | DenseI64ArrayAttr dilation) { |
369 | |
370 | result.addOperands(newOperands: {input, weight, bias}); |
371 | result.addAttribute("pad" , pad); |
372 | result.addAttribute("stride" , stride); |
373 | result.addAttribute("dilation" , dilation); |
374 | |
375 | auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight); |
376 | if (quantAttr) { |
377 | result.addAttribute("quantization_info" , quantAttr); |
378 | result.addTypes( |
379 | newTypes: buildConvOpResultTypeInfo(builder, outputType, input, weight)); |
380 | } else { |
381 | result.addTypes(newTypes: outputType); |
382 | } |
383 | } |
384 | |
385 | /// Handles tosa.transpose_conv2d which has outpad and output shape attributes. |
386 | static void buildTransConvOpWithQuantInfo( |
387 | OpBuilder &builder, OperationState &result, Type outputType, Value input, |
388 | Value weight, Value bias, DenseI64ArrayAttr outpad, |
389 | DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) { |
390 | result.addOperands(newOperands: {input, weight, bias}); |
391 | result.addAttribute("out_pad" , outpad); |
392 | result.addAttribute("stride" , stride); |
393 | result.addAttribute("out_shape" , outputShape); |
394 | auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight); |
395 | |
396 | if (quantAttr) { |
397 | result.addAttribute("quantization_info" , quantAttr); |
398 | result.addTypes( |
399 | newTypes: buildConvOpResultTypeInfo(builder, outputType, input, weight)); |
400 | } else { |
401 | result.addTypes(newTypes: outputType); |
402 | } |
403 | } |
404 | |
405 | /// The tosa.fully_connected op has its own builder as it does not have |
406 | /// strides/dilation/padding. |
407 | static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result, |
408 | Type outputType, Value input, Value weight, |
409 | Value bias) { |
410 | |
411 | result.addOperands(newOperands: {input, weight, bias}); |
412 | auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight); |
413 | if (quantAttr) { |
414 | result.addAttribute("quantization_info" , quantAttr); |
415 | result.addTypes( |
416 | newTypes: buildConvOpResultTypeInfo(builder, outputType, input, weight)); |
417 | } else { |
418 | result.addTypes(newTypes: outputType); |
419 | } |
420 | } |
421 | |
422 | /// The tosa.matmul op is also intended to be generated where a fully_connected |
423 | /// op must be constructed where the weight is not a constant. In this case, |
424 | /// the fully_connected op must be expressed using matmul. |
425 | /// TODO: Add link to the leglization document explaining this. |
426 | static void buildMatMulOpWithQuantInfo(OpBuilder &builder, |
427 | OperationState &result, Type outputType, |
428 | Value a, Value b) { |
429 | result.addOperands(newOperands: {a, b}); |
430 | auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b); |
431 | |
432 | if (quantAttr) { |
433 | result.addAttribute("quantization_info" , quantAttr); |
434 | |
435 | auto inputType = llvm::dyn_cast<ShapedType>(a.getType()); |
436 | assert(inputType && "Input must be a shaped tensor type!" ); |
437 | |
438 | auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>( |
439 | inputType.getElementType()); |
440 | assert(inputQType && "Tensor must have quantized datatype!" ); |
441 | |
442 | unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); |
443 | |
444 | auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType); |
445 | assert(outputShapedType && "Output must be a shaped type" ); |
446 | |
447 | IntegerType accElementType; |
448 | if (inputBits == 16) |
449 | accElementType = builder.getIntegerType(48); |
450 | else |
451 | accElementType = builder.getI32Type(); |
452 | auto accType = outputShapedType.clone(accElementType); |
453 | result.addTypes(accType); |
454 | } else { |
455 | result.addTypes(newTypes: outputType); |
456 | } |
457 | } |
458 | |
459 | /// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr |
460 | /// but avg_pool operator has its own builder as it has additional parameters |
461 | /// not part of the unary ops. |
462 | static void |
463 | buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, |
464 | Type outputType, Value input, |
465 | DenseArrayAttr kernel, DenseArrayAttr stride, |
466 | DenseArrayAttr pad, TypeAttr accType) { |
467 | result.addOperands(newOperands: input); |
468 | result.addAttribute("kernel" , kernel); |
469 | result.addAttribute("stride" , stride); |
470 | result.addAttribute("pad" , pad); |
471 | result.addAttribute("acc_type" , accType); |
472 | auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); |
473 | if (quantAttr) |
474 | result.addAttribute("quantization_info" , quantAttr); |
475 | result.types.push_back(Elt: outputType); |
476 | } |
477 | |
478 | /// This builder is called on single-parameter unary operators that have scale |
479 | /// relationship between their input and output, expressed by the |
480 | /// UnaryOpQuantizationAttr. |
481 | static void buildUnaryOpWithQuantInfo(OpBuilder &builder, |
482 | OperationState &result, Type outputType, |
483 | Value input) { |
484 | result.addOperands(newOperands: input); |
485 | auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); |
486 | if (quantAttr) |
487 | result.addAttribute("quantization_info" , quantAttr); |
488 | result.types.push_back(Elt: outputType); |
489 | } |
490 | |
491 | /// This builder is called on TOSA pad operator that needs to create its own |
492 | /// OptionalAttr quantization_attr parameter to scale the padding values |
493 | /// correctly. No pad_const is interpreted as zero-padding. |
494 | static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, |
495 | Type outputType, Value input, |
496 | Value paddings) { |
497 | result.addOperands(newOperands: {input, paddings}); |
498 | auto quantAttr = buildPadOpQuantizationAttr(builder, input); |
499 | if (quantAttr) |
500 | result.addAttribute("quantization_info" , quantAttr); |
501 | result.types.push_back(Elt: outputType); |
502 | } |
503 | |
504 | /// This builder is called on TOSA pad operator when an explicit pad_const |
505 | /// value is passed in. It also optionally constructs quantization_attr. |
506 | static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder, |
507 | OperationState &result, |
508 | Type outputType, Value input, |
509 | Value paddings, |
510 | Value padConst) { |
511 | result.addOperands(newOperands: {input, paddings, padConst}); |
512 | auto quantAttr = buildPadOpQuantizationAttr(builder, input); |
513 | if (quantAttr) |
514 | result.addAttribute("quantization_info" , quantAttr); |
515 | result.types.push_back(Elt: outputType); |
516 | } |
517 | |
518 | //===----------------------------------------------------------------------===// |
519 | // TOSA Operator Return Type Inference. |
520 | //===----------------------------------------------------------------------===// |
521 | |
522 | static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, |
523 | SmallVector<int64_t> &outShape) { |
524 | int64_t outRank = 0; |
525 | for (int i = 0, e = operands.size(); i != e; ++i) { |
526 | auto shape = operands.getShape(index: i); |
527 | if (!shape.hasRank()) { |
528 | // TODO(jennik): Update function to have better case handling for invalid |
529 | // operands and for ranked tensors. |
530 | return failure(); |
531 | } |
532 | outRank = std::max<int64_t>(a: outRank, b: shape.getRank()); |
533 | } |
534 | |
535 | outShape.resize(N: outRank, NV: 1); |
536 | |
537 | for (int i = 0, e = operands.size(); i != e; ++i) { |
538 | auto shape = operands.getShape(index: i); |
539 | auto rankDiff = outShape.size() - shape.getRank(); |
540 | |
541 | for (size_t i = 0, e = shape.getRank(); i < e; ++i) { |
542 | auto dim1 = outShape[i + rankDiff]; |
543 | auto dim2 = shape.getDimSize(index: i); |
544 | auto resolvedDim = dim1; |
545 | |
546 | if (dim1 == 1) { |
547 | resolvedDim = dim2; |
548 | } else if (dim2 == 1) { |
549 | resolvedDim = dim1; |
550 | } else if (dim1 != dim2) { |
551 | return failure(); |
552 | } |
553 | outShape[i + rankDiff] = resolvedDim; |
554 | } |
555 | } |
556 | |
557 | return success(); |
558 | } |
559 | |
560 | LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents( |
561 | MLIRContext *context, ::std::optional<Location> location, |
562 | ArgMaxOp::Adaptor adaptor, |
563 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
564 | ShapeAdaptor inputShape(adaptor.getInput().getType()); |
565 | IntegerAttr axis = adaptor.getProperties().axis; |
566 | int32_t axisVal = axis.getValue().getSExtValue(); |
567 | |
568 | if (!inputShape.hasRank()) { |
569 | inferredReturnShapes.push_back(ShapedTypeComponents()); |
570 | return success(); |
571 | } |
572 | |
573 | SmallVector<int64_t> outShape; |
574 | outShape.reserve(inputShape.getRank() - 1); |
575 | for (int i = 0, s = inputShape.getRank(); i < s; i++) { |
576 | if (i == axisVal) |
577 | continue; |
578 | outShape.push_back(inputShape.getDimSize(i)); |
579 | } |
580 | |
581 | inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); |
582 | return success(); |
583 | } |
584 | |
585 | LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents( |
586 | MLIRContext *context, ::std::optional<Location> location, |
587 | RFFT2dOp::Adaptor adaptor, |
588 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
589 | ShapeAdaptor inputShape(adaptor.getInput().getType()); |
590 | |
591 | if (!inputShape.hasRank()) |
592 | return failure(); |
593 | |
594 | llvm::SmallVector<int64_t> outputShape; |
595 | outputShape.resize(3, ShapedType::kDynamic); |
596 | outputShape[0] = inputShape.getDimSize(0); |
597 | outputShape[1] = inputShape.getDimSize(1); |
598 | int64_t inWidth = inputShape.getDimSize(2); |
599 | |
600 | // Note that we can support this calculation symbolically |
601 | // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1] |
602 | if (inWidth != ShapedType::kDynamic) |
603 | outputShape[2] = inWidth / 2 + 1; |
604 | |
605 | inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); |
606 | inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); |
607 | |
608 | return success(); |
609 | } |
610 | |
611 | LogicalResult tosa::FFT2dOp::inferReturnTypeComponents( |
612 | MLIRContext *context, ::std::optional<Location> location, |
613 | FFT2dOp::Adaptor adaptor, |
614 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
615 | inferredReturnShapes.push_back( |
616 | ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType()))); |
617 | inferredReturnShapes.push_back( |
618 | ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType()))); |
619 | return success(); |
620 | } |
621 | |
622 | LogicalResult tosa::ConcatOp::inferReturnTypeComponents( |
623 | MLIRContext *context, ::std::optional<Location> location, |
624 | ConcatOp::Adaptor adaptor, |
625 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
626 | // Infer all dimension sizes by reducing based on inputs. |
627 | const Properties &prop = adaptor.getProperties(); |
628 | int32_t axis = prop.axis.getValue().getSExtValue(); |
629 | llvm::SmallVector<int64_t> outputShape; |
630 | bool hasRankedInput = false; |
631 | for (auto operand : adaptor.getOperands()) { |
632 | ShapeAdaptor operandShape(operand.getType()); |
633 | if (!operandShape.hasRank()) |
634 | continue; |
635 | |
636 | // Copy the Operand's rank. |
637 | if (!hasRankedInput) |
638 | outputShape.resize(operandShape.getRank(), ShapedType::kDynamic); |
639 | |
640 | // Copy shapes until the dim is non-dynamic. |
641 | for (int i = 0, s = operandShape.getRank(); i < s; i++) { |
642 | if (i == axis || operandShape.isDynamicDim(i)) |
643 | continue; |
644 | if (outputShape[i] == ShapedType::kDynamic) |
645 | outputShape[i] = operandShape.getDimSize(i); |
646 | if (outputShape[i] != operandShape.getDimSize(i)) |
647 | return emitOptionalError(location, |
648 | "Cannot concat tensors with different sizes" |
649 | " on the non-axis dimension " , |
650 | i); |
651 | } |
652 | |
653 | hasRankedInput = true; |
654 | } |
655 | Type inputType = |
656 | llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType(); |
657 | if (!hasRankedInput) { |
658 | inferredReturnShapes.push_back(ShapedTypeComponents(inputType)); |
659 | return success(); |
660 | } |
661 | |
662 | // Determine the dimension size along the concatenation axis. |
663 | int64_t concatDimSize = 0; |
664 | for (auto operand : adaptor.getOperands()) { |
665 | ShapeAdaptor operandShape(operand.getType()); |
666 | |
667 | // We need to know the length of the concatenation axis of all inputs to |
668 | // determine the dimension size of the output shape. |
669 | if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) { |
670 | concatDimSize = ShapedType::kDynamic; |
671 | break; |
672 | } |
673 | |
674 | concatDimSize += operandShape.getDimSize(axis); |
675 | } |
676 | |
677 | outputShape[axis] = concatDimSize; |
678 | |
679 | inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); |
680 | return success(); |
681 | } |
682 | |
683 | LogicalResult tosa::EqualOp::inferReturnTypeComponents( |
684 | MLIRContext *context, ::std::optional<Location> location, |
685 | ValueShapeRange operands, DictionaryAttr attributes, |
686 | OpaqueProperties properties, RegionRange regions, |
687 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
688 | auto elementType = IntegerType::get(context, /*width=*/1); |
689 | |
690 | llvm::SmallVector<int64_t> outShape; |
691 | if (resolveBroadcastShape(operands, outShape).failed()) { |
692 | inferredReturnShapes.push_back(ShapedTypeComponents(elementType)); |
693 | return success(); |
694 | } |
695 | |
696 | inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType)); |
697 | return success(); |
698 | } |
699 | |
700 | bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { |
701 | if (l.size() != r.size() || l.size() != 1) |
702 | return false; |
703 | return succeeded(verifyCompatibleShape(l[0], r[0])); |
704 | } |
705 | |
706 | LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents( |
707 | MLIRContext *context, ::std::optional<Location> location, |
708 | FullyConnectedOp::Adaptor adaptor, |
709 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
710 | ShapeAdaptor inputShape(adaptor.getInput().getType()); |
711 | ShapeAdaptor weightShape(adaptor.getWeight().getType()); |
712 | ShapeAdaptor biasShape(adaptor.getBias().getType()); |
713 | |
714 | // All shapes are dynamic. |
715 | SmallVector<int64_t> outShape; |
716 | outShape.resize(2, ShapedType::kDynamic); |
717 | |
718 | if (inputShape.hasRank()) { |
719 | outShape[0] = inputShape.getDimSize(0); |
720 | } |
721 | |
722 | if (weightShape.hasRank()) { |
723 | outShape[1] = weightShape.getDimSize(0); |
724 | } |
725 | |
726 | if (biasShape.hasRank()) { |
727 | outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0) |
728 | : outShape[1]; |
729 | } |
730 | |
731 | inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); |
732 | return success(); |
733 | } |
734 | |
735 | LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); } |
736 | |
737 | LogicalResult tosa::MatMulOp::inferReturnTypeComponents( |
738 | MLIRContext *context, ::std::optional<Location> location, |
739 | MatMulOp::Adaptor adaptor, |
740 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
741 | ShapeAdaptor lhsShape(adaptor.getA().getType()); |
742 | ShapeAdaptor rhsShape(adaptor.getB().getType()); |
743 | |
744 | // All shapes are dynamic. |
745 | SmallVector<int64_t> outShape; |
746 | outShape.resize(3, ShapedType::kDynamic); |
747 | |
748 | if (lhsShape.hasRank()) { |
749 | outShape[0] = lhsShape.getDimSize(0); |
750 | outShape[1] = lhsShape.getDimSize(1); |
751 | } |
752 | |
753 | if (rhsShape.hasRank()) { |
754 | outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0) |
755 | : outShape[0]; |
756 | outShape[2] = rhsShape.getDimSize(2); |
757 | } |
758 | |
759 | inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); |
760 | return success(); |
761 | } |
762 | |
763 | LogicalResult tosa::PadOp::inferReturnTypeComponents( |
764 | MLIRContext *context, ::std::optional<Location> location, |
765 | PadOp::Adaptor adaptor, |
766 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
767 | ShapeAdaptor inputShape(adaptor.getInput1().getType()); |
768 | ShapeAdaptor paddingShape(adaptor.getPadding().getType()); |
769 | SmallVector<int64_t> outputShape; |
770 | |
771 | // If both inputs have unknown shape, we cannot determine the shape of the |
772 | // output. |
773 | if (!inputShape.hasRank() && !paddingShape.hasRank()) { |
774 | inferredReturnShapes.push_back(ShapedTypeComponents()); |
775 | return success(); |
776 | } |
777 | |
778 | // If the input rank is unknown we can info the output rank using the padding |
779 | // shape's first dim. |
780 | if (!inputShape.hasRank()) { |
781 | if (paddingShape.isDynamicDim(0)) { |
782 | inferredReturnShapes.push_back(ShapedTypeComponents()); |
783 | return success(); |
784 | } |
785 | |
786 | outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamic); |
787 | inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); |
788 | return success(); |
789 | } |
790 | |
791 | DenseIntElementsAttr paddings; |
792 | // If the paddings value is not a constant, all dimensions must be dynamic. |
793 | if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) { |
794 | outputShape.resize(inputShape.getRank(), ShapedType::kDynamic); |
795 | inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); |
796 | return success(); |
797 | } |
798 | |
799 | SmallVector<int64_t> paddingValues; |
800 | for (auto val : paddings) { |
801 | paddingValues.push_back(val.getSExtValue()); |
802 | } |
803 | |
804 | outputShape.reserve(inputShape.getRank()); |
805 | for (int i = 0, s = inputShape.getRank(); i < s; i++) { |
806 | if (inputShape.isDynamicDim(i)) { |
807 | outputShape.push_back(ShapedType::kDynamic); |
808 | continue; |
809 | } |
810 | |
811 | outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] + |
812 | paddingValues[i * 2 + 1]); |
813 | } |
814 | |
815 | inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); |
816 | return success(); |
817 | } |
818 | |
819 | static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) { |
820 | return to_vector(Range: llvm::map_range(C&: shape, F: [](int64_t dim) { |
821 | return dim == -1 ? ShapedType::kDynamic : dim; |
822 | })); |
823 | } |
824 | |
825 | LogicalResult tosa::SliceOp::inferReturnTypeComponents( |
826 | MLIRContext *context, ::std::optional<Location> location, |
827 | SliceOp::Adaptor adaptor, |
828 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
829 | inferredReturnShapes.push_back( |
830 | ShapedTypeComponents(convertToMlirShape(adaptor.getSize()))); |
831 | return success(); |
832 | } |
833 | |
834 | LogicalResult tosa::SliceOp::verify() { |
835 | auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType()); |
836 | if (!inputType) |
837 | return success(); |
838 | |
839 | if (static_cast<size_t>(inputType.getRank()) != getStart().size()) |
840 | return emitOpError( |
841 | "length of start attribute is not equal rank of input shape" ); |
842 | |
843 | if (static_cast<size_t>(inputType.getRank()) != getSize().size()) |
844 | return emitOpError( |
845 | "length of size attribute is not equal rank of input shape" ); |
846 | |
847 | return success(); |
848 | } |
849 | |
850 | LogicalResult tosa::TableOp::inferReturnTypeComponents( |
851 | MLIRContext *context, ::std::optional<Location> location, |
852 | TableOp::Adaptor adaptor, |
853 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
854 | ShapeAdaptor inputShape(adaptor.getInput().getType()); |
855 | |
856 | if (!inputShape.hasRank()) { |
857 | inferredReturnShapes.push_back(ShapedTypeComponents()); |
858 | return success(); |
859 | } |
860 | |
861 | inferredReturnShapes.resize(1); |
862 | inputShape.getDims(inferredReturnShapes[0]); |
863 | return success(); |
864 | } |
865 | |
866 | LogicalResult tosa::TileOp::inferReturnTypeComponents( |
867 | MLIRContext *context, ::std::optional<Location> location, |
868 | TileOp::Adaptor adaptor, |
869 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
870 | ArrayRef<int64_t> multiples = adaptor.getMultiples(); |
871 | ShapeAdaptor inputShape(adaptor.getInput1().getType()); |
872 | SmallVector<int64_t> outputShape; |
873 | if (!inputShape.hasRank()) { |
874 | outputShape.resize(multiples.size(), ShapedType::kDynamic); |
875 | inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); |
876 | return success(); |
877 | } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size()) |
878 | return failure(); |
879 | |
880 | // Any non dynamic dimension can be multiplied to a known size. |
881 | outputShape.reserve(multiples.size()); |
882 | for (int i = 0, s = inputShape.getRank(); i < s; i++) { |
883 | int64_t dim = inputShape.getDimSize(i); |
884 | if (dim != ShapedType::kDynamic) |
885 | dim *= multiples[i]; |
886 | outputShape.push_back(dim); |
887 | } |
888 | |
889 | inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); |
890 | return success(); |
891 | } |
892 | |
893 | LogicalResult tosa::TileOp::verify() { |
894 | ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType()); |
895 | ShapedType outputType = llvm::cast<ShapedType>(getType()); |
896 | auto multiples = getMultiples(); |
897 | |
898 | if (inputType.hasRank()) { |
899 | if (static_cast<size_t>(inputType.getRank()) != multiples.size()) |
900 | return emitOpError("expect 'multiples' array to have length " ) |
901 | << inputType.getRank() << " but got " << multiples.size() << "." ; |
902 | if (outputType.hasRank() && inputType.getRank() != outputType.getRank()) |
903 | return emitOpError("expect same input and output tensor rank." ); |
904 | } else if (outputType.hasRank() && |
905 | static_cast<size_t>(outputType.getRank()) != multiples.size()) |
906 | return emitOpError("expect 'multiples' array to have length " ) |
907 | << outputType.getRank() << " but got " << multiples.size() << "." ; |
908 | |
909 | return success(); |
910 | } |
911 | |
912 | bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { |
913 | if (l.size() != r.size() || l.size() != 1) |
914 | return false; |
915 | return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]); |
916 | } |
917 | |
918 | LogicalResult tosa::ReshapeOp::inferReturnTypeComponents( |
919 | MLIRContext *context, ::std::optional<Location> location, |
920 | ReshapeOp::Adaptor adaptor, |
921 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
922 | ShapeAdaptor inputShape(adaptor.getInput1().getType()); |
923 | Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType()); |
924 | llvm::SmallVector<int64_t> newShapeValue = |
925 | convertToMlirShape(adaptor.getNewShape()); |
926 | |
927 | // We cannot infer from the total number of elements so we must take the |
928 | // shape attribute as exact. |
929 | if (!inputShape.hasRank() || !inputShape.hasStaticShape()) { |
930 | inferredReturnShapes.push_back( |
931 | ShapedTypeComponents(newShapeValue, inputType)); |
932 | return success(); |
933 | } |
934 | |
935 | // Determine the number of elements covered by the slice of all static |
936 | // dimensions. This allows us to infer the length of the remaining dynamic |
937 | // dimension. |
938 | int64_t numElements = inputShape.getNumElements(); |
939 | int64_t staticMul = 1; |
940 | for (auto val : newShapeValue) { |
941 | if (!ShapedType::isDynamic(val)) { |
942 | staticMul *= val; |
943 | } |
944 | } |
945 | |
946 | // Determine the length of the dynamic dimension. |
947 | for (auto &val : newShapeValue) { |
948 | if (ShapedType::isDynamic(val)) |
949 | val = numElements / staticMul; |
950 | } |
951 | |
952 | inferredReturnShapes.push_back( |
953 | ShapedTypeComponents(newShapeValue, inputType)); |
954 | return success(); |
955 | } |
956 | |
957 | mlir::LogicalResult tosa::ReshapeOp::verify() { |
958 | TensorType inputType = getInput1().getType(); |
959 | RankedTensorType outputType = getType(); |
960 | |
961 | if (hasZeroDimension(inputType) || hasZeroDimension(outputType)) |
962 | return emitOpError() << "tensor has a dimension with size zero. Each " |
963 | "dimension of a tensor must have size >= 1" ; |
964 | |
965 | if ((int64_t) getNewShape().size() != outputType.getRank()) |
966 | return emitOpError() << "new shape does not match result rank" ; |
967 | |
968 | for (auto [newShapeDim, outputShapeDim] : |
969 | zip(getNewShape(), outputType.getShape())) |
970 | if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic && |
971 | newShapeDim != outputShapeDim) |
972 | return emitOpError() << "new shape is inconsistent with result shape" ; |
973 | |
974 | if (inputType.hasStaticShape() && outputType.hasStaticShape()) { |
975 | int64_t inputElementsNum = inputType.getNumElements(); |
976 | int64_t outputElementsNum = outputType.getNumElements(); |
977 | if (inputElementsNum != outputElementsNum) { |
978 | return emitOpError() << "cannot reshape " << inputElementsNum |
979 | << " elements into " << outputElementsNum; |
980 | } |
981 | } |
982 | |
983 | int missingDims = llvm::count(getNewShape(), -1); |
984 | if (missingDims > 1) |
985 | return emitOpError() << "expected at most one target dimension to be -1" ; |
986 | |
987 | return mlir::success(); |
988 | } |
989 | |
990 | LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int64_t> &perms) { |
991 | // Perms must be constants. |
992 | DenseIntElementsAttr permsAttr; |
993 | if (!matchPattern(getPerms(), m_Constant(&permsAttr))) |
994 | return failure(); |
995 | |
996 | // Transpose is not the identity transpose. |
997 | perms = llvm::to_vector( |
998 | llvm::map_range(permsAttr.getValues<APInt>(), |
999 | [](const APInt &val) { return val.getSExtValue(); })); |
1000 | |
1001 | return success(); |
1002 | } |
1003 | |
1004 | LogicalResult tosa::TransposeOp::inferReturnTypeComponents( |
1005 | MLIRContext *context, ::std::optional<Location> location, |
1006 | TransposeOp::Adaptor adaptor, |
1007 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
1008 | ShapeAdaptor inputShape(adaptor.getInput1().getType()); |
1009 | ShapeAdaptor permsShape(adaptor.getPerms().getType()); |
1010 | |
1011 | // We cannot infer anything from a rank-0 "permutation" tensor. |
1012 | if (permsShape.hasRank() && permsShape.getRank() == 0) |
1013 | return failure(); |
1014 | |
1015 | // If input rank and permutation length is unknown, the output rank is |
1016 | // unknown. |
1017 | if (!inputShape.hasRank() || !permsShape.hasRank() || |
1018 | permsShape.isDynamicDim(0)) { |
1019 | inferredReturnShapes.push_back(ShapedTypeComponents()); |
1020 | return success(); |
1021 | } |
1022 | |
1023 | // This would imply the number of permutations does not match the rank of the |
1024 | // input which is illegal. |
1025 | if (permsShape.getDimSize(0) != inputShape.getRank()) { |
1026 | return failure(); |
1027 | } |
1028 | |
1029 | SmallVector<int64_t> outputShape; |
1030 | // Rank-0 means no permutations matter. |
1031 | if (inputShape.getRank() == 0) { |
1032 | inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); |
1033 | return success(); |
1034 | } |
1035 | |
1036 | // Check whether the input dimensions are all the same. |
1037 | bool allTheSame = true; |
1038 | for (int i = 1, s = inputShape.getRank(); i < s; i++) { |
1039 | if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) { |
1040 | allTheSame = false; |
1041 | break; |
1042 | } |
1043 | } |
1044 | |
1045 | // If all of the input dimensions are the same we don't care about the |
1046 | // permutation. |
1047 | if (allTheSame) { |
1048 | outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0)); |
1049 | inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); |
1050 | return success(); |
1051 | } |
1052 | |
1053 | outputShape.resize(inputShape.getRank(), ShapedType::kDynamic); |
1054 | // If the permuations are a constant we can directly determine the output |
1055 | // shape. |
1056 | DenseIntElementsAttr attr; |
1057 | if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) && |
1058 | attr.getType().getRank() == 1) { |
1059 | ShapeAdaptor permShape = attr; |
1060 | // Constant permutation must be the same length as the input rank. |
1061 | if (inputShape.getRank() != permShape.getRank()) |
1062 | return emitOptionalError(location, |
1063 | "constant permutation must be the same length" |
1064 | " as the input rank" ); |
1065 | |
1066 | // Constant permutation values must be within the input rank. |
1067 | for (int i = 0, e = inputShape.getRank(); i < e; i++) { |
1068 | if (inputShape.getRank() <= permShape.getDimSize(i)) |
1069 | return failure(); |
1070 | } |
1071 | |
1072 | outputShape.reserve(inputShape.getRank()); |
1073 | for (int i = 0, s = inputShape.getRank(); i < s; i++) { |
1074 | outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i)); |
1075 | } |
1076 | } |
1077 | |
1078 | inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); |
1079 | return success(); |
1080 | } |
1081 | |
1082 | LogicalResult tosa::TransposeOp::verify() { |
1083 | TensorType inputType = getInput1().getType(); |
1084 | TensorType permType = getPerms().getType(); |
1085 | TensorType outputType = getOutput().getType(); |
1086 | |
1087 | if (permType.hasRank() && permType.getRank() != 1) |
1088 | return emitOpError() |
1089 | << "expected permutation tensor to be rank 1 but got rank " |
1090 | << permType.getRank(); |
1091 | if (inputType.hasRank() && permType.hasRank()) |
1092 | if (!permType.isDynamicDim(0) && |
1093 | permType.getDimSize(0) != inputType.getRank()) |
1094 | return emitOpError() << "expected permutation tensor dim 0 to have size " |
1095 | << inputType.getRank() |
1096 | << " (input rank) but got size " |
1097 | << permType.getDimSize(0); |
1098 | if (inputType.hasRank() && outputType.hasRank() && |
1099 | inputType.getRank() != outputType.getRank()) |
1100 | return emitOpError() |
1101 | << "expected input tensor rank to equal result tensor rank" ; |
1102 | if (outputType.hasRank() && permType.hasRank()) |
1103 | if (!permType.isDynamicDim(0) && |
1104 | permType.getDimSize(0) != outputType.getRank()) |
1105 | return emitOpError() << "expected permutation tensor dim 0 to have size " |
1106 | << outputType.getRank() |
1107 | << " (output rank) but got size " |
1108 | << permType.getDimSize(0); |
1109 | |
1110 | SmallVector<int64_t> constantPerms; |
1111 | if (succeeded(getConstantPerms(constantPerms))) { |
1112 | // Assert that the permutation tensor has a rank, which means that the rank |
1113 | // has been verified above. |
1114 | assert(permType.hasRank() && |
1115 | "Unexpectedly found permutation tensor without rank" ); |
1116 | if (!isPermutationVector(constantPerms)) |
1117 | return emitOpError() << "expected valid permutation tensor" ; |
1118 | } |
1119 | return success(); |
1120 | } |
1121 | |
1122 | LogicalResult tosa::GatherOp::inferReturnTypeComponents( |
1123 | MLIRContext *context, ::std::optional<Location> location, |
1124 | GatherOp::Adaptor adaptor, |
1125 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
1126 | llvm::SmallVector<int64_t> outputShape; |
1127 | outputShape.resize(3, ShapedType::kDynamic); |
1128 | |
1129 | ShapeAdaptor valuesShape(adaptor.getValues().getType()); |
1130 | if (valuesShape.hasRank()) { |
1131 | outputShape[0] = valuesShape.getDimSize(0); |
1132 | outputShape[2] = valuesShape.getDimSize(2); |
1133 | } |
1134 | |
1135 | ShapeAdaptor indicesShape(adaptor.getIndices().getType()); |
1136 | if (indicesShape.hasRank()) { |
1137 | if (outputShape[0] == ShapedType::kDynamic) |
1138 | outputShape[0] = indicesShape.getDimSize(0); |
1139 | if (outputShape[1] == ShapedType::kDynamic) |
1140 | outputShape[1] = indicesShape.getDimSize(1); |
1141 | } |
1142 | |
1143 | inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); |
1144 | return success(); |
1145 | } |
1146 | |
1147 | LogicalResult tosa::ResizeOp::inferReturnTypeComponents( |
1148 | MLIRContext *context, ::std::optional<Location> location, |
1149 | ResizeOp::Adaptor adaptor, |
1150 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
1151 | llvm::SmallVector<int64_t, 4> outputShape; |
1152 | outputShape.resize(4, ShapedType::kDynamic); |
1153 | |
1154 | ShapeAdaptor inputShape(adaptor.getInput().getType()); |
1155 | if (!inputShape.hasRank()) |
1156 | return failure(); |
1157 | |
1158 | outputShape[0] = inputShape.getDimSize(0); |
1159 | outputShape[3] = inputShape.getDimSize(3); |
1160 | int64_t inputHeight = inputShape.getDimSize(1); |
1161 | int64_t inputWidth = inputShape.getDimSize(2); |
1162 | |
1163 | if ((inputHeight == ShapedType::kDynamic) || |
1164 | (inputWidth == ShapedType::kDynamic)) |
1165 | return failure(); |
1166 | |
1167 | llvm::ArrayRef<int64_t> scaleInt = adaptor.getScale(); |
1168 | llvm::ArrayRef<int64_t> offsetInt = adaptor.getOffset(); |
1169 | llvm::ArrayRef<int64_t> borderInt = adaptor.getBorder(); |
1170 | |
1171 | // Compute the output shape based on attributes: scale, offset, and border. |
1172 | outputShape[1] = |
1173 | (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) / |
1174 | scaleInt[1]) + |
1175 | 1; |
1176 | |
1177 | outputShape[2] = |
1178 | (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) / |
1179 | scaleInt[3]) + |
1180 | 1; |
1181 | |
1182 | inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); |
1183 | return success(); |
1184 | } |
1185 | |
1186 | LogicalResult tosa::ScatterOp::inferReturnTypeComponents( |
1187 | MLIRContext *context, ::std::optional<Location> location, |
1188 | ScatterOp::Adaptor adaptor, |
1189 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
1190 | llvm::SmallVector<int64_t> outputShape; |
1191 | outputShape.resize(3, ShapedType::kDynamic); |
1192 | |
1193 | ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType()); |
1194 | if (valuesInShape.hasRank()) { |
1195 | outputShape[0] = valuesInShape.getDimSize(0); |
1196 | outputShape[1] = valuesInShape.getDimSize(1); |
1197 | outputShape[2] = valuesInShape.getDimSize(2); |
1198 | } |
1199 | |
1200 | ShapeAdaptor indicesShape(adaptor.getIndices().getType()); |
1201 | if (indicesShape.hasRank()) { |
1202 | if (outputShape[0] == ShapedType::kDynamic) |
1203 | outputShape[0] = indicesShape.getDimSize(0); |
1204 | } |
1205 | |
1206 | ShapeAdaptor inputShape(adaptor.getInput().getType()); |
1207 | if (inputShape.hasRank()) { |
1208 | if (outputShape[0] == ShapedType::kDynamic) |
1209 | outputShape[0] = inputShape.getDimSize(0); |
1210 | if (outputShape[2] == ShapedType::kDynamic) |
1211 | outputShape[2] = inputShape.getDimSize(2); |
1212 | } |
1213 | |
1214 | inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); |
1215 | return success(); |
1216 | } |
1217 | |
1218 | static LogicalResult ReduceInferReturnTypes( |
1219 | ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, |
1220 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
1221 | int64_t axisVal = axis.getValue().getSExtValue(); |
1222 | if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) { |
1223 | inferredReturnShapes.push_back(Elt: ShapedTypeComponents(inputType)); |
1224 | return success(); |
1225 | } |
1226 | |
1227 | SmallVector<int64_t> outputShape; |
1228 | operandShape.getDims(res&: outputShape); |
1229 | outputShape[axisVal] = 1; |
1230 | inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape, inputType)); |
1231 | return success(); |
1232 | } |
1233 | |
1234 | #define COMPATIBLE_RETURN_TYPES(OP) \ |
1235 | bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \ |
1236 | if (l.size() != r.size() || l.size() != 1) \ |
1237 | return false; \ |
1238 | if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \ |
1239 | return false; \ |
1240 | return succeeded(verifyCompatibleShape(l[0], r[0])); \ |
1241 | } |
1242 | |
1243 | #define REDUCE_SHAPE_INFER(OP) \ |
1244 | LogicalResult OP::inferReturnTypeComponents( \ |
1245 | MLIRContext *context, ::std::optional<Location> location, \ |
1246 | OP::Adaptor adaptor, \ |
1247 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \ |
1248 | Type inputType = \ |
1249 | llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \ |
1250 | ShapeAdaptor inputShape(adaptor.getInput().getType()); \ |
1251 | const Properties &prop = adaptor.getProperties(); \ |
1252 | return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \ |
1253 | inferredReturnShapes); \ |
1254 | } \ |
1255 | COMPATIBLE_RETURN_TYPES(OP) |
1256 | |
1257 | REDUCE_SHAPE_INFER(tosa::ReduceAllOp) |
1258 | REDUCE_SHAPE_INFER(tosa::ReduceAnyOp) |
1259 | REDUCE_SHAPE_INFER(tosa::ReduceMaxOp) |
1260 | REDUCE_SHAPE_INFER(tosa::ReduceMinOp) |
1261 | REDUCE_SHAPE_INFER(tosa::ReduceProdOp) |
1262 | REDUCE_SHAPE_INFER(tosa::ReduceSumOp) |
1263 | #undef REDUCE_SHAPE_INFER |
1264 | COMPATIBLE_RETURN_TYPES(tosa::ConcatOp) |
1265 | #undef COMPATIBLE_RETURN_TYPES |
1266 | |
1267 | template <typename T> |
1268 | static LogicalResult verifyReduceOp(T op) { |
1269 | // All TOSA reduce Ops have input, output and axis. |
1270 | TensorType inputType = op.getInput().getType(); |
1271 | TensorType outputType = op.getOutput().getType(); |
1272 | int32_t reduceAxis = op.getAxis(); |
1273 | |
1274 | if (reduceAxis < 0) { |
1275 | op.emitOpError("reduce axis must not be negative" ); |
1276 | return failure(); |
1277 | } |
1278 | if (inputType.hasRank()) { |
1279 | int64_t inputRank = inputType.getRank(); |
1280 | // We allow for a special case where the input/output shape has rank 0 and |
1281 | // axis is also 0. |
1282 | if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) { |
1283 | op.emitOpError("expect input tensor rank (" ) |
1284 | << inputRank << ") to be larger than reduce axis (" << reduceAxis |
1285 | << ")" ; |
1286 | return failure(); |
1287 | } |
1288 | } |
1289 | if (outputType.hasRank()) { |
1290 | int64_t outputRank = outputType.getRank(); |
1291 | if (inputType.hasRank() && outputRank != inputType.getRank()) { |
1292 | op.emitOpError( |
1293 | "expect output tensor rank to be equal to input tensor rank" ); |
1294 | return failure(); |
1295 | } |
1296 | if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) { |
1297 | op.emitOpError("expect output tensor rank (" ) |
1298 | << outputRank << ") to be larger than reduce axis (" << reduceAxis |
1299 | << ")" ; |
1300 | return failure(); |
1301 | } |
1302 | // We can only verify the reduced dimension size to be 1 if this is not the |
1303 | // special case of output rank == 0. |
1304 | if (outputRank != 0) { |
1305 | auto outputShape = outputType.getShape(); |
1306 | if (!outputType.isDynamicDim(reduceAxis) && |
1307 | outputShape[reduceAxis] != 1) { |
1308 | op.emitOpError("expect reduced dimension size to be 1, got " ) |
1309 | << outputShape[reduceAxis]; |
1310 | return failure(); |
1311 | } |
1312 | } |
1313 | } |
1314 | return success(); |
1315 | } |
1316 | |
1317 | LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); } |
1318 | LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); } |
1319 | LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); } |
1320 | LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); } |
1321 | LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); } |
1322 | LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); } |
1323 | |
1324 | static LogicalResult NAryInferReturnTypes( |
1325 | const ValueShapeRange &operands, |
1326 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
1327 | llvm::SmallVector<int64_t> outShape; |
1328 | if (resolveBroadcastShape(operands, outShape).failed()) { |
1329 | inferredReturnShapes.push_back(Elt: ShapedTypeComponents()); |
1330 | } else { |
1331 | inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outShape)); |
1332 | } |
1333 | return success(); |
1334 | } |
1335 | |
1336 | #define NARY_SHAPE_INFER(OP) \ |
1337 | LogicalResult OP::inferReturnTypeComponents( \ |
1338 | MLIRContext *context, ::std::optional<Location> location, \ |
1339 | ValueShapeRange operands, DictionaryAttr attributes, \ |
1340 | OpaqueProperties properties, RegionRange regions, \ |
1341 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \ |
1342 | return NAryInferReturnTypes(operands, inferredReturnShapes); \ |
1343 | } |
1344 | |
1345 | NARY_SHAPE_INFER(tosa::AbsOp) |
1346 | NARY_SHAPE_INFER(tosa::AddOp) |
1347 | NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp) |
1348 | NARY_SHAPE_INFER(tosa::BitwiseAndOp) |
1349 | NARY_SHAPE_INFER(tosa::BitwiseOrOp) |
1350 | NARY_SHAPE_INFER(tosa::BitwiseXorOp) |
1351 | NARY_SHAPE_INFER(tosa::BitwiseNotOp) |
1352 | NARY_SHAPE_INFER(tosa::CastOp) |
1353 | NARY_SHAPE_INFER(tosa::CeilOp) |
1354 | NARY_SHAPE_INFER(tosa::ClampOp) |
1355 | NARY_SHAPE_INFER(tosa::ClzOp) |
1356 | NARY_SHAPE_INFER(tosa::CosOp) |
1357 | NARY_SHAPE_INFER(tosa::DivOp) |
1358 | NARY_SHAPE_INFER(tosa::ExpOp) |
1359 | NARY_SHAPE_INFER(tosa::FloorOp) |
1360 | NARY_SHAPE_INFER(tosa::GreaterEqualOp) |
1361 | NARY_SHAPE_INFER(tosa::GreaterOp) |
1362 | NARY_SHAPE_INFER(tosa::IdentityOp) |
1363 | NARY_SHAPE_INFER(tosa::LogOp) |
1364 | NARY_SHAPE_INFER(tosa::LogicalAndOp) |
1365 | NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp) |
1366 | NARY_SHAPE_INFER(tosa::LogicalNotOp) |
1367 | NARY_SHAPE_INFER(tosa::LogicalOrOp) |
1368 | NARY_SHAPE_INFER(tosa::LogicalRightShiftOp) |
1369 | NARY_SHAPE_INFER(tosa::LogicalXorOp) |
1370 | NARY_SHAPE_INFER(tosa::MaximumOp) |
1371 | NARY_SHAPE_INFER(tosa::MinimumOp) |
1372 | NARY_SHAPE_INFER(tosa::MulOp) |
1373 | NARY_SHAPE_INFER(tosa::NegateOp) |
1374 | NARY_SHAPE_INFER(tosa::PowOp) |
1375 | NARY_SHAPE_INFER(tosa::ReciprocalOp) |
1376 | NARY_SHAPE_INFER(tosa::RescaleOp) |
1377 | NARY_SHAPE_INFER(tosa::ReverseOp) |
1378 | NARY_SHAPE_INFER(tosa::RsqrtOp) |
1379 | NARY_SHAPE_INFER(tosa::SinOp) |
1380 | NARY_SHAPE_INFER(tosa::SelectOp) |
1381 | NARY_SHAPE_INFER(tosa::SubOp) |
1382 | NARY_SHAPE_INFER(tosa::TanhOp) |
1383 | NARY_SHAPE_INFER(tosa::ErfOp) |
1384 | NARY_SHAPE_INFER(tosa::SigmoidOp) |
1385 | #undef PRED_SHAPE_INFER |
1386 | |
1387 | static LogicalResult poolingInferReturnTypes( |
1388 | ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride, |
1389 | ArrayRef<int64_t> pad, |
1390 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
1391 | llvm::SmallVector<int64_t> outputShape; |
1392 | outputShape.resize(4, ShapedType::kDynamic); |
1393 | |
1394 | // We only know the rank if the input type is unranked. |
1395 | if (!inputShape) { |
1396 | inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape)); |
1397 | return success(); |
1398 | } |
1399 | |
1400 | // Batch and number of channels are identical for pooling layer. |
1401 | outputShape[0] = inputShape.getDimSize(index: 0); |
1402 | outputShape[3] = inputShape.getDimSize(index: 3); |
1403 | |
1404 | int64_t height = inputShape.getDimSize(index: 1); |
1405 | int64_t width = inputShape.getDimSize(index: 2); |
1406 | |
1407 | if (!ShapedType::isDynamic(height)) { |
1408 | int64_t padded = height + pad[0] + pad[1] - kernel[0]; |
1409 | outputShape[1] = padded / stride[0] + 1; |
1410 | } |
1411 | |
1412 | if (!ShapedType::isDynamic(width)) { |
1413 | int64_t padded = width + pad[2] + pad[3] - kernel[1]; |
1414 | outputShape[2] = padded / stride[1] + 1; |
1415 | } |
1416 | |
1417 | inferredReturnShapes.push_back(Elt: ShapedTypeComponents(outputShape)); |
1418 | return success(); |
1419 | } |
1420 | |
1421 | LogicalResult Conv2DOp::inferReturnTypeComponents( |
1422 | MLIRContext *context, ::std::optional<Location> location, |
1423 | Conv2DOp::Adaptor adaptor, |
1424 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
1425 | llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic); |
1426 | |
1427 | int64_t inputWidth = ShapedType::kDynamic; |
1428 | int64_t inputHeight = ShapedType::kDynamic; |
1429 | int64_t weightWidth = ShapedType::kDynamic; |
1430 | int64_t weightHeight = ShapedType::kDynamic; |
1431 | |
1432 | // Input shape describes input width/height and batch. |
1433 | |
1434 | ShapeAdaptor inputShape(adaptor.getInput().getType()); |
1435 | if (inputShape.hasRank()) { |
1436 | outputShape[0] = inputShape.getDimSize(0); |
1437 | inputHeight = inputShape.getDimSize(1); |
1438 | inputWidth = inputShape.getDimSize(2); |
1439 | } |
1440 | |
1441 | // Weight shapes describes the filter width/height and the output channels. |
1442 | ShapeAdaptor weightShape(adaptor.getWeight().getType()); |
1443 | if (weightShape.hasRank()) { |
1444 | outputShape[3] = weightShape.getDimSize(0); |
1445 | weightHeight = weightShape.getDimSize(1); |
1446 | weightWidth = weightShape.getDimSize(2); |
1447 | } |
1448 | |
1449 | // Bias shape can describe the output channels. |
1450 | ShapeAdaptor biasShape(adaptor.getBias().getType()); |
1451 | if (biasShape.hasRank()) { |
1452 | outputShape[3] = ShapedType::isDynamic(outputShape[3]) |
1453 | ? biasShape.getDimSize(0) |
1454 | : outputShape[3]; |
1455 | } |
1456 | |
1457 | llvm::ArrayRef<int64_t> dilation = adaptor.getDilation(); |
1458 | llvm::ArrayRef<int64_t> stride = adaptor.getStride(); |
1459 | llvm::ArrayRef<int64_t> padding = adaptor.getPad(); |
1460 | |
1461 | if (!ShapedType::isDynamic(inputHeight) && |
1462 | !ShapedType::isDynamic(weightHeight)) { |
1463 | int64_t inputSize = inputHeight + padding[0] + padding[1]; |
1464 | int64_t filterSize = (weightHeight - 1) * dilation[0] + 1; |
1465 | int64_t unstridedResult = inputSize - filterSize + 1; |
1466 | outputShape[1] = (unstridedResult - 1) / stride[0] + 1; |
1467 | } |
1468 | |
1469 | if (!ShapedType::isDynamic(inputWidth) && |
1470 | !ShapedType::isDynamic(weightWidth)) { |
1471 | int64_t inputSize = inputWidth + padding[2] + padding[3]; |
1472 | int64_t filterSize = (weightWidth - 1) * dilation[1] + 1; |
1473 | int64_t unstridedResult = inputSize - filterSize + 1; |
1474 | outputShape[2] = (unstridedResult - 1) / stride[1] + 1; |
1475 | } |
1476 | |
1477 | inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); |
1478 | return success(); |
1479 | } |
1480 | |
1481 | LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); } |
1482 | |
1483 | LogicalResult Conv3DOp::inferReturnTypeComponents( |
1484 | MLIRContext *context, ::std::optional<Location> location, |
1485 | Conv3DOp::Adaptor adaptor, |
1486 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
1487 | llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic); |
1488 | |
1489 | int64_t inputWidth = ShapedType::kDynamic; |
1490 | int64_t inputHeight = ShapedType::kDynamic; |
1491 | int64_t inputDepth = ShapedType::kDynamic; |
1492 | |
1493 | int64_t weightWidth = ShapedType::kDynamic; |
1494 | int64_t weightHeight = ShapedType::kDynamic; |
1495 | int64_t weightDepth = ShapedType::kDynamic; |
1496 | |
1497 | // Input shape describes input width/height and batch. |
1498 | ShapeAdaptor inputShape(adaptor.getInput().getType()); |
1499 | if (inputShape.hasRank()) { |
1500 | outputShape[0] = inputShape.getDimSize(0); |
1501 | inputDepth = inputShape.getDimSize(1); |
1502 | inputHeight = inputShape.getDimSize(2); |
1503 | inputWidth = inputShape.getDimSize(3); |
1504 | } |
1505 | |
1506 | // Weight shapes describes the filter width/height and the output channels. |
1507 | ShapeAdaptor weightShape(adaptor.getWeight().getType()); |
1508 | if (weightShape.hasRank()) { |
1509 | outputShape[4] = weightShape.getDimSize(0); |
1510 | weightDepth = weightShape.getDimSize(1); |
1511 | weightHeight = weightShape.getDimSize(2); |
1512 | weightWidth = weightShape.getDimSize(3); |
1513 | } |
1514 | |
1515 | // Bias shape can describe the output channels. |
1516 | ShapeAdaptor biasShape(adaptor.getBias().getType()); |
1517 | if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) { |
1518 | outputShape[4] = biasShape.getDimSize(0); |
1519 | } |
1520 | |
1521 | llvm::ArrayRef<int64_t> dilation = adaptor.getDilation(); |
1522 | llvm::ArrayRef<int64_t> stride = adaptor.getStride(); |
1523 | llvm::ArrayRef<int64_t> pad = adaptor.getPad(); |
1524 | |
1525 | if (!ShapedType::isDynamic(inputDepth) && |
1526 | !ShapedType::isDynamic(weightDepth)) { |
1527 | int32_t inputSize = inputDepth + pad[0] + pad[1]; |
1528 | int32_t filterSize = (weightDepth - 1) * dilation[0] + 1; |
1529 | int32_t unstridedResult = inputSize - filterSize + 1; |
1530 | outputShape[1] = (unstridedResult - 1) / stride[0] + 1; |
1531 | } |
1532 | |
1533 | if (!ShapedType::isDynamic(inputHeight) && |
1534 | !ShapedType::isDynamic(weightHeight)) { |
1535 | int32_t inputSize = inputHeight + pad[2] + pad[3]; |
1536 | int32_t filterSize = (weightHeight - 1) * dilation[1] + 1; |
1537 | int32_t unstridedResult = inputSize - filterSize + 1; |
1538 | outputShape[2] = (unstridedResult - 1) / stride[1] + 1; |
1539 | } |
1540 | |
1541 | if (!ShapedType::isDynamic(inputWidth) && |
1542 | !ShapedType::isDynamic(weightWidth)) { |
1543 | int32_t inputSize = inputWidth + pad[4] + pad[5]; |
1544 | int32_t filterSize = (weightWidth - 1) * dilation[2] + 1; |
1545 | int32_t unstridedResult = inputSize - filterSize + 1; |
1546 | outputShape[3] = (unstridedResult - 1) / stride[2] + 1; |
1547 | } |
1548 | |
1549 | inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); |
1550 | return success(); |
1551 | } |
1552 | |
1553 | LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); } |
1554 | |
1555 | LogicalResult AvgPool2dOp::inferReturnTypeComponents( |
1556 | MLIRContext *context, ::std::optional<Location> location, |
1557 | AvgPool2dOp::Adaptor adaptor, |
1558 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
1559 | ShapeAdaptor inputShape(adaptor.getInput().getType()); |
1560 | const Properties &prop = adaptor.getProperties(); |
1561 | return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad, |
1562 | inferredReturnShapes); |
1563 | } |
1564 | |
1565 | LogicalResult MaxPool2dOp::inferReturnTypeComponents( |
1566 | MLIRContext *context, ::std::optional<Location> location, |
1567 | MaxPool2dOp::Adaptor adaptor, |
1568 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
1569 | ShapeAdaptor inputShape(adaptor.getInput().getType()); |
1570 | const Properties &prop = adaptor.getProperties(); |
1571 | return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad, |
1572 | inferredReturnShapes); |
1573 | } |
1574 | |
1575 | LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents( |
1576 | MLIRContext *context, ::std::optional<Location> location, |
1577 | DepthwiseConv2DOp::Adaptor adaptor, |
1578 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
1579 | llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic); |
1580 | |
1581 | int64_t inputWidth = ShapedType::kDynamic; |
1582 | int64_t inputHeight = ShapedType::kDynamic; |
1583 | int64_t inputChannels = ShapedType::kDynamic; |
1584 | |
1585 | int64_t weightWidth = ShapedType::kDynamic; |
1586 | int64_t weightHeight = ShapedType::kDynamic; |
1587 | int64_t depthChannels = ShapedType::kDynamic; |
1588 | |
1589 | // Input shape describes input width/height and batch. |
1590 | ShapeAdaptor inputShape(adaptor.getInput().getType()); |
1591 | if (inputShape.hasRank()) { |
1592 | outputShape[0] = inputShape.getDimSize(0); |
1593 | inputHeight = inputShape.getDimSize(1); |
1594 | inputWidth = inputShape.getDimSize(2); |
1595 | inputChannels = inputShape.getDimSize(3); |
1596 | } |
1597 | |
1598 | // Weight shapes describes the filter width/height and the output channels. |
1599 | ShapeAdaptor weightShape(adaptor.getWeight().getType()); |
1600 | if (weightShape.hasRank()) { |
1601 | weightHeight = weightShape.getDimSize(0); |
1602 | weightWidth = weightShape.getDimSize(1); |
1603 | inputChannels = ShapedType::isDynamic(inputChannels) |
1604 | ? weightShape.getDimSize(2) |
1605 | : inputChannels; |
1606 | depthChannels = weightShape.getDimSize(3); |
1607 | } |
1608 | |
1609 | // If both inputChannels and depthChannels are available we can determine |
1610 | // the output channels. |
1611 | if (!ShapedType::isDynamic(inputChannels) && |
1612 | !ShapedType::isDynamic(depthChannels)) { |
1613 | outputShape[3] = inputChannels * depthChannels; |
1614 | } |
1615 | |
1616 | // Bias shape can describe the output channels. |
1617 | ShapeAdaptor biasShape(adaptor.getBias().getType()); |
1618 | if (biasShape.hasRank()) { |
1619 | outputShape[3] = ShapedType::isDynamic(outputShape[3]) |
1620 | ? biasShape.getDimSize(0) |
1621 | : outputShape[3]; |
1622 | } |
1623 | |
1624 | llvm::ArrayRef<int64_t> dilation = adaptor.getDilation(); |
1625 | llvm::ArrayRef<int64_t> padding = adaptor.getPad(); |
1626 | llvm::ArrayRef<int64_t> stride = adaptor.getStride(); |
1627 | |
1628 | if (!ShapedType::isDynamic(inputHeight) && |
1629 | !ShapedType::isDynamic(weightHeight)) { |
1630 | int64_t inputSize = inputHeight + padding[0] + padding[1]; |
1631 | int64_t filterSize = (weightHeight - 1) * dilation[0] + 1; |
1632 | int64_t unstridedResult = inputSize - filterSize + 1; |
1633 | outputShape[1] = (unstridedResult - 1) / stride[0] + 1; |
1634 | } |
1635 | |
1636 | if (!ShapedType::isDynamic(inputWidth) && |
1637 | !ShapedType::isDynamic(weightWidth)) { |
1638 | int64_t inputSize = inputWidth + padding[2] + padding[3]; |
1639 | int64_t filterSize = (weightWidth - 1) * dilation[1] + 1; |
1640 | int64_t unstridedResult = inputSize - filterSize + 1; |
1641 | outputShape[2] = (unstridedResult - 1) / stride[1] + 1; |
1642 | } |
1643 | |
1644 | inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); |
1645 | return success(); |
1646 | } |
1647 | |
1648 | LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); } |
1649 | |
1650 | LogicalResult TransposeConv2DOp::inferReturnTypeComponents( |
1651 | MLIRContext *context, ::std::optional<Location> location, |
1652 | TransposeConv2DOp::Adaptor adaptor, |
1653 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
1654 | // outputShape is mutable. |
1655 | llvm::SmallVector<int64_t> outputShape = |
1656 | convertToMlirShape(adaptor.getOutShape()); |
1657 | |
1658 | int64_t inputWidth = ShapedType::kDynamic; |
1659 | int64_t inputHeight = ShapedType::kDynamic; |
1660 | int64_t weightWidth = ShapedType::kDynamic; |
1661 | int64_t weightHeight = ShapedType::kDynamic; |
1662 | |
1663 | // Input shape describes input width/height and batch. |
1664 | ShapeAdaptor inputShape(adaptor.getInput().getType()); |
1665 | if (inputShape.hasRank()) { |
1666 | outputShape[0] = ShapedType::isDynamic(outputShape[0]) |
1667 | ? inputShape.getDimSize(0) |
1668 | : outputShape[0]; |
1669 | inputHeight = inputShape.getDimSize(1); |
1670 | inputWidth = inputShape.getDimSize(2); |
1671 | } |
1672 | |
1673 | // Weight shapes describes the filter width/height and the output channels. |
1674 | ShapeAdaptor weightShape(adaptor.getFilter().getType()); |
1675 | if (weightShape.hasRank()) { |
1676 | outputShape[3] = ShapedType::isDynamic(outputShape[3]) |
1677 | ? weightShape.getDimSize(0) |
1678 | : outputShape[3]; |
1679 | weightHeight = weightShape.getDimSize(1); |
1680 | weightWidth = weightShape.getDimSize(2); |
1681 | } |
1682 | |
1683 | // Bias shape can describe the output channels. |
1684 | ShapeAdaptor biasShape(adaptor.getInput().getType()); |
1685 | if (biasShape.hasRank()) { |
1686 | outputShape[3] = ShapedType::isDynamic(outputShape[3]) |
1687 | ? biasShape.getDimSize(0) |
1688 | : outputShape[3]; |
1689 | } |
1690 | |
1691 | llvm::ArrayRef<int64_t> padding = adaptor.getOutPad(); |
1692 | llvm::ArrayRef<int64_t> stride = adaptor.getStride(); |
1693 | |
1694 | if (!ShapedType::isDynamic(inputHeight) && |
1695 | !ShapedType::isDynamic(weightHeight)) { |
1696 | int64_t calculateSize = |
1697 | (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight; |
1698 | outputShape[1] = |
1699 | ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1]; |
1700 | } |
1701 | |
1702 | if (!ShapedType::isDynamic(inputWidth) && |
1703 | !ShapedType::isDynamic(weightWidth)) { |
1704 | int64_t calculateSize = |
1705 | (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth; |
1706 | outputShape[2] = |
1707 | ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2]; |
1708 | } |
1709 | |
1710 | inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); |
1711 | return success(); |
1712 | } |
1713 | |
1714 | LogicalResult IfOp::inferReturnTypeComponents( |
1715 | MLIRContext *context, ::std::optional<Location> location, |
1716 | IfOp::Adaptor adaptor, |
1717 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
1718 | llvm::SmallVector<tosa::YieldOp> yieldOps; |
1719 | for (Region *region : adaptor.getRegions()) { |
1720 | for (auto &block : *region) |
1721 | if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator())) |
1722 | yieldOps.push_back(returnOp); |
1723 | } |
1724 | |
1725 | if (yieldOps.empty()) |
1726 | return failure(); |
1727 | |
1728 | // Get the initial type information for the yield op. |
1729 | llvm::SmallVector<ValueKnowledge> resultKnowledge; |
1730 | resultKnowledge.reserve(yieldOps.front().getNumOperands()); |
1731 | for (auto operand : yieldOps.front().getOperands()) { |
1732 | resultKnowledge.push_back( |
1733 | ValueKnowledge::getKnowledgeFromType(operand.getType())); |
1734 | } |
1735 | |
1736 | for (auto yieldOp : yieldOps) { |
1737 | if (resultKnowledge.size() != yieldOp.getNumOperands()) |
1738 | return failure(); |
1739 | |
1740 | for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { |
1741 | int32_t index = it.index(); |
1742 | auto meet = ValueKnowledge::meet( |
1743 | resultKnowledge[index], |
1744 | ValueKnowledge::getKnowledgeFromType(it.value().getType())); |
1745 | if (!meet) |
1746 | continue; |
1747 | resultKnowledge[index] = meet; |
1748 | } |
1749 | } |
1750 | |
1751 | for (const ValueKnowledge &result : resultKnowledge) { |
1752 | inferredReturnShapes.push_back(result.getShapedTypeComponents()); |
1753 | } |
1754 | |
1755 | return success(); |
1756 | } |
1757 | |
1758 | LogicalResult WhileOp::inferReturnTypeComponents( |
1759 | MLIRContext *context, ::std::optional<Location> location, |
1760 | WhileOp::Adaptor adaptor, |
1761 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
1762 | llvm::SmallVector<tosa::YieldOp> yieldOps; |
1763 | for (auto &block : adaptor.getBody()) |
1764 | if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator())) |
1765 | yieldOps.push_back(returnOp); |
1766 | |
1767 | // TOSA's while must have a tosa.yield as its terminator. If not found this |
1768 | // tosa.while is invalid. |
1769 | if (yieldOps.empty()) |
1770 | return failure(); |
1771 | |
1772 | // Get the initial type information from the operand types. |
1773 | llvm::SmallVector<ValueKnowledge> resultKnowledge; |
1774 | resultKnowledge.reserve(yieldOps.front().getNumOperands()); |
1775 | for (auto operand : yieldOps.front().getOperands()) { |
1776 | resultKnowledge.push_back( |
1777 | ValueKnowledge::getKnowledgeFromType(operand.getType())); |
1778 | } |
1779 | |
1780 | for (auto yieldOp : yieldOps) { |
1781 | if (resultKnowledge.size() != yieldOp.getNumOperands()) |
1782 | return failure(); |
1783 | |
1784 | for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { |
1785 | int32_t index = it.index(); |
1786 | if (auto meet = ValueKnowledge::meet( |
1787 | resultKnowledge[index], |
1788 | ValueKnowledge::getKnowledgeFromType(it.value().getType()))) { |
1789 | resultKnowledge[index] = meet; |
1790 | } |
1791 | } |
1792 | } |
1793 | |
1794 | for (const ValueKnowledge &result : resultKnowledge) { |
1795 | inferredReturnShapes.push_back(result.getShapedTypeComponents()); |
1796 | } |
1797 | |
1798 | return success(); |
1799 | } |
1800 | |
1801 | std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() { |
1802 | if (auto vt = llvm::dyn_cast<VectorType>(getType())) |
1803 | return llvm::to_vector<4>(vt.getShape()); |
1804 | return std::nullopt; |
1805 | } |
1806 | |
1807 | // parse and print of IfOp refer to the implementation of SCF dialect. |
1808 | ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { |
1809 | // Create the regions for 'then'. |
1810 | result.regions.reserve(2); |
1811 | Region *thenRegion = result.addRegion(); |
1812 | Region *elseRegion = result.addRegion(); |
1813 | |
1814 | auto &builder = parser.getBuilder(); |
1815 | OpAsmParser::UnresolvedOperand cond; |
1816 | // Create a i1 tensor type for the boolean condition. |
1817 | Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1)); |
1818 | if (parser.parseOperand(cond) || |
1819 | parser.resolveOperand(cond, i1Type, result.operands)) |
1820 | return failure(); |
1821 | // Parse optional results type list. |
1822 | if (parser.parseOptionalArrowTypeList(result.types)) |
1823 | return failure(); |
1824 | // Parse the 'then' region. |
1825 | if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) |
1826 | return failure(); |
1827 | |
1828 | // If we find an 'else' keyword then parse the 'else' region. |
1829 | if (!parser.parseOptionalKeyword("else" )) { |
1830 | if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) |
1831 | return failure(); |
1832 | } |
1833 | |
1834 | // Parse the optional attribute list. |
1835 | if (parser.parseOptionalAttrDict(result.attributes)) |
1836 | return failure(); |
1837 | return success(); |
1838 | } |
1839 | |
1840 | void IfOp::print(OpAsmPrinter &p) { |
1841 | bool printBlockTerminators = false; |
1842 | |
1843 | p << " " << getCond(); |
1844 | if (!getResults().empty()) { |
1845 | p << " -> (" << getResultTypes() << ")" ; |
1846 | // Print yield explicitly if the op defines values. |
1847 | printBlockTerminators = true; |
1848 | } |
1849 | p << ' '; |
1850 | p.printRegion(getThenBranch(), |
1851 | /*printEntryBlockArgs=*/false, |
1852 | /*printBlockTerminators=*/printBlockTerminators); |
1853 | |
1854 | // Print the 'else' regions if it exists and has a block. |
1855 | auto &elseRegion = getElseBranch(); |
1856 | if (!elseRegion.empty()) { |
1857 | p << " else " ; |
1858 | p.printRegion(elseRegion, |
1859 | /*printEntryBlockArgs=*/false, |
1860 | /*printBlockTerminators=*/printBlockTerminators); |
1861 | } |
1862 | |
1863 | p.printOptionalAttrDict((*this)->getAttrs()); |
1864 | } |
1865 | |
1866 | LogicalResult ReverseOp::verify() { |
1867 | TensorType inputType = getInput().getType(); |
1868 | TensorType outputType = getOutput().getType(); |
1869 | int32_t reverseAxis = getAxis(); |
1870 | |
1871 | if (reverseAxis < 0) |
1872 | return emitOpError("expected non-negative reverse axis" ); |
1873 | if (inputType.hasRank()) { |
1874 | int64_t inputRank = inputType.getRank(); |
1875 | // We allow for a special case where the input/output shape has rank 0 and |
1876 | // axis is also 0. |
1877 | if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0)) |
1878 | return emitOpError("expect input tensor rank (" ) |
1879 | << inputRank << ") to be larger than reverse axis (" << reverseAxis |
1880 | << ")" ; |
1881 | } |
1882 | if (outputType.hasRank()) { |
1883 | int64_t outputRank = outputType.getRank(); |
1884 | if (inputType.hasRank() && outputRank != inputType.getRank()) |
1885 | return emitOpError( |
1886 | "expect output tensor rank to be equal to input tensor rank" ); |
1887 | if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0)) |
1888 | return emitOpError("expect output tensor rank (" ) |
1889 | << outputRank << ") to be larger than reverse axis (" |
1890 | << reverseAxis << ")" ; |
1891 | } |
1892 | return success(); |
1893 | } |
1894 | |
1895 | // parse and print of WhileOp refer to the implementation of SCF dialect. |
1896 | ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) { |
1897 | SmallVector<OpAsmParser::Argument, 4> regionArgs; |
1898 | SmallVector<OpAsmParser::UnresolvedOperand, 4> operands; |
1899 | Region *cond = result.addRegion(); |
1900 | Region *body = result.addRegion(); |
1901 | |
1902 | OptionalParseResult listResult = |
1903 | parser.parseOptionalAssignmentList(regionArgs, operands); |
1904 | if (listResult.has_value() && failed(listResult.value())) |
1905 | return failure(); |
1906 | |
1907 | FunctionType functionType; |
1908 | SMLoc typeLoc = parser.getCurrentLocation(); |
1909 | if (failed(parser.parseColonType(functionType))) |
1910 | return failure(); |
1911 | |
1912 | result.addTypes(functionType.getResults()); |
1913 | |
1914 | if (functionType.getNumInputs() != operands.size()) { |
1915 | return parser.emitError(typeLoc) |
1916 | << "expected as many input types as operands " |
1917 | << "(expected " << operands.size() << " got " |
1918 | << functionType.getNumInputs() << ")" ; |
1919 | } |
1920 | |
1921 | // Resolve input operands. |
1922 | if (failed(parser.resolveOperands(operands, functionType.getInputs(), |
1923 | parser.getCurrentLocation(), |
1924 | result.operands))) |
1925 | return failure(); |
1926 | |
1927 | // Propagate the types into the region arguments. |
1928 | for (size_t i = 0, e = regionArgs.size(); i != e; ++i) |
1929 | regionArgs[i].type = functionType.getInput(i); |
1930 | |
1931 | return failure(parser.parseRegion(*cond, regionArgs) || |
1932 | parser.parseKeyword("do" ) || parser.parseRegion(*body) || |
1933 | parser.parseOptionalAttrDictWithKeyword(result.attributes)); |
1934 | } |
1935 | |
1936 | static void printInitializationList(OpAsmPrinter &parser, |
1937 | Block::BlockArgListType blocksArgs, |
1938 | ValueRange initializers, |
1939 | StringRef prefix = "" ) { |
1940 | assert(blocksArgs.size() == initializers.size() && |
1941 | "expected same length of arguments and initializers" ); |
1942 | if (initializers.empty()) |
1943 | return; |
1944 | |
1945 | parser << prefix << '('; |
1946 | llvm::interleaveComma( |
1947 | c: llvm::zip(t&: blocksArgs, u&: initializers), os&: parser, |
1948 | each_fn: [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); }); |
1949 | parser << ")" ; |
1950 | } |
1951 | |
1952 | void WhileOp::print(OpAsmPrinter &parser) { |
1953 | printInitializationList(parser, getCond().front().getArguments(), getInputs(), |
1954 | " " ); |
1955 | parser << " : " ; |
1956 | parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes()); |
1957 | parser << ' '; |
1958 | parser.printRegion(getCond(), /*printEntryBlockArgs=*/false); |
1959 | parser << " do " ; |
1960 | parser.printRegion(getBody()); |
1961 | parser.printOptionalAttrDictWithKeyword((*this)->getAttrs()); |
1962 | } |
1963 | |
1964 | //===----------------------------------------------------------------------===// |
1965 | // TOSA Attribute Definitions. |
1966 | //===----------------------------------------------------------------------===// |
1967 | |
1968 | #define GET_ATTRDEF_CLASSES |
1969 | #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc" |
1970 | |
1971 | //===----------------------------------------------------------------------===// |
1972 | // TOSA Operator Definitions. |
1973 | //===----------------------------------------------------------------------===// |
1974 | |
1975 | #define GET_OP_CLASSES |
1976 | #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" |
1977 | |