1 | //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir-c/IR.h" |
10 | #include "mlir-c/Support.h" |
11 | |
12 | #include "mlir/AsmParser/AsmParser.h" |
13 | #include "mlir/Bytecode/BytecodeWriter.h" |
14 | #include "mlir/CAPI/IR.h" |
15 | #include "mlir/CAPI/Support.h" |
16 | #include "mlir/CAPI/Utils.h" |
17 | #include "mlir/IR/Attributes.h" |
18 | #include "mlir/IR/BuiltinAttributes.h" |
19 | #include "mlir/IR/BuiltinOps.h" |
20 | #include "mlir/IR/Diagnostics.h" |
21 | #include "mlir/IR/Dialect.h" |
22 | #include "mlir/IR/Location.h" |
23 | #include "mlir/IR/Operation.h" |
24 | #include "mlir/IR/OperationSupport.h" |
25 | #include "mlir/IR/Types.h" |
26 | #include "mlir/IR/Value.h" |
27 | #include "mlir/IR/Verifier.h" |
28 | #include "mlir/IR/Visitors.h" |
29 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
30 | #include "mlir/Parser/Parser.h" |
31 | #include "llvm/Support/ThreadPool.h" |
32 | |
33 | #include <cstddef> |
34 | #include <memory> |
35 | #include <optional> |
36 | |
37 | using namespace mlir; |
38 | |
39 | //===----------------------------------------------------------------------===// |
40 | // Context API. |
41 | //===----------------------------------------------------------------------===// |
42 | |
43 | MlirContext mlirContextCreate() { |
44 | auto *context = new MLIRContext; |
45 | return wrap(cpp: context); |
46 | } |
47 | |
48 | static inline MLIRContext::Threading toThreadingEnum(bool threadingEnabled) { |
49 | return threadingEnabled ? MLIRContext::Threading::ENABLED |
50 | : MLIRContext::Threading::DISABLED; |
51 | } |
52 | |
53 | MlirContext mlirContextCreateWithThreading(bool threadingEnabled) { |
54 | auto *context = new MLIRContext(toThreadingEnum(threadingEnabled)); |
55 | return wrap(cpp: context); |
56 | } |
57 | |
58 | MlirContext mlirContextCreateWithRegistry(MlirDialectRegistry registry, |
59 | bool threadingEnabled) { |
60 | auto *context = |
61 | new MLIRContext(*unwrap(c: registry), toThreadingEnum(threadingEnabled)); |
62 | return wrap(cpp: context); |
63 | } |
64 | |
65 | bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { |
66 | return unwrap(c: ctx1) == unwrap(c: ctx2); |
67 | } |
68 | |
69 | void mlirContextDestroy(MlirContext context) { delete unwrap(c: context); } |
70 | |
71 | void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) { |
72 | unwrap(c: context)->allowUnregisteredDialects(allow); |
73 | } |
74 | |
75 | bool mlirContextGetAllowUnregisteredDialects(MlirContext context) { |
76 | return unwrap(c: context)->allowsUnregisteredDialects(); |
77 | } |
78 | intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) { |
79 | return static_cast<intptr_t>(unwrap(c: context)->getAvailableDialects().size()); |
80 | } |
81 | |
82 | void mlirContextAppendDialectRegistry(MlirContext ctx, |
83 | MlirDialectRegistry registry) { |
84 | unwrap(c: ctx)->appendDialectRegistry(registry: *unwrap(c: registry)); |
85 | } |
86 | |
87 | // TODO: expose a cheaper way than constructing + sorting a vector only to take |
88 | // its size. |
89 | intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { |
90 | return static_cast<intptr_t>(unwrap(c: context)->getLoadedDialects().size()); |
91 | } |
92 | |
93 | MlirDialect mlirContextGetOrLoadDialect(MlirContext context, |
94 | MlirStringRef name) { |
95 | return wrap(cpp: unwrap(c: context)->getOrLoadDialect(name: unwrap(ref: name))); |
96 | } |
97 | |
98 | bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) { |
99 | return unwrap(c: context)->isOperationRegistered(name: unwrap(ref: name)); |
100 | } |
101 | |
102 | void mlirContextEnableMultithreading(MlirContext context, bool enable) { |
103 | return unwrap(c: context)->enableMultithreading(enable); |
104 | } |
105 | |
106 | void mlirContextLoadAllAvailableDialects(MlirContext context) { |
107 | unwrap(c: context)->loadAllAvailableDialects(); |
108 | } |
109 | |
110 | void mlirContextSetThreadPool(MlirContext context, |
111 | MlirLlvmThreadPool threadPool) { |
112 | unwrap(c: context)->setThreadPool(*unwrap(c: threadPool)); |
113 | } |
114 | |
115 | //===----------------------------------------------------------------------===// |
116 | // Dialect API. |
117 | //===----------------------------------------------------------------------===// |
118 | |
119 | MlirContext mlirDialectGetContext(MlirDialect dialect) { |
120 | return wrap(cpp: unwrap(c: dialect)->getContext()); |
121 | } |
122 | |
123 | bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) { |
124 | return unwrap(c: dialect1) == unwrap(c: dialect2); |
125 | } |
126 | |
127 | MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) { |
128 | return wrap(ref: unwrap(c: dialect)->getNamespace()); |
129 | } |
130 | |
131 | //===----------------------------------------------------------------------===// |
132 | // DialectRegistry API. |
133 | //===----------------------------------------------------------------------===// |
134 | |
135 | MlirDialectRegistry mlirDialectRegistryCreate() { |
136 | return wrap(cpp: new DialectRegistry()); |
137 | } |
138 | |
139 | void mlirDialectRegistryDestroy(MlirDialectRegistry registry) { |
140 | delete unwrap(c: registry); |
141 | } |
142 | |
143 | //===----------------------------------------------------------------------===// |
144 | // AsmState API. |
145 | //===----------------------------------------------------------------------===// |
146 | |
147 | MlirAsmState mlirAsmStateCreateForOperation(MlirOperation op, |
148 | MlirOpPrintingFlags flags) { |
149 | return wrap(cpp: new AsmState(unwrap(c: op), *unwrap(c: flags))); |
150 | } |
151 | |
152 | static Operation *findParent(Operation *op, bool shouldUseLocalScope) { |
153 | do { |
154 | // If we are printing local scope, stop at the first operation that is |
155 | // isolated from above. |
156 | if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>()) |
157 | break; |
158 | |
159 | // Otherwise, traverse up to the next parent. |
160 | Operation *parentOp = op->getParentOp(); |
161 | if (!parentOp) |
162 | break; |
163 | op = parentOp; |
164 | } while (true); |
165 | return op; |
166 | } |
167 | |
168 | MlirAsmState mlirAsmStateCreateForValue(MlirValue value, |
169 | MlirOpPrintingFlags flags) { |
170 | Operation *op; |
171 | mlir::Value val = unwrap(c: value); |
172 | if (auto result = llvm::dyn_cast<OpResult>(Val&: val)) { |
173 | op = result.getOwner(); |
174 | } else { |
175 | op = llvm::cast<BlockArgument>(Val&: val).getOwner()->getParentOp(); |
176 | if (!op) { |
177 | emitError(loc: val.getLoc()) << "<<UNKNOWN SSA VALUE>>" ; |
178 | return {.ptr: nullptr}; |
179 | } |
180 | } |
181 | op = findParent(op, shouldUseLocalScope: unwrap(c: flags)->shouldUseLocalScope()); |
182 | return wrap(cpp: new AsmState(op, *unwrap(c: flags))); |
183 | } |
184 | |
185 | /// Destroys printing flags created with mlirAsmStateCreate. |
186 | void mlirAsmStateDestroy(MlirAsmState state) { delete unwrap(c: state); } |
187 | |
188 | //===----------------------------------------------------------------------===// |
189 | // Printing flags API. |
190 | //===----------------------------------------------------------------------===// |
191 | |
192 | MlirOpPrintingFlags mlirOpPrintingFlagsCreate() { |
193 | return wrap(cpp: new OpPrintingFlags()); |
194 | } |
195 | |
196 | void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) { |
197 | delete unwrap(c: flags); |
198 | } |
199 | |
200 | void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, |
201 | intptr_t largeElementLimit) { |
202 | unwrap(c: flags)->elideLargeElementsAttrs(largeElementLimit); |
203 | } |
204 | |
205 | void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, |
206 | bool prettyForm) { |
207 | unwrap(c: flags)->enableDebugInfo(enable, /*prettyForm=*/prettyForm); |
208 | } |
209 | |
210 | void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) { |
211 | unwrap(c: flags)->printGenericOpForm(); |
212 | } |
213 | |
214 | void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) { |
215 | unwrap(c: flags)->useLocalScope(); |
216 | } |
217 | |
218 | void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) { |
219 | unwrap(c: flags)->assumeVerified(); |
220 | } |
221 | |
222 | //===----------------------------------------------------------------------===// |
223 | // Bytecode printing flags API. |
224 | //===----------------------------------------------------------------------===// |
225 | |
226 | MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate() { |
227 | return wrap(cpp: new BytecodeWriterConfig()); |
228 | } |
229 | |
230 | void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config) { |
231 | delete unwrap(c: config); |
232 | } |
233 | |
234 | void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, |
235 | int64_t version) { |
236 | unwrap(c: flags)->setDesiredBytecodeVersion(version); |
237 | } |
238 | |
239 | //===----------------------------------------------------------------------===// |
240 | // Location API. |
241 | //===----------------------------------------------------------------------===// |
242 | |
243 | MlirAttribute mlirLocationGetAttribute(MlirLocation location) { |
244 | return wrap(cpp: LocationAttr(unwrap(c: location))); |
245 | } |
246 | |
247 | MlirLocation mlirLocationFromAttribute(MlirAttribute attribute) { |
248 | return wrap(cpp: Location(llvm::cast<LocationAttr>(Val: unwrap(c: attribute)))); |
249 | } |
250 | |
251 | MlirLocation mlirLocationFileLineColGet(MlirContext context, |
252 | MlirStringRef filename, unsigned line, |
253 | unsigned col) { |
254 | return wrap(Location( |
255 | FileLineColLoc::get(unwrap(context), unwrap(filename), line, col))); |
256 | } |
257 | |
258 | MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) { |
259 | return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller)))); |
260 | } |
261 | |
262 | MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, |
263 | MlirLocation const *locations, |
264 | MlirAttribute metadata) { |
265 | SmallVector<Location, 4> locs; |
266 | ArrayRef<Location> unwrappedLocs = unwrapList(size: nLocations, first: locations, storage&: locs); |
267 | return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx))); |
268 | } |
269 | |
270 | MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, |
271 | MlirLocation childLoc) { |
272 | if (mlirLocationIsNull(childLoc)) |
273 | return wrap( |
274 | Location(NameLoc::get(StringAttr::get(unwrap(context), unwrap(name))))); |
275 | return wrap(Location(NameLoc::get( |
276 | StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc)))); |
277 | } |
278 | |
279 | MlirLocation mlirLocationUnknownGet(MlirContext context) { |
280 | return wrap(Location(UnknownLoc::get(unwrap(context)))); |
281 | } |
282 | |
283 | bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) { |
284 | return unwrap(c: l1) == unwrap(c: l2); |
285 | } |
286 | |
287 | MlirContext mlirLocationGetContext(MlirLocation location) { |
288 | return wrap(cpp: unwrap(c: location).getContext()); |
289 | } |
290 | |
291 | void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, |
292 | void *userData) { |
293 | detail::CallbackOstream stream(callback, userData); |
294 | unwrap(c: location).print(os&: stream); |
295 | } |
296 | |
297 | //===----------------------------------------------------------------------===// |
298 | // Module API. |
299 | //===----------------------------------------------------------------------===// |
300 | |
301 | MlirModule mlirModuleCreateEmpty(MlirLocation location) { |
302 | return wrap(ModuleOp::create(unwrap(location))); |
303 | } |
304 | |
305 | MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) { |
306 | OwningOpRef<ModuleOp> owning = |
307 | parseSourceString<ModuleOp>(unwrap(ref: module), unwrap(c: context)); |
308 | if (!owning) |
309 | return MlirModule{.ptr: nullptr}; |
310 | return MlirModule{owning.release().getOperation()}; |
311 | } |
312 | |
313 | MlirContext mlirModuleGetContext(MlirModule module) { |
314 | return wrap(unwrap(module).getContext()); |
315 | } |
316 | |
317 | MlirBlock mlirModuleGetBody(MlirModule module) { |
318 | return wrap(unwrap(module).getBody()); |
319 | } |
320 | |
321 | void mlirModuleDestroy(MlirModule module) { |
322 | // Transfer ownership to an OwningOpRef<ModuleOp> so that its destructor is |
323 | // called. |
324 | OwningOpRef<ModuleOp>(unwrap(module)); |
325 | } |
326 | |
327 | MlirOperation mlirModuleGetOperation(MlirModule module) { |
328 | return wrap(unwrap(module).getOperation()); |
329 | } |
330 | |
331 | MlirModule mlirModuleFromOperation(MlirOperation op) { |
332 | return wrap(dyn_cast<ModuleOp>(unwrap(c: op))); |
333 | } |
334 | |
335 | //===----------------------------------------------------------------------===// |
336 | // Operation state API. |
337 | //===----------------------------------------------------------------------===// |
338 | |
339 | MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) { |
340 | MlirOperationState state; |
341 | state.name = name; |
342 | state.location = loc; |
343 | state.nResults = 0; |
344 | state.results = nullptr; |
345 | state.nOperands = 0; |
346 | state.operands = nullptr; |
347 | state.nRegions = 0; |
348 | state.regions = nullptr; |
349 | state.nSuccessors = 0; |
350 | state.successors = nullptr; |
351 | state.nAttributes = 0; |
352 | state.attributes = nullptr; |
353 | state.enableResultTypeInference = false; |
354 | return state; |
355 | } |
356 | |
357 | #define APPEND_ELEMS(type, sizeName, elemName) \ |
358 | state->elemName = \ |
359 | (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ |
360 | memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ |
361 | state->sizeName += n; |
362 | |
363 | void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, |
364 | MlirType const *results) { |
365 | APPEND_ELEMS(MlirType, nResults, results); |
366 | } |
367 | |
368 | void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, |
369 | MlirValue const *operands) { |
370 | APPEND_ELEMS(MlirValue, nOperands, operands); |
371 | } |
372 | void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, |
373 | MlirRegion const *regions) { |
374 | APPEND_ELEMS(MlirRegion, nRegions, regions); |
375 | } |
376 | void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, |
377 | MlirBlock const *successors) { |
378 | APPEND_ELEMS(MlirBlock, nSuccessors, successors); |
379 | } |
380 | void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, |
381 | MlirNamedAttribute const *attributes) { |
382 | APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); |
383 | } |
384 | |
385 | void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) { |
386 | state->enableResultTypeInference = true; |
387 | } |
388 | |
389 | //===----------------------------------------------------------------------===// |
390 | // Operation API. |
391 | //===----------------------------------------------------------------------===// |
392 | |
393 | static LogicalResult inferOperationTypes(OperationState &state) { |
394 | MLIRContext *context = state.getContext(); |
395 | std::optional<RegisteredOperationName> info = state.name.getRegisteredInfo(); |
396 | if (!info) { |
397 | emitError(loc: state.location) |
398 | << "type inference was requested for the operation " << state.name |
399 | << ", but the operation was not registered; ensure that the dialect " |
400 | "containing the operation is linked into MLIR and registered with " |
401 | "the context" ; |
402 | return failure(); |
403 | } |
404 | |
405 | auto *inferInterface = info->getInterface<InferTypeOpInterface>(); |
406 | if (!inferInterface) { |
407 | emitError(loc: state.location) |
408 | << "type inference was requested for the operation " << state.name |
409 | << ", but the operation does not support type inference; result " |
410 | "types must be specified explicitly" ; |
411 | return failure(); |
412 | } |
413 | |
414 | DictionaryAttr attributes = state.attributes.getDictionary(context); |
415 | OpaqueProperties properties = state.getRawProperties(); |
416 | |
417 | if (!properties && info->getOpPropertyByteSize() > 0 && !attributes.empty()) { |
418 | auto prop = std::make_unique<char[]>(num: info->getOpPropertyByteSize()); |
419 | properties = OpaqueProperties(prop.get()); |
420 | if (properties) { |
421 | auto emitError = [&]() { |
422 | return mlir::emitError(state.location) |
423 | << " failed properties conversion while building " |
424 | << state.name.getStringRef() << " with `" << attributes << "`: " ; |
425 | }; |
426 | if (failed(info->setOpPropertiesFromAttribute(opName: state.name, properties, |
427 | attr: attributes, emitError))) |
428 | return failure(); |
429 | } |
430 | if (succeeded(inferInterface->inferReturnTypes( |
431 | context, state.location, state.operands, attributes, properties, |
432 | state.regions, state.types))) { |
433 | return success(); |
434 | } |
435 | // Diagnostic emitted by interface. |
436 | return failure(); |
437 | } |
438 | |
439 | if (succeeded(inferInterface->inferReturnTypes( |
440 | context, state.location, state.operands, attributes, properties, |
441 | state.regions, state.types))) |
442 | return success(); |
443 | |
444 | // Diagnostic emitted by interface. |
445 | return failure(); |
446 | } |
447 | |
448 | MlirOperation mlirOperationCreate(MlirOperationState *state) { |
449 | assert(state); |
450 | OperationState cppState(unwrap(c: state->location), unwrap(ref: state->name)); |
451 | SmallVector<Type, 4> resultStorage; |
452 | SmallVector<Value, 8> operandStorage; |
453 | SmallVector<Block *, 2> successorStorage; |
454 | cppState.addTypes(newTypes: unwrapList(size: state->nResults, first: state->results, storage&: resultStorage)); |
455 | cppState.addOperands( |
456 | newOperands: unwrapList(size: state->nOperands, first: state->operands, storage&: operandStorage)); |
457 | cppState.addSuccessors( |
458 | newSuccessors: unwrapList(size: state->nSuccessors, first: state->successors, storage&: successorStorage)); |
459 | |
460 | cppState.attributes.reserve(N: state->nAttributes); |
461 | for (intptr_t i = 0; i < state->nAttributes; ++i) |
462 | cppState.addAttribute(unwrap(state->attributes[i].name), |
463 | unwrap(c: state->attributes[i].attribute)); |
464 | |
465 | for (intptr_t i = 0; i < state->nRegions; ++i) |
466 | cppState.addRegion(region: std::unique_ptr<Region>(unwrap(c: state->regions[i]))); |
467 | |
468 | free(ptr: state->results); |
469 | free(ptr: state->operands); |
470 | free(ptr: state->successors); |
471 | free(ptr: state->regions); |
472 | free(ptr: state->attributes); |
473 | |
474 | // Infer result types. |
475 | if (state->enableResultTypeInference) { |
476 | assert(cppState.types.empty() && |
477 | "result type inference enabled and result types provided" ); |
478 | if (failed(result: inferOperationTypes(state&: cppState))) |
479 | return {.ptr: nullptr}; |
480 | } |
481 | |
482 | return wrap(cpp: Operation::create(state: cppState)); |
483 | } |
484 | |
485 | MlirOperation mlirOperationCreateParse(MlirContext context, |
486 | MlirStringRef sourceStr, |
487 | MlirStringRef sourceName) { |
488 | |
489 | return wrap( |
490 | cpp: parseSourceString(sourceStr: unwrap(ref: sourceStr), config: unwrap(c: context), sourceName: unwrap(ref: sourceName)) |
491 | .release()); |
492 | } |
493 | |
494 | MlirOperation mlirOperationClone(MlirOperation op) { |
495 | return wrap(cpp: unwrap(c: op)->clone()); |
496 | } |
497 | |
498 | void mlirOperationDestroy(MlirOperation op) { unwrap(c: op)->erase(); } |
499 | |
500 | void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(c: op)->remove(); } |
501 | |
502 | bool mlirOperationEqual(MlirOperation op, MlirOperation other) { |
503 | return unwrap(c: op) == unwrap(c: other); |
504 | } |
505 | |
506 | MlirContext mlirOperationGetContext(MlirOperation op) { |
507 | return wrap(cpp: unwrap(c: op)->getContext()); |
508 | } |
509 | |
510 | MlirLocation mlirOperationGetLocation(MlirOperation op) { |
511 | return wrap(cpp: unwrap(c: op)->getLoc()); |
512 | } |
513 | |
514 | MlirTypeID mlirOperationGetTypeID(MlirOperation op) { |
515 | if (auto info = unwrap(c: op)->getRegisteredInfo()) |
516 | return wrap(cpp: info->getTypeID()); |
517 | return {.ptr: nullptr}; |
518 | } |
519 | |
520 | MlirIdentifier mlirOperationGetName(MlirOperation op) { |
521 | return wrap(unwrap(c: op)->getName().getIdentifier()); |
522 | } |
523 | |
524 | MlirBlock mlirOperationGetBlock(MlirOperation op) { |
525 | return wrap(cpp: unwrap(c: op)->getBlock()); |
526 | } |
527 | |
528 | MlirOperation mlirOperationGetParentOperation(MlirOperation op) { |
529 | return wrap(cpp: unwrap(c: op)->getParentOp()); |
530 | } |
531 | |
532 | intptr_t mlirOperationGetNumRegions(MlirOperation op) { |
533 | return static_cast<intptr_t>(unwrap(c: op)->getNumRegions()); |
534 | } |
535 | |
536 | MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { |
537 | return wrap(cpp: &unwrap(c: op)->getRegion(index: static_cast<unsigned>(pos))); |
538 | } |
539 | |
540 | MlirRegion mlirOperationGetFirstRegion(MlirOperation op) { |
541 | Operation *cppOp = unwrap(c: op); |
542 | if (cppOp->getNumRegions() == 0) |
543 | return wrap(cpp: static_cast<Region *>(nullptr)); |
544 | return wrap(cpp: &cppOp->getRegion(index: 0)); |
545 | } |
546 | |
547 | MlirRegion mlirRegionGetNextInOperation(MlirRegion region) { |
548 | Region *cppRegion = unwrap(c: region); |
549 | Operation *parent = cppRegion->getParentOp(); |
550 | intptr_t next = cppRegion->getRegionNumber() + 1; |
551 | if (parent->getNumRegions() > next) |
552 | return wrap(cpp: &parent->getRegion(index: next)); |
553 | return wrap(cpp: static_cast<Region *>(nullptr)); |
554 | } |
555 | |
556 | MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { |
557 | return wrap(unwrap(c: op)->getNextNode()); |
558 | } |
559 | |
560 | intptr_t mlirOperationGetNumOperands(MlirOperation op) { |
561 | return static_cast<intptr_t>(unwrap(c: op)->getNumOperands()); |
562 | } |
563 | |
564 | MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { |
565 | return wrap(cpp: unwrap(c: op)->getOperand(idx: static_cast<unsigned>(pos))); |
566 | } |
567 | |
568 | void mlirOperationSetOperand(MlirOperation op, intptr_t pos, |
569 | MlirValue newValue) { |
570 | unwrap(c: op)->setOperand(idx: static_cast<unsigned>(pos), value: unwrap(c: newValue)); |
571 | } |
572 | |
573 | void mlirOperationSetOperands(MlirOperation op, intptr_t nOperands, |
574 | MlirValue const *operands) { |
575 | SmallVector<Value> ops; |
576 | unwrap(c: op)->setOperands(unwrapList(size: nOperands, first: operands, storage&: ops)); |
577 | } |
578 | |
579 | intptr_t mlirOperationGetNumResults(MlirOperation op) { |
580 | return static_cast<intptr_t>(unwrap(c: op)->getNumResults()); |
581 | } |
582 | |
583 | MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { |
584 | return wrap(cpp: unwrap(c: op)->getResult(idx: static_cast<unsigned>(pos))); |
585 | } |
586 | |
587 | intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { |
588 | return static_cast<intptr_t>(unwrap(c: op)->getNumSuccessors()); |
589 | } |
590 | |
591 | MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { |
592 | return wrap(cpp: unwrap(c: op)->getSuccessor(index: static_cast<unsigned>(pos))); |
593 | } |
594 | |
595 | MLIR_CAPI_EXPORTED bool |
596 | mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name) { |
597 | std::optional<Attribute> attr = unwrap(c: op)->getInherentAttr(name: unwrap(ref: name)); |
598 | return attr.has_value(); |
599 | } |
600 | |
601 | MlirAttribute mlirOperationGetInherentAttributeByName(MlirOperation op, |
602 | MlirStringRef name) { |
603 | std::optional<Attribute> attr = unwrap(c: op)->getInherentAttr(name: unwrap(ref: name)); |
604 | if (attr.has_value()) |
605 | return wrap(cpp: *attr); |
606 | return {}; |
607 | } |
608 | |
609 | void mlirOperationSetInherentAttributeByName(MlirOperation op, |
610 | MlirStringRef name, |
611 | MlirAttribute attr) { |
612 | unwrap(op)->setInherentAttr( |
613 | StringAttr::get(unwrap(op)->getContext(), unwrap(name)), unwrap(attr)); |
614 | } |
615 | |
616 | intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) { |
617 | return static_cast<intptr_t>( |
618 | llvm::range_size(unwrap(c: op)->getDiscardableAttrs())); |
619 | } |
620 | |
621 | MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op, |
622 | intptr_t pos) { |
623 | NamedAttribute attr = |
624 | *std::next(unwrap(c: op)->getDiscardableAttrs().begin(), pos); |
625 | return MlirNamedAttribute{wrap(attr.getName()), wrap(cpp: attr.getValue())}; |
626 | } |
627 | |
628 | MlirAttribute mlirOperationGetDiscardableAttributeByName(MlirOperation op, |
629 | MlirStringRef name) { |
630 | return wrap(cpp: unwrap(c: op)->getDiscardableAttr(name: unwrap(ref: name))); |
631 | } |
632 | |
633 | void mlirOperationSetDiscardableAttributeByName(MlirOperation op, |
634 | MlirStringRef name, |
635 | MlirAttribute attr) { |
636 | unwrap(c: op)->setDiscardableAttr(name: unwrap(ref: name), value: unwrap(c: attr)); |
637 | } |
638 | |
639 | bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op, |
640 | MlirStringRef name) { |
641 | return !!unwrap(c: op)->removeDiscardableAttr(name: unwrap(ref: name)); |
642 | } |
643 | |
644 | void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, |
645 | MlirBlock block) { |
646 | unwrap(c: op)->setSuccessor(block: unwrap(c: block), index: static_cast<unsigned>(pos)); |
647 | } |
648 | |
649 | intptr_t mlirOperationGetNumAttributes(MlirOperation op) { |
650 | return static_cast<intptr_t>(unwrap(c: op)->getAttrs().size()); |
651 | } |
652 | |
653 | MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { |
654 | NamedAttribute attr = unwrap(c: op)->getAttrs()[pos]; |
655 | return MlirNamedAttribute{wrap(attr.getName()), wrap(cpp: attr.getValue())}; |
656 | } |
657 | |
658 | MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, |
659 | MlirStringRef name) { |
660 | return wrap(cpp: unwrap(c: op)->getAttr(name: unwrap(ref: name))); |
661 | } |
662 | |
663 | void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, |
664 | MlirAttribute attr) { |
665 | unwrap(c: op)->setAttr(name: unwrap(ref: name), value: unwrap(c: attr)); |
666 | } |
667 | |
668 | bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) { |
669 | return !!unwrap(c: op)->removeAttr(name: unwrap(ref: name)); |
670 | } |
671 | |
672 | void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, |
673 | void *userData) { |
674 | detail::CallbackOstream stream(callback, userData); |
675 | unwrap(c: op)->print(os&: stream); |
676 | } |
677 | |
678 | void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, |
679 | MlirStringCallback callback, void *userData) { |
680 | detail::CallbackOstream stream(callback, userData); |
681 | unwrap(c: op)->print(os&: stream, flags: *unwrap(c: flags)); |
682 | } |
683 | |
684 | void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, |
685 | MlirStringCallback callback, void *userData) { |
686 | detail::CallbackOstream stream(callback, userData); |
687 | if (state.ptr) |
688 | unwrap(c: op)->print(os&: stream, state&: *unwrap(c: state)); |
689 | unwrap(c: op)->print(os&: stream); |
690 | } |
691 | |
692 | void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, |
693 | void *userData) { |
694 | detail::CallbackOstream stream(callback, userData); |
695 | // As no desired version is set, no failure can occur. |
696 | (void)writeBytecodeToFile(op: unwrap(c: op), os&: stream); |
697 | } |
698 | |
699 | MlirLogicalResult mlirOperationWriteBytecodeWithConfig( |
700 | MlirOperation op, MlirBytecodeWriterConfig config, |
701 | MlirStringCallback callback, void *userData) { |
702 | detail::CallbackOstream stream(callback, userData); |
703 | return wrap(res: writeBytecodeToFile(op: unwrap(c: op), os&: stream, config: *unwrap(c: config))); |
704 | } |
705 | |
706 | void mlirOperationDump(MlirOperation op) { return unwrap(c: op)->dump(); } |
707 | |
708 | bool mlirOperationVerify(MlirOperation op) { |
709 | return succeeded(result: verify(op: unwrap(c: op))); |
710 | } |
711 | |
712 | void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) { |
713 | return unwrap(c: op)->moveAfter(existingOp: unwrap(c: other)); |
714 | } |
715 | |
716 | void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) { |
717 | return unwrap(c: op)->moveBefore(existingOp: unwrap(c: other)); |
718 | } |
719 | |
720 | static mlir::WalkResult unwrap(MlirWalkResult result) { |
721 | switch (result) { |
722 | case MlirWalkResultAdvance: |
723 | return mlir::WalkResult::advance(); |
724 | |
725 | case MlirWalkResultInterrupt: |
726 | return mlir::WalkResult::interrupt(); |
727 | |
728 | case MlirWalkResultSkip: |
729 | return mlir::WalkResult::skip(); |
730 | } |
731 | } |
732 | |
733 | void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, |
734 | void *userData, MlirWalkOrder walkOrder) { |
735 | switch (walkOrder) { |
736 | |
737 | case MlirWalkPreOrder: |
738 | unwrap(c: op)->walk<mlir::WalkOrder::PreOrder>( |
739 | callback: [callback, userData](Operation *op) { |
740 | return unwrap(result: callback(wrap(cpp: op), userData)); |
741 | }); |
742 | break; |
743 | case MlirWalkPostOrder: |
744 | unwrap(c: op)->walk<mlir::WalkOrder::PostOrder>( |
745 | callback: [callback, userData](Operation *op) { |
746 | return unwrap(result: callback(wrap(cpp: op), userData)); |
747 | }); |
748 | } |
749 | } |
750 | |
751 | //===----------------------------------------------------------------------===// |
752 | // Region API. |
753 | //===----------------------------------------------------------------------===// |
754 | |
755 | MlirRegion mlirRegionCreate() { return wrap(cpp: new Region); } |
756 | |
757 | bool mlirRegionEqual(MlirRegion region, MlirRegion other) { |
758 | return unwrap(c: region) == unwrap(c: other); |
759 | } |
760 | |
761 | MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { |
762 | Region *cppRegion = unwrap(c: region); |
763 | if (cppRegion->empty()) |
764 | return wrap(cpp: static_cast<Block *>(nullptr)); |
765 | return wrap(cpp: &cppRegion->front()); |
766 | } |
767 | |
768 | void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { |
769 | unwrap(c: region)->push_back(block: unwrap(c: block)); |
770 | } |
771 | |
772 | void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, |
773 | MlirBlock block) { |
774 | auto &blockList = unwrap(c: region)->getBlocks(); |
775 | blockList.insert(where: std::next(x: blockList.begin(), n: pos), New: unwrap(c: block)); |
776 | } |
777 | |
778 | void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, |
779 | MlirBlock block) { |
780 | Region *cppRegion = unwrap(c: region); |
781 | if (mlirBlockIsNull(block: reference)) { |
782 | cppRegion->getBlocks().insert(where: cppRegion->begin(), New: unwrap(c: block)); |
783 | return; |
784 | } |
785 | |
786 | assert(unwrap(reference)->getParent() == unwrap(region) && |
787 | "expected reference block to belong to the region" ); |
788 | cppRegion->getBlocks().insertAfter(where: Region::iterator(unwrap(c: reference)), |
789 | New: unwrap(c: block)); |
790 | } |
791 | |
792 | void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, |
793 | MlirBlock block) { |
794 | if (mlirBlockIsNull(block: reference)) |
795 | return mlirRegionAppendOwnedBlock(region, block); |
796 | |
797 | assert(unwrap(reference)->getParent() == unwrap(region) && |
798 | "expected reference block to belong to the region" ); |
799 | unwrap(c: region)->getBlocks().insert(where: Region::iterator(unwrap(c: reference)), |
800 | New: unwrap(c: block)); |
801 | } |
802 | |
803 | void mlirRegionDestroy(MlirRegion region) { |
804 | delete static_cast<Region *>(region.ptr); |
805 | } |
806 | |
807 | void mlirRegionTakeBody(MlirRegion target, MlirRegion source) { |
808 | unwrap(c: target)->takeBody(other&: *unwrap(c: source)); |
809 | } |
810 | |
811 | //===----------------------------------------------------------------------===// |
812 | // Block API. |
813 | //===----------------------------------------------------------------------===// |
814 | |
815 | MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, |
816 | MlirLocation const *locs) { |
817 | Block *b = new Block; |
818 | for (intptr_t i = 0; i < nArgs; ++i) |
819 | b->addArgument(type: unwrap(c: args[i]), loc: unwrap(c: locs[i])); |
820 | return wrap(cpp: b); |
821 | } |
822 | |
823 | bool mlirBlockEqual(MlirBlock block, MlirBlock other) { |
824 | return unwrap(c: block) == unwrap(c: other); |
825 | } |
826 | |
827 | MlirOperation mlirBlockGetParentOperation(MlirBlock block) { |
828 | return wrap(cpp: unwrap(c: block)->getParentOp()); |
829 | } |
830 | |
831 | MlirRegion mlirBlockGetParentRegion(MlirBlock block) { |
832 | return wrap(cpp: unwrap(c: block)->getParent()); |
833 | } |
834 | |
835 | MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { |
836 | return wrap(cpp: unwrap(c: block)->getNextNode()); |
837 | } |
838 | |
839 | MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { |
840 | Block *cppBlock = unwrap(c: block); |
841 | if (cppBlock->empty()) |
842 | return wrap(cpp: static_cast<Operation *>(nullptr)); |
843 | return wrap(cpp: &cppBlock->front()); |
844 | } |
845 | |
846 | MlirOperation mlirBlockGetTerminator(MlirBlock block) { |
847 | Block *cppBlock = unwrap(c: block); |
848 | if (cppBlock->empty()) |
849 | return wrap(cpp: static_cast<Operation *>(nullptr)); |
850 | Operation &back = cppBlock->back(); |
851 | if (!back.hasTrait<OpTrait::IsTerminator>()) |
852 | return wrap(cpp: static_cast<Operation *>(nullptr)); |
853 | return wrap(cpp: &back); |
854 | } |
855 | |
856 | void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { |
857 | unwrap(c: block)->push_back(op: unwrap(c: operation)); |
858 | } |
859 | |
860 | void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, |
861 | MlirOperation operation) { |
862 | auto &opList = unwrap(c: block)->getOperations(); |
863 | opList.insert(where: std::next(x: opList.begin(), n: pos), New: unwrap(c: operation)); |
864 | } |
865 | |
866 | void mlirBlockInsertOwnedOperationAfter(MlirBlock block, |
867 | MlirOperation reference, |
868 | MlirOperation operation) { |
869 | Block *cppBlock = unwrap(c: block); |
870 | if (mlirOperationIsNull(op: reference)) { |
871 | cppBlock->getOperations().insert(where: cppBlock->begin(), New: unwrap(c: operation)); |
872 | return; |
873 | } |
874 | |
875 | assert(unwrap(reference)->getBlock() == unwrap(block) && |
876 | "expected reference operation to belong to the block" ); |
877 | cppBlock->getOperations().insertAfter(where: Block::iterator(unwrap(c: reference)), |
878 | New: unwrap(c: operation)); |
879 | } |
880 | |
881 | void mlirBlockInsertOwnedOperationBefore(MlirBlock block, |
882 | MlirOperation reference, |
883 | MlirOperation operation) { |
884 | if (mlirOperationIsNull(op: reference)) |
885 | return mlirBlockAppendOwnedOperation(block, operation); |
886 | |
887 | assert(unwrap(reference)->getBlock() == unwrap(block) && |
888 | "expected reference operation to belong to the block" ); |
889 | unwrap(c: block)->getOperations().insert(where: Block::iterator(unwrap(c: reference)), |
890 | New: unwrap(c: operation)); |
891 | } |
892 | |
893 | void mlirBlockDestroy(MlirBlock block) { delete unwrap(c: block); } |
894 | |
895 | void mlirBlockDetach(MlirBlock block) { |
896 | Block *b = unwrap(c: block); |
897 | b->getParent()->getBlocks().remove(IT: b); |
898 | } |
899 | |
900 | intptr_t mlirBlockGetNumArguments(MlirBlock block) { |
901 | return static_cast<intptr_t>(unwrap(c: block)->getNumArguments()); |
902 | } |
903 | |
904 | MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, |
905 | MlirLocation loc) { |
906 | return wrap(cpp: unwrap(c: block)->addArgument(type: unwrap(c: type), loc: unwrap(c: loc))); |
907 | } |
908 | |
909 | MlirValue mlirBlockInsertArgument(MlirBlock block, intptr_t pos, MlirType type, |
910 | MlirLocation loc) { |
911 | return wrap(cpp: unwrap(c: block)->insertArgument(index: pos, type: unwrap(c: type), loc: unwrap(c: loc))); |
912 | } |
913 | |
914 | MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { |
915 | return wrap(cpp: unwrap(c: block)->getArgument(i: static_cast<unsigned>(pos))); |
916 | } |
917 | |
918 | void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, |
919 | void *userData) { |
920 | detail::CallbackOstream stream(callback, userData); |
921 | unwrap(c: block)->print(os&: stream); |
922 | } |
923 | |
924 | //===----------------------------------------------------------------------===// |
925 | // Value API. |
926 | //===----------------------------------------------------------------------===// |
927 | |
928 | bool mlirValueEqual(MlirValue value1, MlirValue value2) { |
929 | return unwrap(c: value1) == unwrap(c: value2); |
930 | } |
931 | |
932 | bool mlirValueIsABlockArgument(MlirValue value) { |
933 | return llvm::isa<BlockArgument>(Val: unwrap(c: value)); |
934 | } |
935 | |
936 | bool mlirValueIsAOpResult(MlirValue value) { |
937 | return llvm::isa<OpResult>(Val: unwrap(c: value)); |
938 | } |
939 | |
940 | MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { |
941 | return wrap(cpp: llvm::cast<BlockArgument>(Val: unwrap(c: value)).getOwner()); |
942 | } |
943 | |
944 | intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) { |
945 | return static_cast<intptr_t>( |
946 | llvm::cast<BlockArgument>(Val: unwrap(c: value)).getArgNumber()); |
947 | } |
948 | |
949 | void mlirBlockArgumentSetType(MlirValue value, MlirType type) { |
950 | llvm::cast<BlockArgument>(Val: unwrap(c: value)).setType(unwrap(c: type)); |
951 | } |
952 | |
953 | MlirOperation mlirOpResultGetOwner(MlirValue value) { |
954 | return wrap(cpp: llvm::cast<OpResult>(Val: unwrap(c: value)).getOwner()); |
955 | } |
956 | |
957 | intptr_t mlirOpResultGetResultNumber(MlirValue value) { |
958 | return static_cast<intptr_t>( |
959 | llvm::cast<OpResult>(Val: unwrap(c: value)).getResultNumber()); |
960 | } |
961 | |
962 | MlirType mlirValueGetType(MlirValue value) { |
963 | return wrap(cpp: unwrap(c: value).getType()); |
964 | } |
965 | |
966 | void mlirValueSetType(MlirValue value, MlirType type) { |
967 | unwrap(c: value).setType(unwrap(c: type)); |
968 | } |
969 | |
970 | void mlirValueDump(MlirValue value) { unwrap(c: value).dump(); } |
971 | |
972 | void mlirValuePrint(MlirValue value, MlirStringCallback callback, |
973 | void *userData) { |
974 | detail::CallbackOstream stream(callback, userData); |
975 | unwrap(c: value).print(os&: stream); |
976 | } |
977 | |
978 | void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, |
979 | MlirStringCallback callback, void *userData) { |
980 | detail::CallbackOstream stream(callback, userData); |
981 | Value cppValue = unwrap(c: value); |
982 | cppValue.printAsOperand(os&: stream, state&: *unwrap(c: state)); |
983 | } |
984 | |
985 | MlirOpOperand mlirValueGetFirstUse(MlirValue value) { |
986 | Value cppValue = unwrap(c: value); |
987 | if (cppValue.use_empty()) |
988 | return {}; |
989 | |
990 | OpOperand *opOperand = cppValue.use_begin().getOperand(); |
991 | |
992 | return wrap(cpp: opOperand); |
993 | } |
994 | |
995 | void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) { |
996 | unwrap(c: oldValue).replaceAllUsesWith(newValue: unwrap(c: newValue)); |
997 | } |
998 | |
999 | //===----------------------------------------------------------------------===// |
1000 | // OpOperand API. |
1001 | //===----------------------------------------------------------------------===// |
1002 | |
1003 | bool mlirOpOperandIsNull(MlirOpOperand opOperand) { return !opOperand.ptr; } |
1004 | |
1005 | MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand) { |
1006 | return wrap(cpp: unwrap(c: opOperand)->getOwner()); |
1007 | } |
1008 | |
1009 | MlirValue mlirOpOperandGetValue(MlirOpOperand opOperand) { |
1010 | return wrap(cpp: unwrap(c: opOperand)->get()); |
1011 | } |
1012 | |
1013 | unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand) { |
1014 | return unwrap(c: opOperand)->getOperandNumber(); |
1015 | } |
1016 | |
1017 | MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand) { |
1018 | if (mlirOpOperandIsNull(opOperand)) |
1019 | return {}; |
1020 | |
1021 | OpOperand *nextOpOperand = static_cast<OpOperand *>( |
1022 | unwrap(c: opOperand)->getNextOperandUsingThisValue()); |
1023 | |
1024 | if (!nextOpOperand) |
1025 | return {}; |
1026 | |
1027 | return wrap(cpp: nextOpOperand); |
1028 | } |
1029 | |
1030 | //===----------------------------------------------------------------------===// |
1031 | // Type API. |
1032 | //===----------------------------------------------------------------------===// |
1033 | |
1034 | MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) { |
1035 | return wrap(cpp: mlir::parseType(typeStr: unwrap(ref: type), context: unwrap(c: context))); |
1036 | } |
1037 | |
1038 | MlirContext mlirTypeGetContext(MlirType type) { |
1039 | return wrap(cpp: unwrap(c: type).getContext()); |
1040 | } |
1041 | |
1042 | MlirTypeID mlirTypeGetTypeID(MlirType type) { |
1043 | return wrap(cpp: unwrap(c: type).getTypeID()); |
1044 | } |
1045 | |
1046 | MlirDialect mlirTypeGetDialect(MlirType type) { |
1047 | return wrap(cpp: &unwrap(c: type).getDialect()); |
1048 | } |
1049 | |
1050 | bool mlirTypeEqual(MlirType t1, MlirType t2) { |
1051 | return unwrap(c: t1) == unwrap(c: t2); |
1052 | } |
1053 | |
1054 | void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { |
1055 | detail::CallbackOstream stream(callback, userData); |
1056 | unwrap(c: type).print(os&: stream); |
1057 | } |
1058 | |
1059 | void mlirTypeDump(MlirType type) { unwrap(c: type).dump(); } |
1060 | |
1061 | //===----------------------------------------------------------------------===// |
1062 | // Attribute API. |
1063 | //===----------------------------------------------------------------------===// |
1064 | |
1065 | MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) { |
1066 | return wrap(cpp: mlir::parseAttribute(attrStr: unwrap(ref: attr), context: unwrap(c: context))); |
1067 | } |
1068 | |
1069 | MlirContext mlirAttributeGetContext(MlirAttribute attribute) { |
1070 | return wrap(cpp: unwrap(c: attribute).getContext()); |
1071 | } |
1072 | |
1073 | MlirType mlirAttributeGetType(MlirAttribute attribute) { |
1074 | Attribute attr = unwrap(c: attribute); |
1075 | if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) |
1076 | return wrap(typedAttr.getType()); |
1077 | return wrap(NoneType::get(attr.getContext())); |
1078 | } |
1079 | |
1080 | MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) { |
1081 | return wrap(cpp: unwrap(c: attr).getTypeID()); |
1082 | } |
1083 | |
1084 | MlirDialect mlirAttributeGetDialect(MlirAttribute attr) { |
1085 | return wrap(cpp: &unwrap(c: attr).getDialect()); |
1086 | } |
1087 | |
1088 | bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { |
1089 | return unwrap(c: a1) == unwrap(c: a2); |
1090 | } |
1091 | |
1092 | void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, |
1093 | void *userData) { |
1094 | detail::CallbackOstream stream(callback, userData); |
1095 | unwrap(c: attr).print(os&: stream); |
1096 | } |
1097 | |
1098 | void mlirAttributeDump(MlirAttribute attr) { unwrap(c: attr).dump(); } |
1099 | |
1100 | MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, |
1101 | MlirAttribute attr) { |
1102 | return MlirNamedAttribute{.name: name, .attribute: attr}; |
1103 | } |
1104 | |
1105 | //===----------------------------------------------------------------------===// |
1106 | // Identifier API. |
1107 | //===----------------------------------------------------------------------===// |
1108 | |
1109 | MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) { |
1110 | return wrap(StringAttr::get(unwrap(context), unwrap(str))); |
1111 | } |
1112 | |
1113 | MlirContext mlirIdentifierGetContext(MlirIdentifier ident) { |
1114 | return wrap(unwrap(ident).getContext()); |
1115 | } |
1116 | |
1117 | bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) { |
1118 | return unwrap(ident) == unwrap(other); |
1119 | } |
1120 | |
1121 | MlirStringRef mlirIdentifierStr(MlirIdentifier ident) { |
1122 | return wrap(unwrap(ident).strref()); |
1123 | } |
1124 | |
1125 | //===----------------------------------------------------------------------===// |
1126 | // Symbol and SymbolTable API. |
1127 | //===----------------------------------------------------------------------===// |
1128 | |
1129 | MlirStringRef mlirSymbolTableGetSymbolAttributeName() { |
1130 | return wrap(ref: SymbolTable::getSymbolAttrName()); |
1131 | } |
1132 | |
1133 | MlirStringRef mlirSymbolTableGetVisibilityAttributeName() { |
1134 | return wrap(ref: SymbolTable::getVisibilityAttrName()); |
1135 | } |
1136 | |
1137 | MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) { |
1138 | if (!unwrap(c: operation)->hasTrait<OpTrait::SymbolTable>()) |
1139 | return wrap(cpp: static_cast<SymbolTable *>(nullptr)); |
1140 | return wrap(cpp: new SymbolTable(unwrap(c: operation))); |
1141 | } |
1142 | |
1143 | void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) { |
1144 | delete unwrap(c: symbolTable); |
1145 | } |
1146 | |
1147 | MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, |
1148 | MlirStringRef name) { |
1149 | return wrap(cpp: unwrap(c: symbolTable)->lookup(name: StringRef(name.data, name.length))); |
1150 | } |
1151 | |
1152 | MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, |
1153 | MlirOperation operation) { |
1154 | return wrap(cpp: (Attribute)unwrap(c: symbolTable)->insert(unwrap(c: operation))); |
1155 | } |
1156 | |
1157 | void mlirSymbolTableErase(MlirSymbolTable symbolTable, |
1158 | MlirOperation operation) { |
1159 | unwrap(c: symbolTable)->erase(symbol: unwrap(c: operation)); |
1160 | } |
1161 | |
1162 | MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, |
1163 | MlirStringRef newSymbol, |
1164 | MlirOperation from) { |
1165 | auto *cppFrom = unwrap(c: from); |
1166 | auto *context = cppFrom->getContext(); |
1167 | auto oldSymbolAttr = StringAttr::get(context, unwrap(oldSymbol)); |
1168 | auto newSymbolAttr = StringAttr::get(context, unwrap(newSymbol)); |
1169 | return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr, |
1170 | unwrap(c: from))); |
1171 | } |
1172 | |
1173 | void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, |
1174 | void (*callback)(MlirOperation, bool, |
1175 | void *userData), |
1176 | void *userData) { |
1177 | SymbolTable::walkSymbolTables(op: unwrap(c: from), allSymUsesVisible, |
1178 | callback: [&](Operation *foundOpCpp, bool isVisible) { |
1179 | callback(wrap(cpp: foundOpCpp), isVisible, |
1180 | userData); |
1181 | }); |
1182 | } |
1183 | |