1 | //===- CastInterfaces.cpp -------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Interfaces/CastInterfaces.h" |
10 | |
11 | #include "mlir/IR/BuiltinDialect.h" |
12 | #include "mlir/IR/BuiltinOps.h" |
13 | |
14 | using namespace mlir; |
15 | |
16 | //===----------------------------------------------------------------------===// |
17 | // Helper functions for CastOpInterface |
18 | //===----------------------------------------------------------------------===// |
19 | |
20 | /// Attempt to fold the given cast operation. |
21 | LogicalResult |
22 | impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands, |
23 | SmallVectorImpl<OpFoldResult> &foldResults) { |
24 | OperandRange operands = op->getOperands(); |
25 | if (operands.empty()) |
26 | return failure(); |
27 | ResultRange results = op->getResults(); |
28 | |
29 | // Check for the case where the input and output types match 1-1. |
30 | if (operands.getTypes() == results.getTypes()) { |
31 | foldResults.append(in_start: operands.begin(), in_end: operands.end()); |
32 | return success(); |
33 | } |
34 | |
35 | return failure(); |
36 | } |
37 | |
38 | /// Attempt to verify the given cast operation. |
39 | LogicalResult impl::verifyCastInterfaceOp(Operation *op) { |
40 | auto resultTypes = op->getResultTypes(); |
41 | if (resultTypes.empty()) |
42 | return op->emitOpError() |
43 | << "expected at least one result for cast operation" ; |
44 | |
45 | auto operandTypes = op->getOperandTypes(); |
46 | if (!cast<CastOpInterface>(op).areCastCompatible(operandTypes, resultTypes)) { |
47 | InFlightDiagnostic diag = op->emitOpError(message: "operand type" ); |
48 | if (operandTypes.empty()) |
49 | diag << "s []" ; |
50 | else if (llvm::size(Range&: operandTypes) == 1) |
51 | diag << " " << *operandTypes.begin(); |
52 | else |
53 | diag << "s " << operandTypes; |
54 | return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s " ) |
55 | << resultTypes << " are cast incompatible" ; |
56 | } |
57 | |
58 | return success(); |
59 | } |
60 | |
61 | //===----------------------------------------------------------------------===// |
62 | // External model for BuiltinDialect ops |
63 | //===----------------------------------------------------------------------===// |
64 | |
65 | namespace mlir { |
66 | namespace { |
67 | // This interface cannot be implemented directly on the op because the IR build |
68 | // unit cannot depend on the Interfaces build unit. |
69 | struct UnrealizedConversionCastOpInterface |
70 | : CastOpInterface::ExternalModel<UnrealizedConversionCastOpInterface, |
71 | UnrealizedConversionCastOp> { |
72 | static bool areCastCompatible(TypeRange inputs, TypeRange outputs) { |
73 | // `UnrealizedConversionCastOp` is agnostic of the input/output types. |
74 | return true; |
75 | } |
76 | }; |
77 | } // namespace |
78 | } // namespace mlir |
79 | |
80 | void mlir::builtin::registerCastOpInterfaceExternalModels( |
81 | DialectRegistry ®istry) { |
82 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) { |
83 | UnrealizedConversionCastOp::attachInterface< |
84 | UnrealizedConversionCastOpInterface>(*ctx); |
85 | }); |
86 | } |
87 | |
88 | //===----------------------------------------------------------------------===// |
89 | // Table-generated class definitions |
90 | //===----------------------------------------------------------------------===// |
91 | |
92 | #include "mlir/Interfaces/CastInterfaces.cpp.inc" |
93 | |