1//===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir-c/IR.h"
10#include "mlir-c/Support.h"
11
12#include "mlir/AsmParser/AsmParser.h"
13#include "mlir/Bytecode/BytecodeWriter.h"
14#include "mlir/CAPI/IR.h"
15#include "mlir/CAPI/Support.h"
16#include "mlir/CAPI/Utils.h"
17#include "mlir/IR/Attributes.h"
18#include "mlir/IR/BuiltinAttributes.h"
19#include "mlir/IR/BuiltinOps.h"
20#include "mlir/IR/Diagnostics.h"
21#include "mlir/IR/Dialect.h"
22#include "mlir/IR/Location.h"
23#include "mlir/IR/Operation.h"
24#include "mlir/IR/OperationSupport.h"
25#include "mlir/IR/OwningOpRef.h"
26#include "mlir/IR/Types.h"
27#include "mlir/IR/Value.h"
28#include "mlir/IR/Verifier.h"
29#include "mlir/IR/Visitors.h"
30#include "mlir/Interfaces/InferTypeOpInterface.h"
31#include "mlir/Parser/Parser.h"
32#include "llvm/ADT/SmallPtrSet.h"
33#include "llvm/Support/ThreadPool.h"
34
35#include <cstddef>
36#include <memory>
37#include <optional>
38
39using namespace mlir;
40
41//===----------------------------------------------------------------------===//
42// Context API.
43//===----------------------------------------------------------------------===//
44
45MlirContext mlirContextCreate() {
46 auto *context = new MLIRContext;
47 return wrap(cpp: context);
48}
49
50static inline MLIRContext::Threading toThreadingEnum(bool threadingEnabled) {
51 return threadingEnabled ? MLIRContext::Threading::ENABLED
52 : MLIRContext::Threading::DISABLED;
53}
54
55MlirContext mlirContextCreateWithThreading(bool threadingEnabled) {
56 auto *context = new MLIRContext(toThreadingEnum(threadingEnabled));
57 return wrap(cpp: context);
58}
59
60MlirContext mlirContextCreateWithRegistry(MlirDialectRegistry registry,
61 bool threadingEnabled) {
62 auto *context =
63 new MLIRContext(*unwrap(c: registry), toThreadingEnum(threadingEnabled));
64 return wrap(cpp: context);
65}
66
67bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
68 return unwrap(c: ctx1) == unwrap(c: ctx2);
69}
70
71void mlirContextDestroy(MlirContext context) { delete unwrap(c: context); }
72
73void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) {
74 unwrap(c: context)->allowUnregisteredDialects(allow);
75}
76
77bool mlirContextGetAllowUnregisteredDialects(MlirContext context) {
78 return unwrap(c: context)->allowsUnregisteredDialects();
79}
80intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
81 return static_cast<intptr_t>(unwrap(c: context)->getAvailableDialects().size());
82}
83
84void mlirContextAppendDialectRegistry(MlirContext ctx,
85 MlirDialectRegistry registry) {
86 unwrap(c: ctx)->appendDialectRegistry(registry: *unwrap(c: registry));
87}
88
89// TODO: expose a cheaper way than constructing + sorting a vector only to take
90// its size.
91intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
92 return static_cast<intptr_t>(unwrap(c: context)->getLoadedDialects().size());
93}
94
95MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
96 MlirStringRef name) {
97 return wrap(cpp: unwrap(c: context)->getOrLoadDialect(name: unwrap(ref: name)));
98}
99
100bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) {
101 return unwrap(c: context)->isOperationRegistered(name: unwrap(ref: name));
102}
103
104void mlirContextEnableMultithreading(MlirContext context, bool enable) {
105 return unwrap(c: context)->enableMultithreading(enable);
106}
107
108void mlirContextLoadAllAvailableDialects(MlirContext context) {
109 unwrap(c: context)->loadAllAvailableDialects();
110}
111
112void mlirContextSetThreadPool(MlirContext context,
113 MlirLlvmThreadPool threadPool) {
114 unwrap(c: context)->setThreadPool(*unwrap(c: threadPool));
115}
116
117unsigned mlirContextGetNumThreads(MlirContext context) {
118 return unwrap(c: context)->getNumThreads();
119}
120
121MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context) {
122 return wrap(cpp: &unwrap(c: context)->getThreadPool());
123}
124
125//===----------------------------------------------------------------------===//
126// Dialect API.
127//===----------------------------------------------------------------------===//
128
129MlirContext mlirDialectGetContext(MlirDialect dialect) {
130 return wrap(cpp: unwrap(c: dialect)->getContext());
131}
132
133bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
134 return unwrap(c: dialect1) == unwrap(c: dialect2);
135}
136
137MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
138 return wrap(ref: unwrap(c: dialect)->getNamespace());
139}
140
141//===----------------------------------------------------------------------===//
142// DialectRegistry API.
143//===----------------------------------------------------------------------===//
144
145MlirDialectRegistry mlirDialectRegistryCreate() {
146 return wrap(cpp: new DialectRegistry());
147}
148
149void mlirDialectRegistryDestroy(MlirDialectRegistry registry) {
150 delete unwrap(c: registry);
151}
152
153//===----------------------------------------------------------------------===//
154// AsmState API.
155//===----------------------------------------------------------------------===//
156
157MlirAsmState mlirAsmStateCreateForOperation(MlirOperation op,
158 MlirOpPrintingFlags flags) {
159 return wrap(cpp: new AsmState(unwrap(c: op), *unwrap(c: flags)));
160}
161
162static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
163 do {
164 // If we are printing local scope, stop at the first operation that is
165 // isolated from above.
166 if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
167 break;
168
169 // Otherwise, traverse up to the next parent.
170 Operation *parentOp = op->getParentOp();
171 if (!parentOp)
172 break;
173 op = parentOp;
174 } while (true);
175 return op;
176}
177
178MlirAsmState mlirAsmStateCreateForValue(MlirValue value,
179 MlirOpPrintingFlags flags) {
180 Operation *op;
181 mlir::Value val = unwrap(c: value);
182 if (auto result = llvm::dyn_cast<OpResult>(Val&: val)) {
183 op = result.getOwner();
184 } else {
185 op = llvm::cast<BlockArgument>(Val&: val).getOwner()->getParentOp();
186 if (!op) {
187 emitError(loc: val.getLoc()) << "<<UNKNOWN SSA VALUE>>";
188 return {.ptr: nullptr};
189 }
190 }
191 op = findParent(op, shouldUseLocalScope: unwrap(c: flags)->shouldUseLocalScope());
192 return wrap(cpp: new AsmState(op, *unwrap(c: flags)));
193}
194
195/// Destroys printing flags created with mlirAsmStateCreate.
196void mlirAsmStateDestroy(MlirAsmState state) { delete unwrap(c: state); }
197
198//===----------------------------------------------------------------------===//
199// Printing flags API.
200//===----------------------------------------------------------------------===//
201
202MlirOpPrintingFlags mlirOpPrintingFlagsCreate() {
203 return wrap(cpp: new OpPrintingFlags());
204}
205
206void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) {
207 delete unwrap(c: flags);
208}
209
210void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,
211 intptr_t largeElementLimit) {
212 unwrap(c: flags)->elideLargeElementsAttrs(largeElementLimit);
213}
214
215void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags,
216 intptr_t largeResourceLimit) {
217 unwrap(c: flags)->elideLargeResourceString(largeResourceLimit);
218}
219
220void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable,
221 bool prettyForm) {
222 unwrap(c: flags)->enableDebugInfo(enable, /*prettyForm=*/prettyForm);
223}
224
225void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) {
226 unwrap(c: flags)->printGenericOpForm();
227}
228
229void mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags) {
230 unwrap(c: flags)->printNameLocAsPrefix();
231}
232
233void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
234 unwrap(c: flags)->useLocalScope();
235}
236
237void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) {
238 unwrap(c: flags)->assumeVerified();
239}
240
241void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags) {
242 unwrap(c: flags)->skipRegions();
243}
244//===----------------------------------------------------------------------===//
245// Bytecode printing flags API.
246//===----------------------------------------------------------------------===//
247
248MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate() {
249 return wrap(cpp: new BytecodeWriterConfig());
250}
251
252void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config) {
253 delete unwrap(c: config);
254}
255
256void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags,
257 int64_t version) {
258 unwrap(c: flags)->setDesiredBytecodeVersion(version);
259}
260
261//===----------------------------------------------------------------------===//
262// Location API.
263//===----------------------------------------------------------------------===//
264
265MlirAttribute mlirLocationGetAttribute(MlirLocation location) {
266 return wrap(cpp: LocationAttr(unwrap(c: location)));
267}
268
269MlirLocation mlirLocationFromAttribute(MlirAttribute attribute) {
270 return wrap(cpp: Location(llvm::dyn_cast<LocationAttr>(Val: unwrap(c: attribute))));
271}
272
273MlirLocation mlirLocationFileLineColGet(MlirContext context,
274 MlirStringRef filename, unsigned line,
275 unsigned col) {
276 return wrap(cpp: Location(
277 FileLineColLoc::get(context: unwrap(c: context), fileName: unwrap(ref: filename), line, column: col)));
278}
279
280MlirLocation
281mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename,
282 unsigned startLine, unsigned startCol,
283 unsigned endLine, unsigned endCol) {
284 return wrap(
285 Location(FileLineColRange::get(unwrap(context), unwrap(filename),
286 startLine, startCol, endLine, endCol)));
287}
288
289MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location) {
290 return wrap(llvm::dyn_cast<FileLineColRange>(unwrap(c: location)).getFilename());
291}
292
293int mlirLocationFileLineColRangeGetStartLine(MlirLocation location) {
294 if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
295 return loc.getStartLine();
296 return -1;
297}
298
299int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location) {
300 if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
301 return loc.getStartColumn();
302 return -1;
303}
304
305int mlirLocationFileLineColRangeGetEndLine(MlirLocation location) {
306 if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
307 return loc.getEndLine();
308 return -1;
309}
310
311int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location) {
312 if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
313 return loc.getEndColumn();
314 return -1;
315}
316
317MlirTypeID mlirLocationFileLineColRangeGetTypeID() {
318 return wrap(FileLineColRange::getTypeID());
319}
320
321bool mlirLocationIsAFileLineColRange(MlirLocation location) {
322 return isa<FileLineColRange>(Val: unwrap(c: location));
323}
324
325MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) {
326 return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller))));
327}
328
329MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location) {
330 return wrap(
331 Location(llvm::dyn_cast<CallSiteLoc>(unwrap(location)).getCallee()));
332}
333
334MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location) {
335 return wrap(
336 Location(llvm::dyn_cast<CallSiteLoc>(unwrap(location)).getCaller()));
337}
338
339MlirTypeID mlirLocationCallSiteGetTypeID() {
340 return wrap(CallSiteLoc::getTypeID());
341}
342
343bool mlirLocationIsACallSite(MlirLocation location) {
344 return isa<CallSiteLoc>(unwrap(location));
345}
346
347MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
348 MlirLocation const *locations,
349 MlirAttribute metadata) {
350 SmallVector<Location, 4> locs;
351 ArrayRef<Location> unwrappedLocs = unwrapList(size: nLocations, first: locations, storage&: locs);
352 return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx)));
353}
354
355unsigned mlirLocationFusedGetNumLocations(MlirLocation location) {
356 if (auto locationsArrRef = llvm::dyn_cast<FusedLoc>(unwrap(location)))
357 return locationsArrRef.getLocations().size();
358 return 0;
359}
360
361void mlirLocationFusedGetLocations(MlirLocation location,
362 MlirLocation *locationsCPtr) {
363 if (auto locationsArrRef = llvm::dyn_cast<FusedLoc>(unwrap(location))) {
364 for (auto [i, location] : llvm::enumerate(locationsArrRef.getLocations()))
365 locationsCPtr[i] = wrap(location);
366 }
367}
368
369MlirAttribute mlirLocationFusedGetMetadata(MlirLocation location) {
370 return wrap(llvm::dyn_cast<FusedLoc>(unwrap(location)).getMetadata());
371}
372
373MlirTypeID mlirLocationFusedGetTypeID() { return wrap(FusedLoc::getTypeID()); }
374
375bool mlirLocationIsAFused(MlirLocation location) {
376 return isa<FusedLoc>(unwrap(location));
377}
378
379MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name,
380 MlirLocation childLoc) {
381 if (mlirLocationIsNull(childLoc))
382 return wrap(
383 Location(NameLoc::get(StringAttr::get(unwrap(context), unwrap(name)))));
384 return wrap(Location(NameLoc::get(
385 StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc))));
386}
387
388MlirIdentifier mlirLocationNameGetName(MlirLocation location) {
389 return wrap((llvm::dyn_cast<NameLoc>(unwrap(location)).getName()));
390}
391
392MlirLocation mlirLocationNameGetChildLoc(MlirLocation location) {
393 return wrap(
394 Location(llvm::dyn_cast<NameLoc>(unwrap(location)).getChildLoc()));
395}
396
397MlirTypeID mlirLocationNameGetTypeID() { return wrap(NameLoc::getTypeID()); }
398
399bool mlirLocationIsAName(MlirLocation location) {
400 return isa<NameLoc>(unwrap(location));
401}
402
403MlirLocation mlirLocationUnknownGet(MlirContext context) {
404 return wrap(Location(UnknownLoc::get(unwrap(context))));
405}
406
407bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) {
408 return unwrap(c: l1) == unwrap(c: l2);
409}
410
411MlirContext mlirLocationGetContext(MlirLocation location) {
412 return wrap(cpp: unwrap(c: location).getContext());
413}
414
415void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
416 void *userData) {
417 detail::CallbackOstream stream(callback, userData);
418 unwrap(c: location).print(os&: stream);
419}
420
421//===----------------------------------------------------------------------===//
422// Module API.
423//===----------------------------------------------------------------------===//
424
425MlirModule mlirModuleCreateEmpty(MlirLocation location) {
426 return wrap(ModuleOp::create(unwrap(location)));
427}
428
429MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) {
430 OwningOpRef<ModuleOp> owning =
431 parseSourceString<ModuleOp>(unwrap(ref: module), unwrap(c: context));
432 if (!owning)
433 return MlirModule{.ptr: nullptr};
434 return MlirModule{owning.release().getOperation()};
435}
436
437MlirModule mlirModuleCreateParseFromFile(MlirContext context,
438 MlirStringRef fileName) {
439 OwningOpRef<ModuleOp> owning =
440 parseSourceFile<ModuleOp>(filename: unwrap(ref: fileName), config: unwrap(c: context));
441 if (!owning)
442 return MlirModule{.ptr: nullptr};
443 return MlirModule{owning.release().getOperation()};
444}
445
446MlirContext mlirModuleGetContext(MlirModule module) {
447 return wrap(unwrap(module).getContext());
448}
449
450MlirBlock mlirModuleGetBody(MlirModule module) {
451 return wrap(unwrap(module).getBody());
452}
453
454void mlirModuleDestroy(MlirModule module) {
455 // Transfer ownership to an OwningOpRef<ModuleOp> so that its destructor is
456 // called.
457 OwningOpRef<ModuleOp>(unwrap(module));
458}
459
460MlirOperation mlirModuleGetOperation(MlirModule module) {
461 return wrap(unwrap(module).getOperation());
462}
463
464MlirModule mlirModuleFromOperation(MlirOperation op) {
465 return wrap(dyn_cast<ModuleOp>(unwrap(c: op)));
466}
467
468//===----------------------------------------------------------------------===//
469// Operation state API.
470//===----------------------------------------------------------------------===//
471
472MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) {
473 MlirOperationState state;
474 state.name = name;
475 state.location = loc;
476 state.nResults = 0;
477 state.results = nullptr;
478 state.nOperands = 0;
479 state.operands = nullptr;
480 state.nRegions = 0;
481 state.regions = nullptr;
482 state.nSuccessors = 0;
483 state.successors = nullptr;
484 state.nAttributes = 0;
485 state.attributes = nullptr;
486 state.enableResultTypeInference = false;
487 return state;
488}
489
490#define APPEND_ELEMS(type, sizeName, elemName) \
491 state->elemName = \
492 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \
493 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \
494 state->sizeName += n;
495
496void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
497 MlirType const *results) {
498 APPEND_ELEMS(MlirType, nResults, results);
499}
500
501void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
502 MlirValue const *operands) {
503 APPEND_ELEMS(MlirValue, nOperands, operands);
504}
505void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
506 MlirRegion const *regions) {
507 APPEND_ELEMS(MlirRegion, nRegions, regions);
508}
509void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
510 MlirBlock const *successors) {
511 APPEND_ELEMS(MlirBlock, nSuccessors, successors);
512}
513void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
514 MlirNamedAttribute const *attributes) {
515 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
516}
517
518void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) {
519 state->enableResultTypeInference = true;
520}
521
522//===----------------------------------------------------------------------===//
523// Operation API.
524//===----------------------------------------------------------------------===//
525
526static LogicalResult inferOperationTypes(OperationState &state) {
527 MLIRContext *context = state.getContext();
528 std::optional<RegisteredOperationName> info = state.name.getRegisteredInfo();
529 if (!info) {
530 emitError(loc: state.location)
531 << "type inference was requested for the operation " << state.name
532 << ", but the operation was not registered; ensure that the dialect "
533 "containing the operation is linked into MLIR and registered with "
534 "the context";
535 return failure();
536 }
537
538 auto *inferInterface = info->getInterface<InferTypeOpInterface>();
539 if (!inferInterface) {
540 emitError(loc: state.location)
541 << "type inference was requested for the operation " << state.name
542 << ", but the operation does not support type inference; result "
543 "types must be specified explicitly";
544 return failure();
545 }
546
547 DictionaryAttr attributes = state.attributes.getDictionary(context);
548 OpaqueProperties properties = state.getRawProperties();
549
550 if (!properties && info->getOpPropertyByteSize() > 0 && !attributes.empty()) {
551 auto prop = std::make_unique<char[]>(num: info->getOpPropertyByteSize());
552 properties = OpaqueProperties(prop.get());
553 if (properties) {
554 auto emitError = [&]() {
555 return mlir::emitError(state.location)
556 << " failed properties conversion while building "
557 << state.name.getStringRef() << " with `" << attributes << "`: ";
558 };
559 if (failed(info->setOpPropertiesFromAttribute(opName: state.name, properties,
560 attr: attributes, emitError)))
561 return failure();
562 }
563 if (succeeded(inferInterface->inferReturnTypes(
564 context, state.location, state.operands, attributes, properties,
565 state.regions, state.types))) {
566 return success();
567 }
568 // Diagnostic emitted by interface.
569 return failure();
570 }
571
572 if (succeeded(inferInterface->inferReturnTypes(
573 context, state.location, state.operands, attributes, properties,
574 state.regions, state.types)))
575 return success();
576
577 // Diagnostic emitted by interface.
578 return failure();
579}
580
581MlirOperation mlirOperationCreate(MlirOperationState *state) {
582 assert(state);
583 OperationState cppState(unwrap(c: state->location), unwrap(ref: state->name));
584 SmallVector<Type, 4> resultStorage;
585 SmallVector<Value, 8> operandStorage;
586 SmallVector<Block *, 2> successorStorage;
587 cppState.addTypes(newTypes: unwrapList(size: state->nResults, first: state->results, storage&: resultStorage));
588 cppState.addOperands(
589 newOperands: unwrapList(size: state->nOperands, first: state->operands, storage&: operandStorage));
590 cppState.addSuccessors(
591 newSuccessors: unwrapList(size: state->nSuccessors, first: state->successors, storage&: successorStorage));
592
593 cppState.attributes.reserve(N: state->nAttributes);
594 for (intptr_t i = 0; i < state->nAttributes; ++i)
595 cppState.addAttribute(unwrap(state->attributes[i].name),
596 unwrap(c: state->attributes[i].attribute));
597
598 for (intptr_t i = 0; i < state->nRegions; ++i)
599 cppState.addRegion(region: std::unique_ptr<Region>(unwrap(c: state->regions[i])));
600
601 free(ptr: state->results);
602 free(ptr: state->operands);
603 free(ptr: state->successors);
604 free(ptr: state->regions);
605 free(ptr: state->attributes);
606
607 // Infer result types.
608 if (state->enableResultTypeInference) {
609 assert(cppState.types.empty() &&
610 "result type inference enabled and result types provided");
611 if (failed(Result: inferOperationTypes(state&: cppState)))
612 return {.ptr: nullptr};
613 }
614
615 return wrap(cpp: Operation::create(state: cppState));
616}
617
618MlirOperation mlirOperationCreateParse(MlirContext context,
619 MlirStringRef sourceStr,
620 MlirStringRef sourceName) {
621
622 return wrap(
623 cpp: parseSourceString(sourceStr: unwrap(ref: sourceStr), config: unwrap(c: context), sourceName: unwrap(ref: sourceName))
624 .release());
625}
626
627MlirOperation mlirOperationClone(MlirOperation op) {
628 return wrap(cpp: unwrap(c: op)->clone());
629}
630
631void mlirOperationDestroy(MlirOperation op) { unwrap(c: op)->erase(); }
632
633void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(c: op)->remove(); }
634
635bool mlirOperationEqual(MlirOperation op, MlirOperation other) {
636 return unwrap(c: op) == unwrap(c: other);
637}
638
639MlirContext mlirOperationGetContext(MlirOperation op) {
640 return wrap(cpp: unwrap(c: op)->getContext());
641}
642
643MlirLocation mlirOperationGetLocation(MlirOperation op) {
644 return wrap(cpp: unwrap(c: op)->getLoc());
645}
646
647MlirTypeID mlirOperationGetTypeID(MlirOperation op) {
648 if (auto info = unwrap(c: op)->getRegisteredInfo())
649 return wrap(cpp: info->getTypeID());
650 return {.ptr: nullptr};
651}
652
653MlirIdentifier mlirOperationGetName(MlirOperation op) {
654 return wrap(unwrap(c: op)->getName().getIdentifier());
655}
656
657MlirBlock mlirOperationGetBlock(MlirOperation op) {
658 return wrap(cpp: unwrap(c: op)->getBlock());
659}
660
661MlirOperation mlirOperationGetParentOperation(MlirOperation op) {
662 return wrap(cpp: unwrap(c: op)->getParentOp());
663}
664
665intptr_t mlirOperationGetNumRegions(MlirOperation op) {
666 return static_cast<intptr_t>(unwrap(c: op)->getNumRegions());
667}
668
669MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
670 return wrap(cpp: &unwrap(c: op)->getRegion(index: static_cast<unsigned>(pos)));
671}
672
673MlirRegion mlirOperationGetFirstRegion(MlirOperation op) {
674 Operation *cppOp = unwrap(c: op);
675 if (cppOp->getNumRegions() == 0)
676 return wrap(cpp: static_cast<Region *>(nullptr));
677 return wrap(cpp: &cppOp->getRegion(index: 0));
678}
679
680MlirRegion mlirRegionGetNextInOperation(MlirRegion region) {
681 Region *cppRegion = unwrap(c: region);
682 Operation *parent = cppRegion->getParentOp();
683 intptr_t next = cppRegion->getRegionNumber() + 1;
684 if (parent->getNumRegions() > next)
685 return wrap(cpp: &parent->getRegion(index: next));
686 return wrap(cpp: static_cast<Region *>(nullptr));
687}
688
689MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
690 return wrap(unwrap(c: op)->getNextNode());
691}
692
693intptr_t mlirOperationGetNumOperands(MlirOperation op) {
694 return static_cast<intptr_t>(unwrap(c: op)->getNumOperands());
695}
696
697MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
698 return wrap(cpp: unwrap(c: op)->getOperand(idx: static_cast<unsigned>(pos)));
699}
700
701void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
702 MlirValue newValue) {
703 unwrap(c: op)->setOperand(idx: static_cast<unsigned>(pos), value: unwrap(c: newValue));
704}
705
706void mlirOperationSetOperands(MlirOperation op, intptr_t nOperands,
707 MlirValue const *operands) {
708 SmallVector<Value> ops;
709 unwrap(c: op)->setOperands(unwrapList(size: nOperands, first: operands, storage&: ops));
710}
711
712intptr_t mlirOperationGetNumResults(MlirOperation op) {
713 return static_cast<intptr_t>(unwrap(c: op)->getNumResults());
714}
715
716MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
717 return wrap(cpp: unwrap(c: op)->getResult(idx: static_cast<unsigned>(pos)));
718}
719
720intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
721 return static_cast<intptr_t>(unwrap(c: op)->getNumSuccessors());
722}
723
724MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
725 return wrap(cpp: unwrap(c: op)->getSuccessor(index: static_cast<unsigned>(pos)));
726}
727
728MLIR_CAPI_EXPORTED bool
729mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name) {
730 std::optional<Attribute> attr = unwrap(c: op)->getInherentAttr(name: unwrap(ref: name));
731 return attr.has_value();
732}
733
734MlirAttribute mlirOperationGetInherentAttributeByName(MlirOperation op,
735 MlirStringRef name) {
736 std::optional<Attribute> attr = unwrap(c: op)->getInherentAttr(name: unwrap(ref: name));
737 if (attr.has_value())
738 return wrap(cpp: *attr);
739 return {};
740}
741
742void mlirOperationSetInherentAttributeByName(MlirOperation op,
743 MlirStringRef name,
744 MlirAttribute attr) {
745 unwrap(op)->setInherentAttr(
746 StringAttr::get(unwrap(op)->getContext(), unwrap(name)), unwrap(attr));
747}
748
749intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) {
750 return static_cast<intptr_t>(
751 llvm::range_size(unwrap(c: op)->getDiscardableAttrs()));
752}
753
754MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op,
755 intptr_t pos) {
756 NamedAttribute attr =
757 *std::next(unwrap(c: op)->getDiscardableAttrs().begin(), pos);
758 return MlirNamedAttribute{wrap(attr.getName()), wrap(cpp: attr.getValue())};
759}
760
761MlirAttribute mlirOperationGetDiscardableAttributeByName(MlirOperation op,
762 MlirStringRef name) {
763 return wrap(cpp: unwrap(c: op)->getDiscardableAttr(name: unwrap(ref: name)));
764}
765
766void mlirOperationSetDiscardableAttributeByName(MlirOperation op,
767 MlirStringRef name,
768 MlirAttribute attr) {
769 unwrap(c: op)->setDiscardableAttr(name: unwrap(ref: name), value: unwrap(c: attr));
770}
771
772bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op,
773 MlirStringRef name) {
774 return !!unwrap(c: op)->removeDiscardableAttr(name: unwrap(ref: name));
775}
776
777void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos,
778 MlirBlock block) {
779 unwrap(c: op)->setSuccessor(block: unwrap(c: block), index: static_cast<unsigned>(pos));
780}
781
782intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
783 return static_cast<intptr_t>(unwrap(c: op)->getAttrs().size());
784}
785
786MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
787 NamedAttribute attr = unwrap(c: op)->getAttrs()[pos];
788 return MlirNamedAttribute{wrap(attr.getName()), wrap(cpp: attr.getValue())};
789}
790
791MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
792 MlirStringRef name) {
793 return wrap(cpp: unwrap(c: op)->getAttr(name: unwrap(ref: name)));
794}
795
796void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name,
797 MlirAttribute attr) {
798 unwrap(c: op)->setAttr(name: unwrap(ref: name), value: unwrap(c: attr));
799}
800
801bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) {
802 return !!unwrap(c: op)->removeAttr(name: unwrap(ref: name));
803}
804
805void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
806 void *userData) {
807 detail::CallbackOstream stream(callback, userData);
808 unwrap(c: op)->print(os&: stream);
809}
810
811void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
812 MlirStringCallback callback, void *userData) {
813 detail::CallbackOstream stream(callback, userData);
814 unwrap(c: op)->print(os&: stream, flags: *unwrap(c: flags));
815}
816
817void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state,
818 MlirStringCallback callback, void *userData) {
819 detail::CallbackOstream stream(callback, userData);
820 if (state.ptr)
821 unwrap(c: op)->print(os&: stream, state&: *unwrap(c: state));
822 unwrap(c: op)->print(os&: stream);
823}
824
825void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback,
826 void *userData) {
827 detail::CallbackOstream stream(callback, userData);
828 // As no desired version is set, no failure can occur.
829 (void)writeBytecodeToFile(op: unwrap(c: op), os&: stream);
830}
831
832MlirLogicalResult mlirOperationWriteBytecodeWithConfig(
833 MlirOperation op, MlirBytecodeWriterConfig config,
834 MlirStringCallback callback, void *userData) {
835 detail::CallbackOstream stream(callback, userData);
836 return wrap(res: writeBytecodeToFile(op: unwrap(c: op), os&: stream, config: *unwrap(c: config)));
837}
838
839void mlirOperationDump(MlirOperation op) { return unwrap(c: op)->dump(); }
840
841bool mlirOperationVerify(MlirOperation op) {
842 return succeeded(Result: verify(op: unwrap(c: op)));
843}
844
845void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) {
846 return unwrap(c: op)->moveAfter(existingOp: unwrap(c: other));
847}
848
849void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
850 return unwrap(c: op)->moveBefore(existingOp: unwrap(c: other));
851}
852
853static mlir::WalkResult unwrap(MlirWalkResult result) {
854 switch (result) {
855 case MlirWalkResultAdvance:
856 return mlir::WalkResult::advance();
857
858 case MlirWalkResultInterrupt:
859 return mlir::WalkResult::interrupt();
860
861 case MlirWalkResultSkip:
862 return mlir::WalkResult::skip();
863 }
864 llvm_unreachable("unknown result in WalkResult::unwrap");
865}
866
867void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
868 void *userData, MlirWalkOrder walkOrder) {
869 switch (walkOrder) {
870
871 case MlirWalkPreOrder:
872 unwrap(c: op)->walk<mlir::WalkOrder::PreOrder>(
873 callback: [callback, userData](Operation *op) {
874 return unwrap(result: callback(wrap(cpp: op), userData));
875 });
876 break;
877 case MlirWalkPostOrder:
878 unwrap(c: op)->walk<mlir::WalkOrder::PostOrder>(
879 callback: [callback, userData](Operation *op) {
880 return unwrap(result: callback(wrap(cpp: op), userData));
881 });
882 }
883}
884
885//===----------------------------------------------------------------------===//
886// Region API.
887//===----------------------------------------------------------------------===//
888
889MlirRegion mlirRegionCreate() { return wrap(cpp: new Region); }
890
891bool mlirRegionEqual(MlirRegion region, MlirRegion other) {
892 return unwrap(c: region) == unwrap(c: other);
893}
894
895MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
896 Region *cppRegion = unwrap(c: region);
897 if (cppRegion->empty())
898 return wrap(cpp: static_cast<Block *>(nullptr));
899 return wrap(cpp: &cppRegion->front());
900}
901
902void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
903 unwrap(c: region)->push_back(block: unwrap(c: block));
904}
905
906void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
907 MlirBlock block) {
908 auto &blockList = unwrap(c: region)->getBlocks();
909 blockList.insert(where: std::next(x: blockList.begin(), n: pos), New: unwrap(c: block));
910}
911
912void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
913 MlirBlock block) {
914 Region *cppRegion = unwrap(c: region);
915 if (mlirBlockIsNull(block: reference)) {
916 cppRegion->getBlocks().insert(where: cppRegion->begin(), New: unwrap(c: block));
917 return;
918 }
919
920 assert(unwrap(reference)->getParent() == unwrap(region) &&
921 "expected reference block to belong to the region");
922 cppRegion->getBlocks().insertAfter(where: Region::iterator(unwrap(c: reference)),
923 New: unwrap(c: block));
924}
925
926void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
927 MlirBlock block) {
928 if (mlirBlockIsNull(block: reference))
929 return mlirRegionAppendOwnedBlock(region, block);
930
931 assert(unwrap(reference)->getParent() == unwrap(region) &&
932 "expected reference block to belong to the region");
933 unwrap(c: region)->getBlocks().insert(where: Region::iterator(unwrap(c: reference)),
934 New: unwrap(c: block));
935}
936
937void mlirRegionDestroy(MlirRegion region) {
938 delete static_cast<Region *>(region.ptr);
939}
940
941void mlirRegionTakeBody(MlirRegion target, MlirRegion source) {
942 unwrap(c: target)->takeBody(other&: *unwrap(c: source));
943}
944
945//===----------------------------------------------------------------------===//
946// Block API.
947//===----------------------------------------------------------------------===//
948
949MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args,
950 MlirLocation const *locs) {
951 Block *b = new Block;
952 for (intptr_t i = 0; i < nArgs; ++i)
953 b->addArgument(type: unwrap(c: args[i]), loc: unwrap(c: locs[i]));
954 return wrap(cpp: b);
955}
956
957bool mlirBlockEqual(MlirBlock block, MlirBlock other) {
958 return unwrap(c: block) == unwrap(c: other);
959}
960
961MlirOperation mlirBlockGetParentOperation(MlirBlock block) {
962 return wrap(cpp: unwrap(c: block)->getParentOp());
963}
964
965MlirRegion mlirBlockGetParentRegion(MlirBlock block) {
966 return wrap(cpp: unwrap(c: block)->getParent());
967}
968
969MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
970 return wrap(cpp: unwrap(c: block)->getNextNode());
971}
972
973MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
974 Block *cppBlock = unwrap(c: block);
975 if (cppBlock->empty())
976 return wrap(cpp: static_cast<Operation *>(nullptr));
977 return wrap(cpp: &cppBlock->front());
978}
979
980MlirOperation mlirBlockGetTerminator(MlirBlock block) {
981 Block *cppBlock = unwrap(c: block);
982 if (cppBlock->empty())
983 return wrap(cpp: static_cast<Operation *>(nullptr));
984 Operation &back = cppBlock->back();
985 if (!back.hasTrait<OpTrait::IsTerminator>())
986 return wrap(cpp: static_cast<Operation *>(nullptr));
987 return wrap(cpp: &back);
988}
989
990void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
991 unwrap(c: block)->push_back(op: unwrap(c: operation));
992}
993
994void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
995 MlirOperation operation) {
996 auto &opList = unwrap(c: block)->getOperations();
997 opList.insert(where: std::next(x: opList.begin(), n: pos), New: unwrap(c: operation));
998}
999
1000void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
1001 MlirOperation reference,
1002 MlirOperation operation) {
1003 Block *cppBlock = unwrap(c: block);
1004 if (mlirOperationIsNull(op: reference)) {
1005 cppBlock->getOperations().insert(where: cppBlock->begin(), New: unwrap(c: operation));
1006 return;
1007 }
1008
1009 assert(unwrap(reference)->getBlock() == unwrap(block) &&
1010 "expected reference operation to belong to the block");
1011 cppBlock->getOperations().insertAfter(where: Block::iterator(unwrap(c: reference)),
1012 New: unwrap(c: operation));
1013}
1014
1015void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
1016 MlirOperation reference,
1017 MlirOperation operation) {
1018 if (mlirOperationIsNull(op: reference))
1019 return mlirBlockAppendOwnedOperation(block, operation);
1020
1021 assert(unwrap(reference)->getBlock() == unwrap(block) &&
1022 "expected reference operation to belong to the block");
1023 unwrap(c: block)->getOperations().insert(where: Block::iterator(unwrap(c: reference)),
1024 New: unwrap(c: operation));
1025}
1026
1027void mlirBlockDestroy(MlirBlock block) { delete unwrap(c: block); }
1028
1029void mlirBlockDetach(MlirBlock block) {
1030 Block *b = unwrap(c: block);
1031 b->getParent()->getBlocks().remove(IT: b);
1032}
1033
1034intptr_t mlirBlockGetNumArguments(MlirBlock block) {
1035 return static_cast<intptr_t>(unwrap(c: block)->getNumArguments());
1036}
1037
1038MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type,
1039 MlirLocation loc) {
1040 return wrap(cpp: unwrap(c: block)->addArgument(type: unwrap(c: type), loc: unwrap(c: loc)));
1041}
1042
1043void mlirBlockEraseArgument(MlirBlock block, unsigned index) {
1044 return unwrap(c: block)->eraseArgument(index);
1045}
1046
1047MlirValue mlirBlockInsertArgument(MlirBlock block, intptr_t pos, MlirType type,
1048 MlirLocation loc) {
1049 return wrap(cpp: unwrap(c: block)->insertArgument(index: pos, type: unwrap(c: type), loc: unwrap(c: loc)));
1050}
1051
1052MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
1053 return wrap(cpp: unwrap(c: block)->getArgument(i: static_cast<unsigned>(pos)));
1054}
1055
1056void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
1057 void *userData) {
1058 detail::CallbackOstream stream(callback, userData);
1059 unwrap(c: block)->print(os&: stream);
1060}
1061
1062//===----------------------------------------------------------------------===//
1063// Value API.
1064//===----------------------------------------------------------------------===//
1065
1066bool mlirValueEqual(MlirValue value1, MlirValue value2) {
1067 return unwrap(c: value1) == unwrap(c: value2);
1068}
1069
1070bool mlirValueIsABlockArgument(MlirValue value) {
1071 return llvm::isa<BlockArgument>(Val: unwrap(c: value));
1072}
1073
1074bool mlirValueIsAOpResult(MlirValue value) {
1075 return llvm::isa<OpResult>(Val: unwrap(c: value));
1076}
1077
1078MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
1079 return wrap(cpp: llvm::dyn_cast<BlockArgument>(Val: unwrap(c: value)).getOwner());
1080}
1081
1082intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
1083 return static_cast<intptr_t>(
1084 llvm::dyn_cast<BlockArgument>(Val: unwrap(c: value)).getArgNumber());
1085}
1086
1087void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
1088 if (auto blockArg = llvm::dyn_cast<BlockArgument>(Val: unwrap(c: value)))
1089 blockArg.setType(unwrap(c: type));
1090}
1091
1092MlirOperation mlirOpResultGetOwner(MlirValue value) {
1093 return wrap(cpp: llvm::dyn_cast<OpResult>(Val: unwrap(c: value)).getOwner());
1094}
1095
1096intptr_t mlirOpResultGetResultNumber(MlirValue value) {
1097 return static_cast<intptr_t>(
1098 llvm::dyn_cast<OpResult>(Val: unwrap(c: value)).getResultNumber());
1099}
1100
1101MlirType mlirValueGetType(MlirValue value) {
1102 return wrap(cpp: unwrap(c: value).getType());
1103}
1104
1105void mlirValueSetType(MlirValue value, MlirType type) {
1106 unwrap(c: value).setType(unwrap(c: type));
1107}
1108
1109void mlirValueDump(MlirValue value) { unwrap(c: value).dump(); }
1110
1111void mlirValuePrint(MlirValue value, MlirStringCallback callback,
1112 void *userData) {
1113 detail::CallbackOstream stream(callback, userData);
1114 unwrap(c: value).print(os&: stream);
1115}
1116
1117void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state,
1118 MlirStringCallback callback, void *userData) {
1119 detail::CallbackOstream stream(callback, userData);
1120 Value cppValue = unwrap(c: value);
1121 cppValue.printAsOperand(os&: stream, state&: *unwrap(c: state));
1122}
1123
1124MlirOpOperand mlirValueGetFirstUse(MlirValue value) {
1125 Value cppValue = unwrap(c: value);
1126 if (cppValue.use_empty())
1127 return {};
1128
1129 OpOperand *opOperand = cppValue.use_begin().getOperand();
1130
1131 return wrap(cpp: opOperand);
1132}
1133
1134void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) {
1135 unwrap(c: oldValue).replaceAllUsesWith(newValue: unwrap(c: newValue));
1136}
1137
1138void mlirValueReplaceAllUsesExcept(MlirValue oldValue, MlirValue newValue,
1139 intptr_t numExceptions,
1140 MlirOperation *exceptions) {
1141 Value oldValueCpp = unwrap(c: oldValue);
1142 Value newValueCpp = unwrap(c: newValue);
1143
1144 llvm::SmallPtrSet<mlir::Operation *, 4> exceptionSet;
1145 for (intptr_t i = 0; i < numExceptions; ++i) {
1146 exceptionSet.insert(Ptr: unwrap(c: exceptions[i]));
1147 }
1148
1149 oldValueCpp.replaceAllUsesExcept(newValue: newValueCpp, exceptions: exceptionSet);
1150}
1151
1152MlirLocation mlirValueGetLocation(MlirValue v) {
1153 return wrap(cpp: unwrap(c: v).getLoc());
1154}
1155
1156MlirContext mlirValueGetContext(MlirValue v) {
1157 return wrap(cpp: unwrap(c: v).getContext());
1158}
1159
1160//===----------------------------------------------------------------------===//
1161// OpOperand API.
1162//===----------------------------------------------------------------------===//
1163
1164bool mlirOpOperandIsNull(MlirOpOperand opOperand) { return !opOperand.ptr; }
1165
1166MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand) {
1167 return wrap(cpp: unwrap(c: opOperand)->getOwner());
1168}
1169
1170MlirValue mlirOpOperandGetValue(MlirOpOperand opOperand) {
1171 return wrap(cpp: unwrap(c: opOperand)->get());
1172}
1173
1174unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand) {
1175 return unwrap(c: opOperand)->getOperandNumber();
1176}
1177
1178MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand) {
1179 if (mlirOpOperandIsNull(opOperand))
1180 return {};
1181
1182 OpOperand *nextOpOperand = static_cast<OpOperand *>(
1183 unwrap(c: opOperand)->getNextOperandUsingThisValue());
1184
1185 if (!nextOpOperand)
1186 return {};
1187
1188 return wrap(cpp: nextOpOperand);
1189}
1190
1191//===----------------------------------------------------------------------===//
1192// Type API.
1193//===----------------------------------------------------------------------===//
1194
1195MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) {
1196 return wrap(cpp: mlir::parseType(typeStr: unwrap(ref: type), context: unwrap(c: context)));
1197}
1198
1199MlirContext mlirTypeGetContext(MlirType type) {
1200 return wrap(cpp: unwrap(c: type).getContext());
1201}
1202
1203MlirTypeID mlirTypeGetTypeID(MlirType type) {
1204 return wrap(cpp: unwrap(c: type).getTypeID());
1205}
1206
1207MlirDialect mlirTypeGetDialect(MlirType type) {
1208 return wrap(cpp: &unwrap(c: type).getDialect());
1209}
1210
1211bool mlirTypeEqual(MlirType t1, MlirType t2) {
1212 return unwrap(c: t1) == unwrap(c: t2);
1213}
1214
1215void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
1216 detail::CallbackOstream stream(callback, userData);
1217 unwrap(c: type).print(os&: stream);
1218}
1219
1220void mlirTypeDump(MlirType type) { unwrap(c: type).dump(); }
1221
1222//===----------------------------------------------------------------------===//
1223// Attribute API.
1224//===----------------------------------------------------------------------===//
1225
1226MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) {
1227 return wrap(cpp: mlir::parseAttribute(attrStr: unwrap(ref: attr), context: unwrap(c: context)));
1228}
1229
1230MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
1231 return wrap(cpp: unwrap(c: attribute).getContext());
1232}
1233
1234MlirType mlirAttributeGetType(MlirAttribute attribute) {
1235 Attribute attr = unwrap(c: attribute);
1236 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr))
1237 return wrap(typedAttr.getType());
1238 return wrap(NoneType::get(attr.getContext()));
1239}
1240
1241MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) {
1242 return wrap(cpp: unwrap(c: attr).getTypeID());
1243}
1244
1245MlirDialect mlirAttributeGetDialect(MlirAttribute attr) {
1246 return wrap(cpp: &unwrap(c: attr).getDialect());
1247}
1248
1249bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
1250 return unwrap(c: a1) == unwrap(c: a2);
1251}
1252
1253void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
1254 void *userData) {
1255 detail::CallbackOstream stream(callback, userData);
1256 unwrap(c: attr).print(os&: stream);
1257}
1258
1259void mlirAttributeDump(MlirAttribute attr) { unwrap(c: attr).dump(); }
1260
1261MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name,
1262 MlirAttribute attr) {
1263 return MlirNamedAttribute{.name: name, .attribute: attr};
1264}
1265
1266//===----------------------------------------------------------------------===//
1267// Identifier API.
1268//===----------------------------------------------------------------------===//
1269
1270MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) {
1271 return wrap(StringAttr::get(unwrap(context), unwrap(str)));
1272}
1273
1274MlirContext mlirIdentifierGetContext(MlirIdentifier ident) {
1275 return wrap(unwrap(ident).getContext());
1276}
1277
1278bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) {
1279 return unwrap(ident) == unwrap(other);
1280}
1281
1282MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
1283 return wrap(unwrap(ident).strref());
1284}
1285
1286//===----------------------------------------------------------------------===//
1287// Symbol and SymbolTable API.
1288//===----------------------------------------------------------------------===//
1289
1290MlirStringRef mlirSymbolTableGetSymbolAttributeName() {
1291 return wrap(ref: SymbolTable::getSymbolAttrName());
1292}
1293
1294MlirStringRef mlirSymbolTableGetVisibilityAttributeName() {
1295 return wrap(ref: SymbolTable::getVisibilityAttrName());
1296}
1297
1298MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) {
1299 if (!unwrap(c: operation)->hasTrait<OpTrait::SymbolTable>())
1300 return wrap(cpp: static_cast<SymbolTable *>(nullptr));
1301 return wrap(cpp: new SymbolTable(unwrap(c: operation)));
1302}
1303
1304void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) {
1305 delete unwrap(c: symbolTable);
1306}
1307
1308MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable,
1309 MlirStringRef name) {
1310 return wrap(cpp: unwrap(c: symbolTable)->lookup(name: StringRef(name.data, name.length)));
1311}
1312
1313MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable,
1314 MlirOperation operation) {
1315 return wrap(cpp: (Attribute)unwrap(c: symbolTable)->insert(unwrap(c: operation)));
1316}
1317
1318void mlirSymbolTableErase(MlirSymbolTable symbolTable,
1319 MlirOperation operation) {
1320 unwrap(c: symbolTable)->erase(symbol: unwrap(c: operation));
1321}
1322
1323MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol,
1324 MlirStringRef newSymbol,
1325 MlirOperation from) {
1326 auto *cppFrom = unwrap(c: from);
1327 auto *context = cppFrom->getContext();
1328 auto oldSymbolAttr = StringAttr::get(context, unwrap(oldSymbol));
1329 auto newSymbolAttr = StringAttr::get(context, unwrap(newSymbol));
1330 return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr,
1331 unwrap(c: from)));
1332}
1333
1334void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible,
1335 void (*callback)(MlirOperation, bool,
1336 void *userData),
1337 void *userData) {
1338 SymbolTable::walkSymbolTables(op: unwrap(c: from), allSymUsesVisible,
1339 callback: [&](Operation *foundOpCpp, bool isVisible) {
1340 callback(wrap(cpp: foundOpCpp), isVisible,
1341 userData);
1342 });
1343}
1344

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/CAPI/IR/IR.cpp