1//===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Interfaces/DestinationStyleOpInterface.h"
10
11using namespace mlir;
12
13namespace mlir {
14#include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
15} // namespace mlir
16
17namespace {
18size_t getNumTensorResults(Operation *op) {
19 size_t numTensorResults = 0;
20 for (auto t : op->getResultTypes()) {
21 if (isa<TensorType>(Val: t)) {
22 ++numTensorResults;
23 }
24 }
25 return numTensorResults;
26}
27} // namespace
28
29LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
30 DestinationStyleOpInterface dstStyleOp =
31 cast<DestinationStyleOpInterface>(op);
32
33 SmallVector<OpOperand *> outputTensorOperands;
34 for (OpOperand &operand : dstStyleOp.getDpsInitsMutable()) {
35 Type type = operand.get().getType();
36 if (isa<TensorType>(type)) {
37 outputTensorOperands.push_back(&operand);
38 } else if (!isa<BaseMemRefType>(type)) {
39 return op->emitOpError("expected that operand #")
40 << operand.getOperandNumber() << " is a tensor or a memref";
41 }
42 }
43
44 // Verify the number of tensor results matches the number of output tensors.
45 if (getNumTensorResults(op) != outputTensorOperands.size())
46 return op->emitOpError(message: "expected the number of tensor results (")
47 << getNumTensorResults(op)
48 << ") to be equal to the number of output tensors ("
49 << outputTensorOperands.size() << ")";
50
51 for (OpOperand *opOperand : outputTensorOperands) {
52 OpResult result = dstStyleOp.getTiedOpResult(opOperand);
53 if (result.getType() != opOperand->get().getType())
54 return op->emitOpError(message: "expected type of operand #")
55 << opOperand->getOperandNumber() << " ("
56 << opOperand->get().getType() << ")"
57 << " to match type of corresponding result (" << result.getType()
58 << ")";
59 }
60
61 return success();
62}
63

source code of mlir/lib/Interfaces/DestinationStyleOpInterface.cpp