1 | |
2 | |
3 | //===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===// |
4 | // |
5 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
6 | // See https://llvm.org/LICENSE.txt for license information. |
7 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
8 | // |
9 | //===----------------------------------------------------------------------===// |
10 | |
11 | #include "mlir-c/Interfaces.h" |
12 | |
13 | #include "mlir/CAPI/IR.h" |
14 | #include "mlir/CAPI/Interfaces.h" |
15 | #include "mlir/CAPI/Support.h" |
16 | #include "mlir/CAPI/Wrap.h" |
17 | #include "mlir/IR/ValueRange.h" |
18 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
19 | #include "llvm/ADT/ScopeExit.h" |
20 | #include <optional> |
21 | |
22 | using namespace mlir; |
23 | |
24 | namespace { |
25 | |
26 | std::optional<RegisteredOperationName> |
27 | getRegisteredOperationName(MlirContext context, MlirStringRef opName) { |
28 | StringRef name(opName.data, opName.length); |
29 | std::optional<RegisteredOperationName> info = |
30 | RegisteredOperationName::lookup(name, ctx: unwrap(c: context)); |
31 | return info; |
32 | } |
33 | |
34 | std::optional<Location> maybeGetLocation(MlirLocation location) { |
35 | std::optional<Location> maybeLocation; |
36 | if (!mlirLocationIsNull(location)) |
37 | maybeLocation = unwrap(c: location); |
38 | return maybeLocation; |
39 | } |
40 | |
41 | SmallVector<Value> unwrapOperands(intptr_t nOperands, MlirValue *operands) { |
42 | SmallVector<Value> unwrappedOperands; |
43 | (void)unwrapList(size: nOperands, first: operands, storage&: unwrappedOperands); |
44 | return unwrappedOperands; |
45 | } |
46 | |
47 | DictionaryAttr unwrapAttributes(MlirAttribute attributes) { |
48 | DictionaryAttr attributeDict; |
49 | if (!mlirAttributeIsNull(attr: attributes)) |
50 | attributeDict = llvm::cast<DictionaryAttr>(unwrap(c: attributes)); |
51 | return attributeDict; |
52 | } |
53 | |
54 | SmallVector<std::unique_ptr<Region>> unwrapRegions(intptr_t nRegions, |
55 | MlirRegion *regions) { |
56 | // Create a vector of unique pointers to regions and make sure they are not |
57 | // deleted when exiting the scope. This is a hack caused by C++ API expecting |
58 | // an list of unique pointers to regions (without ownership transfer |
59 | // semantics) and C API making ownership transfer explicit. |
60 | SmallVector<std::unique_ptr<Region>> unwrappedRegions; |
61 | unwrappedRegions.reserve(N: nRegions); |
62 | for (intptr_t i = 0; i < nRegions; ++i) |
63 | unwrappedRegions.emplace_back(Args: unwrap(c: *(regions + i))); |
64 | auto cleaner = llvm::make_scope_exit(F: [&]() { |
65 | for (auto ®ion : unwrappedRegions) |
66 | region.release(); |
67 | }); |
68 | return unwrappedRegions; |
69 | } |
70 | |
71 | } // namespace |
72 | |
73 | bool mlirOperationImplementsInterface(MlirOperation operation, |
74 | MlirTypeID interfaceTypeID) { |
75 | std::optional<RegisteredOperationName> info = |
76 | unwrap(c: operation)->getRegisteredInfo(); |
77 | return info && info->hasInterface(interfaceID: unwrap(c: interfaceTypeID)); |
78 | } |
79 | |
80 | bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, |
81 | MlirContext context, |
82 | MlirTypeID interfaceTypeID) { |
83 | std::optional<RegisteredOperationName> info = RegisteredOperationName::lookup( |
84 | name: StringRef(operationName.data, operationName.length), ctx: unwrap(c: context)); |
85 | return info && info->hasInterface(interfaceID: unwrap(c: interfaceTypeID)); |
86 | } |
87 | |
88 | MlirTypeID mlirInferTypeOpInterfaceTypeID() { |
89 | return wrap(InferTypeOpInterface::getInterfaceID()); |
90 | } |
91 | |
92 | MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( |
93 | MlirStringRef opName, MlirContext context, MlirLocation location, |
94 | intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, |
95 | void *properties, intptr_t nRegions, MlirRegion *regions, |
96 | MlirTypesCallback callback, void *userData) { |
97 | StringRef name(opName.data, opName.length); |
98 | std::optional<RegisteredOperationName> info = |
99 | getRegisteredOperationName(context, opName); |
100 | if (!info) |
101 | return mlirLogicalResultFailure(); |
102 | |
103 | std::optional<Location> maybeLocation = maybeGetLocation(location); |
104 | SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands); |
105 | DictionaryAttr attributeDict = unwrapAttributes(attributes); |
106 | SmallVector<std::unique_ptr<Region>> unwrappedRegions = |
107 | unwrapRegions(nRegions, regions); |
108 | |
109 | SmallVector<Type> inferredTypes; |
110 | if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes( |
111 | unwrap(context), maybeLocation, unwrappedOperands, attributeDict, |
112 | properties, unwrappedRegions, inferredTypes))) |
113 | return mlirLogicalResultFailure(); |
114 | |
115 | SmallVector<MlirType> wrappedInferredTypes; |
116 | wrappedInferredTypes.reserve(N: inferredTypes.size()); |
117 | for (Type t : inferredTypes) |
118 | wrappedInferredTypes.push_back(Elt: wrap(cpp: t)); |
119 | callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData); |
120 | return mlirLogicalResultSuccess(); |
121 | } |
122 | |
123 | MlirTypeID mlirInferShapedTypeOpInterfaceTypeID() { |
124 | return wrap(InferShapedTypeOpInterface::getInterfaceID()); |
125 | } |
126 | |
127 | MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes( |
128 | MlirStringRef opName, MlirContext context, MlirLocation location, |
129 | intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, |
130 | void *properties, intptr_t nRegions, MlirRegion *regions, |
131 | MlirShapedTypeComponentsCallback callback, void *userData) { |
132 | std::optional<RegisteredOperationName> info = |
133 | getRegisteredOperationName(context, opName); |
134 | if (!info) |
135 | return mlirLogicalResultFailure(); |
136 | |
137 | std::optional<Location> maybeLocation = maybeGetLocation(location); |
138 | SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands); |
139 | DictionaryAttr attributeDict = unwrapAttributes(attributes); |
140 | SmallVector<std::unique_ptr<Region>> unwrappedRegions = |
141 | unwrapRegions(nRegions, regions); |
142 | |
143 | SmallVector<ShapedTypeComponents> inferredTypeComponents; |
144 | if (failed(info->getInterface<InferShapedTypeOpInterface>() |
145 | ->inferReturnTypeComponents( |
146 | unwrap(context), maybeLocation, |
147 | mlir::ValueRange(llvm::ArrayRef(unwrappedOperands)), |
148 | attributeDict, properties, unwrappedRegions, |
149 | inferredTypeComponents))) |
150 | return mlirLogicalResultFailure(); |
151 | |
152 | bool hasRank; |
153 | intptr_t rank; |
154 | const int64_t *shapeData; |
155 | for (const ShapedTypeComponents &t : inferredTypeComponents) { |
156 | if (t.hasRank()) { |
157 | hasRank = true; |
158 | rank = t.getDims().size(); |
159 | shapeData = t.getDims().data(); |
160 | } else { |
161 | hasRank = false; |
162 | rank = 0; |
163 | shapeData = nullptr; |
164 | } |
165 | callback(hasRank, rank, shapeData, wrap(cpp: t.getElementType()), |
166 | wrap(cpp: t.getAttribute()), userData); |
167 | } |
168 | return mlirLogicalResultSuccess(); |
169 | } |
170 | |