1
2
3//===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===//
4//
5// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6// See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9//===----------------------------------------------------------------------===//
10
11#include "mlir-c/Interfaces.h"
12
13#include "mlir/CAPI/IR.h"
14#include "mlir/CAPI/Interfaces.h"
15#include "mlir/CAPI/Support.h"
16#include "mlir/CAPI/Wrap.h"
17#include "mlir/IR/ValueRange.h"
18#include "mlir/Interfaces/InferTypeOpInterface.h"
19#include "llvm/ADT/ScopeExit.h"
20#include <optional>
21
22using namespace mlir;
23
24namespace {
25
26std::optional<RegisteredOperationName>
27getRegisteredOperationName(MlirContext context, MlirStringRef opName) {
28 StringRef name(opName.data, opName.length);
29 std::optional<RegisteredOperationName> info =
30 RegisteredOperationName::lookup(name, ctx: unwrap(c: context));
31 return info;
32}
33
34std::optional<Location> maybeGetLocation(MlirLocation location) {
35 std::optional<Location> maybeLocation;
36 if (!mlirLocationIsNull(location))
37 maybeLocation = unwrap(c: location);
38 return maybeLocation;
39}
40
41SmallVector<Value> unwrapOperands(intptr_t nOperands, MlirValue *operands) {
42 SmallVector<Value> unwrappedOperands;
43 (void)unwrapList(size: nOperands, first: operands, storage&: unwrappedOperands);
44 return unwrappedOperands;
45}
46
47DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
48 DictionaryAttr attributeDict;
49 if (!mlirAttributeIsNull(attr: attributes))
50 attributeDict = llvm::cast<DictionaryAttr>(unwrap(c: attributes));
51 return attributeDict;
52}
53
54SmallVector<std::unique_ptr<Region>> unwrapRegions(intptr_t nRegions,
55 MlirRegion *regions) {
56 // Create a vector of unique pointers to regions and make sure they are not
57 // deleted when exiting the scope. This is a hack caused by C++ API expecting
58 // an list of unique pointers to regions (without ownership transfer
59 // semantics) and C API making ownership transfer explicit.
60 SmallVector<std::unique_ptr<Region>> unwrappedRegions;
61 unwrappedRegions.reserve(N: nRegions);
62 for (intptr_t i = 0; i < nRegions; ++i)
63 unwrappedRegions.emplace_back(Args: unwrap(c: *(regions + i)));
64 auto cleaner = llvm::make_scope_exit(F: [&]() {
65 for (auto &region : unwrappedRegions)
66 region.release();
67 });
68 return unwrappedRegions;
69}
70
71} // namespace
72
73bool mlirOperationImplementsInterface(MlirOperation operation,
74 MlirTypeID interfaceTypeID) {
75 std::optional<RegisteredOperationName> info =
76 unwrap(c: operation)->getRegisteredInfo();
77 return info && info->hasInterface(interfaceID: unwrap(c: interfaceTypeID));
78}
79
80bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
81 MlirContext context,
82 MlirTypeID interfaceTypeID) {
83 std::optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
84 name: StringRef(operationName.data, operationName.length), ctx: unwrap(c: context));
85 return info && info->hasInterface(interfaceID: unwrap(c: interfaceTypeID));
86}
87
88MlirTypeID mlirInferTypeOpInterfaceTypeID() {
89 return wrap(InferTypeOpInterface::getInterfaceID());
90}
91
92MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
93 MlirStringRef opName, MlirContext context, MlirLocation location,
94 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
95 void *properties, intptr_t nRegions, MlirRegion *regions,
96 MlirTypesCallback callback, void *userData) {
97 StringRef name(opName.data, opName.length);
98 std::optional<RegisteredOperationName> info =
99 getRegisteredOperationName(context, opName);
100 if (!info)
101 return mlirLogicalResultFailure();
102
103 std::optional<Location> maybeLocation = maybeGetLocation(location);
104 SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
105 DictionaryAttr attributeDict = unwrapAttributes(attributes);
106 SmallVector<std::unique_ptr<Region>> unwrappedRegions =
107 unwrapRegions(nRegions, regions);
108
109 SmallVector<Type> inferredTypes;
110 if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
111 unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
112 properties, unwrappedRegions, inferredTypes)))
113 return mlirLogicalResultFailure();
114
115 SmallVector<MlirType> wrappedInferredTypes;
116 wrappedInferredTypes.reserve(N: inferredTypes.size());
117 for (Type t : inferredTypes)
118 wrappedInferredTypes.push_back(Elt: wrap(cpp: t));
119 callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
120 return mlirLogicalResultSuccess();
121}
122
123MlirTypeID mlirInferShapedTypeOpInterfaceTypeID() {
124 return wrap(InferShapedTypeOpInterface::getInterfaceID());
125}
126
127MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(
128 MlirStringRef opName, MlirContext context, MlirLocation location,
129 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
130 void *properties, intptr_t nRegions, MlirRegion *regions,
131 MlirShapedTypeComponentsCallback callback, void *userData) {
132 std::optional<RegisteredOperationName> info =
133 getRegisteredOperationName(context, opName);
134 if (!info)
135 return mlirLogicalResultFailure();
136
137 std::optional<Location> maybeLocation = maybeGetLocation(location);
138 SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
139 DictionaryAttr attributeDict = unwrapAttributes(attributes);
140 SmallVector<std::unique_ptr<Region>> unwrappedRegions =
141 unwrapRegions(nRegions, regions);
142
143 SmallVector<ShapedTypeComponents> inferredTypeComponents;
144 if (failed(info->getInterface<InferShapedTypeOpInterface>()
145 ->inferReturnTypeComponents(
146 unwrap(context), maybeLocation,
147 mlir::ValueRange(llvm::ArrayRef(unwrappedOperands)),
148 attributeDict, properties, unwrappedRegions,
149 inferredTypeComponents)))
150 return mlirLogicalResultFailure();
151
152 bool hasRank;
153 intptr_t rank;
154 const int64_t *shapeData;
155 for (const ShapedTypeComponents &t : inferredTypeComponents) {
156 if (t.hasRank()) {
157 hasRank = true;
158 rank = t.getDims().size();
159 shapeData = t.getDims().data();
160 } else {
161 hasRank = false;
162 rank = 0;
163 shapeData = nullptr;
164 }
165 callback(hasRank, rank, shapeData, wrap(cpp: t.getElementType()),
166 wrap(cpp: t.getAttribute()), userData);
167 }
168 return mlirLogicalResultSuccess();
169}
170

source code of mlir/lib/CAPI/Interfaces/Interfaces.cpp