| 1 | //===- CastInterfaces.cpp -------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Interfaces/CastInterfaces.h" |
| 10 | |
| 11 | #include "mlir/IR/BuiltinDialect.h" |
| 12 | #include "mlir/IR/BuiltinOps.h" |
| 13 | |
| 14 | using namespace mlir; |
| 15 | |
| 16 | //===----------------------------------------------------------------------===// |
| 17 | // Helper functions for CastOpInterface |
| 18 | //===----------------------------------------------------------------------===// |
| 19 | |
| 20 | /// Attempt to fold the given cast operation. |
| 21 | LogicalResult |
| 22 | impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands, |
| 23 | SmallVectorImpl<OpFoldResult> &foldResults) { |
| 24 | OperandRange operands = op->getOperands(); |
| 25 | if (operands.empty()) |
| 26 | return failure(); |
| 27 | ResultRange results = op->getResults(); |
| 28 | |
| 29 | // Check for the case where the input and output types match 1-1. |
| 30 | if (operands.getTypes() == results.getTypes()) { |
| 31 | foldResults.append(in_start: operands.begin(), in_end: operands.end()); |
| 32 | return success(); |
| 33 | } |
| 34 | |
| 35 | return failure(); |
| 36 | } |
| 37 | |
| 38 | /// Attempt to verify the given cast operation. |
| 39 | LogicalResult impl::verifyCastInterfaceOp(Operation *op) { |
| 40 | auto resultTypes = op->getResultTypes(); |
| 41 | if (resultTypes.empty()) |
| 42 | return op->emitOpError() |
| 43 | << "expected at least one result for cast operation" ; |
| 44 | |
| 45 | auto operandTypes = op->getOperandTypes(); |
| 46 | if (!cast<CastOpInterface>(op).areCastCompatible(operandTypes, resultTypes)) { |
| 47 | InFlightDiagnostic diag = op->emitOpError(message: "operand type" ); |
| 48 | if (operandTypes.empty()) |
| 49 | diag << "s []" ; |
| 50 | else if (llvm::size(Range&: operandTypes) == 1) |
| 51 | diag << " " << *operandTypes.begin(); |
| 52 | else |
| 53 | diag << "s " << operandTypes; |
| 54 | return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s " ) |
| 55 | << resultTypes << " are cast incompatible" ; |
| 56 | } |
| 57 | |
| 58 | return success(); |
| 59 | } |
| 60 | |
| 61 | //===----------------------------------------------------------------------===// |
| 62 | // External model for BuiltinDialect ops |
| 63 | //===----------------------------------------------------------------------===// |
| 64 | |
| 65 | namespace mlir { |
| 66 | namespace { |
| 67 | // This interface cannot be implemented directly on the op because the IR build |
| 68 | // unit cannot depend on the Interfaces build unit. |
| 69 | struct UnrealizedConversionCastOpInterface |
| 70 | : CastOpInterface::ExternalModel<UnrealizedConversionCastOpInterface, |
| 71 | UnrealizedConversionCastOp> { |
| 72 | static bool areCastCompatible(TypeRange inputs, TypeRange outputs) { |
| 73 | // `UnrealizedConversionCastOp` is agnostic of the input/output types. |
| 74 | return true; |
| 75 | } |
| 76 | }; |
| 77 | } // namespace |
| 78 | } // namespace mlir |
| 79 | |
| 80 | void mlir::builtin::registerCastOpInterfaceExternalModels( |
| 81 | DialectRegistry ®istry) { |
| 82 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) { |
| 83 | UnrealizedConversionCastOp::attachInterface< |
| 84 | UnrealizedConversionCastOpInterface>(*ctx); |
| 85 | }); |
| 86 | } |
| 87 | |
| 88 | //===----------------------------------------------------------------------===// |
| 89 | // Table-generated class definitions |
| 90 | //===----------------------------------------------------------------------===// |
| 91 | |
| 92 | #include "mlir/Interfaces/CastInterfaces.cpp.inc" |
| 93 | |