1//===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir-c/IR.h"
10#include "mlir-c/Support.h"
11
12#include "mlir/AsmParser/AsmParser.h"
13#include "mlir/Bytecode/BytecodeWriter.h"
14#include "mlir/CAPI/IR.h"
15#include "mlir/CAPI/Support.h"
16#include "mlir/CAPI/Utils.h"
17#include "mlir/IR/Attributes.h"
18#include "mlir/IR/BuiltinAttributes.h"
19#include "mlir/IR/BuiltinOps.h"
20#include "mlir/IR/Diagnostics.h"
21#include "mlir/IR/Dialect.h"
22#include "mlir/IR/Location.h"
23#include "mlir/IR/Operation.h"
24#include "mlir/IR/OperationSupport.h"
25#include "mlir/IR/OwningOpRef.h"
26#include "mlir/IR/Types.h"
27#include "mlir/IR/Value.h"
28#include "mlir/IR/Verifier.h"
29#include "mlir/IR/Visitors.h"
30#include "mlir/Interfaces/InferTypeOpInterface.h"
31#include "mlir/Parser/Parser.h"
32#include "llvm/ADT/SmallPtrSet.h"
33#include "llvm/Support/ThreadPool.h"
34
35#include <cstddef>
36#include <memory>
37#include <optional>
38
39using namespace mlir;
40
41//===----------------------------------------------------------------------===//
42// Context API.
43//===----------------------------------------------------------------------===//
44
45MlirContext mlirContextCreate() {
46 auto *context = new MLIRContext;
47 return wrap(cpp: context);
48}
49
50static inline MLIRContext::Threading toThreadingEnum(bool threadingEnabled) {
51 return threadingEnabled ? MLIRContext::Threading::ENABLED
52 : MLIRContext::Threading::DISABLED;
53}
54
55MlirContext mlirContextCreateWithThreading(bool threadingEnabled) {
56 auto *context = new MLIRContext(toThreadingEnum(threadingEnabled));
57 return wrap(cpp: context);
58}
59
60MlirContext mlirContextCreateWithRegistry(MlirDialectRegistry registry,
61 bool threadingEnabled) {
62 auto *context =
63 new MLIRContext(*unwrap(c: registry), toThreadingEnum(threadingEnabled));
64 return wrap(cpp: context);
65}
66
67bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
68 return unwrap(c: ctx1) == unwrap(c: ctx2);
69}
70
71void mlirContextDestroy(MlirContext context) { delete unwrap(c: context); }
72
73void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) {
74 unwrap(c: context)->allowUnregisteredDialects(allow);
75}
76
77bool mlirContextGetAllowUnregisteredDialects(MlirContext context) {
78 return unwrap(c: context)->allowsUnregisteredDialects();
79}
80intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
81 return static_cast<intptr_t>(unwrap(c: context)->getAvailableDialects().size());
82}
83
84void mlirContextAppendDialectRegistry(MlirContext ctx,
85 MlirDialectRegistry registry) {
86 unwrap(c: ctx)->appendDialectRegistry(registry: *unwrap(c: registry));
87}
88
89// TODO: expose a cheaper way than constructing + sorting a vector only to take
90// its size.
91intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
92 return static_cast<intptr_t>(unwrap(c: context)->getLoadedDialects().size());
93}
94
95MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
96 MlirStringRef name) {
97 return wrap(cpp: unwrap(c: context)->getOrLoadDialect(name: unwrap(ref: name)));
98}
99
100bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) {
101 return unwrap(c: context)->isOperationRegistered(name: unwrap(ref: name));
102}
103
104void mlirContextEnableMultithreading(MlirContext context, bool enable) {
105 return unwrap(c: context)->enableMultithreading(enable);
106}
107
108void mlirContextLoadAllAvailableDialects(MlirContext context) {
109 unwrap(c: context)->loadAllAvailableDialects();
110}
111
112void mlirContextSetThreadPool(MlirContext context,
113 MlirLlvmThreadPool threadPool) {
114 unwrap(c: context)->setThreadPool(*unwrap(c: threadPool));
115}
116
117unsigned mlirContextGetNumThreads(MlirContext context) {
118 return unwrap(c: context)->getNumThreads();
119}
120
121MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context) {
122 return wrap(cpp: &unwrap(c: context)->getThreadPool());
123}
124
125//===----------------------------------------------------------------------===//
126// Dialect API.
127//===----------------------------------------------------------------------===//
128
129MlirContext mlirDialectGetContext(MlirDialect dialect) {
130 return wrap(cpp: unwrap(c: dialect)->getContext());
131}
132
133bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
134 return unwrap(c: dialect1) == unwrap(c: dialect2);
135}
136
137MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
138 return wrap(ref: unwrap(c: dialect)->getNamespace());
139}
140
141//===----------------------------------------------------------------------===//
142// DialectRegistry API.
143//===----------------------------------------------------------------------===//
144
145MlirDialectRegistry mlirDialectRegistryCreate() {
146 return wrap(cpp: new DialectRegistry());
147}
148
149void mlirDialectRegistryDestroy(MlirDialectRegistry registry) {
150 delete unwrap(c: registry);
151}
152
153//===----------------------------------------------------------------------===//
154// AsmState API.
155//===----------------------------------------------------------------------===//
156
157MlirAsmState mlirAsmStateCreateForOperation(MlirOperation op,
158 MlirOpPrintingFlags flags) {
159 return wrap(cpp: new AsmState(unwrap(c: op), *unwrap(c: flags)));
160}
161
162static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
163 do {
164 // If we are printing local scope, stop at the first operation that is
165 // isolated from above.
166 if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
167 break;
168
169 // Otherwise, traverse up to the next parent.
170 Operation *parentOp = op->getParentOp();
171 if (!parentOp)
172 break;
173 op = parentOp;
174 } while (true);
175 return op;
176}
177
178MlirAsmState mlirAsmStateCreateForValue(MlirValue value,
179 MlirOpPrintingFlags flags) {
180 Operation *op;
181 mlir::Value val = unwrap(c: value);
182 if (auto result = llvm::dyn_cast<OpResult>(Val&: val)) {
183 op = result.getOwner();
184 } else {
185 op = llvm::cast<BlockArgument>(Val&: val).getOwner()->getParentOp();
186 if (!op) {
187 emitError(loc: val.getLoc()) << "<<UNKNOWN SSA VALUE>>";
188 return {.ptr: nullptr};
189 }
190 }
191 op = findParent(op, shouldUseLocalScope: unwrap(c: flags)->shouldUseLocalScope());
192 return wrap(cpp: new AsmState(op, *unwrap(c: flags)));
193}
194
195/// Destroys printing flags created with mlirAsmStateCreate.
196void mlirAsmStateDestroy(MlirAsmState state) { delete unwrap(c: state); }
197
198//===----------------------------------------------------------------------===//
199// Printing flags API.
200//===----------------------------------------------------------------------===//
201
202MlirOpPrintingFlags mlirOpPrintingFlagsCreate() {
203 return wrap(cpp: new OpPrintingFlags());
204}
205
206void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) {
207 delete unwrap(c: flags);
208}
209
210void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,
211 intptr_t largeElementLimit) {
212 unwrap(c: flags)->elideLargeElementsAttrs(largeElementLimit);
213}
214
215void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags,
216 intptr_t largeResourceLimit) {
217 unwrap(c: flags)->elideLargeResourceString(largeResourceLimit);
218}
219
220void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable,
221 bool prettyForm) {
222 unwrap(c: flags)->enableDebugInfo(enable, /*prettyForm=*/prettyForm);
223}
224
225void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) {
226 unwrap(c: flags)->printGenericOpForm();
227}
228
229void mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags) {
230 unwrap(c: flags)->printNameLocAsPrefix();
231}
232
233void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
234 unwrap(c: flags)->useLocalScope();
235}
236
237void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) {
238 unwrap(c: flags)->assumeVerified();
239}
240
241void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags) {
242 unwrap(c: flags)->skipRegions();
243}
244//===----------------------------------------------------------------------===//
245// Bytecode printing flags API.
246//===----------------------------------------------------------------------===//
247
248MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate() {
249 return wrap(cpp: new BytecodeWriterConfig());
250}
251
252void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config) {
253 delete unwrap(c: config);
254}
255
256void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags,
257 int64_t version) {
258 unwrap(c: flags)->setDesiredBytecodeVersion(version);
259}
260
261//===----------------------------------------------------------------------===//
262// Location API.
263//===----------------------------------------------------------------------===//
264
265MlirAttribute mlirLocationGetAttribute(MlirLocation location) {
266 return wrap(cpp: LocationAttr(unwrap(c: location)));
267}
268
269MlirLocation mlirLocationFromAttribute(MlirAttribute attribute) {
270 return wrap(cpp: Location(llvm::dyn_cast<LocationAttr>(Val: unwrap(c: attribute))));
271}
272
273MlirLocation mlirLocationFileLineColGet(MlirContext context,
274 MlirStringRef filename, unsigned line,
275 unsigned col) {
276 return wrap(cpp: Location(
277 FileLineColLoc::get(context: unwrap(c: context), fileName: unwrap(ref: filename), line, column: col)));
278}
279
280MlirLocation
281mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename,
282 unsigned startLine, unsigned startCol,
283 unsigned endLine, unsigned endCol) {
284 return wrap(
285 cpp: Location(FileLineColRange::get(context: unwrap(c: context), filename: unwrap(ref: filename),
286 start_line: startLine, start_column: startCol, end_line: endLine, end_column: endCol)));
287}
288
289MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location) {
290 return wrap(cpp: llvm::dyn_cast<FileLineColRange>(Val: unwrap(c: location)).getFilename());
291}
292
293int mlirLocationFileLineColRangeGetStartLine(MlirLocation location) {
294 if (auto loc = llvm::dyn_cast<FileLineColRange>(Val: unwrap(c: location)))
295 return loc.getStartLine();
296 return -1;
297}
298
299int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location) {
300 if (auto loc = llvm::dyn_cast<FileLineColRange>(Val: unwrap(c: location)))
301 return loc.getStartColumn();
302 return -1;
303}
304
305int mlirLocationFileLineColRangeGetEndLine(MlirLocation location) {
306 if (auto loc = llvm::dyn_cast<FileLineColRange>(Val: unwrap(c: location)))
307 return loc.getEndLine();
308 return -1;
309}
310
311int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location) {
312 if (auto loc = llvm::dyn_cast<FileLineColRange>(Val: unwrap(c: location)))
313 return loc.getEndColumn();
314 return -1;
315}
316
317MlirTypeID mlirLocationFileLineColRangeGetTypeID() {
318 return wrap(cpp: FileLineColRange::getTypeID());
319}
320
321bool mlirLocationIsAFileLineColRange(MlirLocation location) {
322 return isa<FileLineColRange>(Val: unwrap(c: location));
323}
324
325MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) {
326 return wrap(cpp: Location(CallSiteLoc::get(callee: unwrap(c: callee), caller: unwrap(c: caller))));
327}
328
329MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location) {
330 return wrap(
331 cpp: Location(llvm::dyn_cast<CallSiteLoc>(Val: unwrap(c: location)).getCallee()));
332}
333
334MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location) {
335 return wrap(
336 cpp: Location(llvm::dyn_cast<CallSiteLoc>(Val: unwrap(c: location)).getCaller()));
337}
338
339MlirTypeID mlirLocationCallSiteGetTypeID() {
340 return wrap(cpp: CallSiteLoc::getTypeID());
341}
342
343bool mlirLocationIsACallSite(MlirLocation location) {
344 return isa<CallSiteLoc>(Val: unwrap(c: location));
345}
346
347MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
348 MlirLocation const *locations,
349 MlirAttribute metadata) {
350 SmallVector<Location, 4> locs;
351 ArrayRef<Location> unwrappedLocs = unwrapList(size: nLocations, first: locations, storage&: locs);
352 return wrap(cpp: FusedLoc::get(locs: unwrappedLocs, metadata: unwrap(c: metadata), context: unwrap(c: ctx)));
353}
354
355unsigned mlirLocationFusedGetNumLocations(MlirLocation location) {
356 if (auto locationsArrRef = llvm::dyn_cast<FusedLoc>(Val: unwrap(c: location)))
357 return locationsArrRef.getLocations().size();
358 return 0;
359}
360
361void mlirLocationFusedGetLocations(MlirLocation location,
362 MlirLocation *locationsCPtr) {
363 if (auto locationsArrRef = llvm::dyn_cast<FusedLoc>(Val: unwrap(c: location))) {
364 for (auto [i, location] : llvm::enumerate(First: locationsArrRef.getLocations()))
365 locationsCPtr[i] = wrap(cpp: location);
366 }
367}
368
369MlirAttribute mlirLocationFusedGetMetadata(MlirLocation location) {
370 return wrap(cpp: llvm::dyn_cast<FusedLoc>(Val: unwrap(c: location)).getMetadata());
371}
372
373MlirTypeID mlirLocationFusedGetTypeID() { return wrap(cpp: FusedLoc::getTypeID()); }
374
375bool mlirLocationIsAFused(MlirLocation location) {
376 return isa<FusedLoc>(Val: unwrap(c: location));
377}
378
379MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name,
380 MlirLocation childLoc) {
381 if (mlirLocationIsNull(location: childLoc))
382 return wrap(
383 cpp: Location(NameLoc::get(name: StringAttr::get(context: unwrap(c: context), bytes: unwrap(ref: name)))));
384 return wrap(cpp: Location(NameLoc::get(
385 name: StringAttr::get(context: unwrap(c: context), bytes: unwrap(ref: name)), childLoc: unwrap(c: childLoc))));
386}
387
388MlirIdentifier mlirLocationNameGetName(MlirLocation location) {
389 return wrap(cpp: (llvm::dyn_cast<NameLoc>(Val: unwrap(c: location)).getName()));
390}
391
392MlirLocation mlirLocationNameGetChildLoc(MlirLocation location) {
393 return wrap(
394 cpp: Location(llvm::dyn_cast<NameLoc>(Val: unwrap(c: location)).getChildLoc()));
395}
396
397MlirTypeID mlirLocationNameGetTypeID() { return wrap(cpp: NameLoc::getTypeID()); }
398
399bool mlirLocationIsAName(MlirLocation location) {
400 return isa<NameLoc>(Val: unwrap(c: location));
401}
402
403MlirLocation mlirLocationUnknownGet(MlirContext context) {
404 return wrap(cpp: Location(UnknownLoc::get(context: unwrap(c: context))));
405}
406
407bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) {
408 return unwrap(c: l1) == unwrap(c: l2);
409}
410
411MlirContext mlirLocationGetContext(MlirLocation location) {
412 return wrap(cpp: unwrap(c: location).getContext());
413}
414
415void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
416 void *userData) {
417 detail::CallbackOstream stream(callback, userData);
418 unwrap(c: location).print(os&: stream);
419}
420
421//===----------------------------------------------------------------------===//
422// Module API.
423//===----------------------------------------------------------------------===//
424
425MlirModule mlirModuleCreateEmpty(MlirLocation location) {
426 return wrap(cpp: ModuleOp::create(loc: unwrap(c: location)));
427}
428
429MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) {
430 OwningOpRef<ModuleOp> owning =
431 parseSourceString<ModuleOp>(sourceStr: unwrap(ref: module), config: unwrap(c: context));
432 if (!owning)
433 return MlirModule{.ptr: nullptr};
434 return MlirModule{.ptr: owning.release().getOperation()};
435}
436
437MlirModule mlirModuleCreateParseFromFile(MlirContext context,
438 MlirStringRef fileName) {
439 OwningOpRef<ModuleOp> owning =
440 parseSourceFile<ModuleOp>(filename: unwrap(ref: fileName), config: unwrap(c: context));
441 if (!owning)
442 return MlirModule{.ptr: nullptr};
443 return MlirModule{.ptr: owning.release().getOperation()};
444}
445
446MlirContext mlirModuleGetContext(MlirModule module) {
447 return wrap(cpp: unwrap(c: module).getContext());
448}
449
450MlirBlock mlirModuleGetBody(MlirModule module) {
451 return wrap(cpp: unwrap(c: module).getBody());
452}
453
454void mlirModuleDestroy(MlirModule module) {
455 // Transfer ownership to an OwningOpRef<ModuleOp> so that its destructor is
456 // called.
457 OwningOpRef<ModuleOp>(unwrap(c: module));
458}
459
460MlirOperation mlirModuleGetOperation(MlirModule module) {
461 return wrap(cpp: unwrap(c: module).getOperation());
462}
463
464MlirModule mlirModuleFromOperation(MlirOperation op) {
465 return wrap(cpp: dyn_cast<ModuleOp>(Val: unwrap(c: op)));
466}
467
468//===----------------------------------------------------------------------===//
469// Operation state API.
470//===----------------------------------------------------------------------===//
471
472MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) {
473 MlirOperationState state;
474 state.name = name;
475 state.location = loc;
476 state.nResults = 0;
477 state.results = nullptr;
478 state.nOperands = 0;
479 state.operands = nullptr;
480 state.nRegions = 0;
481 state.regions = nullptr;
482 state.nSuccessors = 0;
483 state.successors = nullptr;
484 state.nAttributes = 0;
485 state.attributes = nullptr;
486 state.enableResultTypeInference = false;
487 return state;
488}
489
490#define APPEND_ELEMS(type, sizeName, elemName) \
491 state->elemName = \
492 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \
493 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \
494 state->sizeName += n;
495
496void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
497 MlirType const *results) {
498 APPEND_ELEMS(MlirType, nResults, results);
499}
500
501void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
502 MlirValue const *operands) {
503 APPEND_ELEMS(MlirValue, nOperands, operands);
504}
505void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
506 MlirRegion const *regions) {
507 APPEND_ELEMS(MlirRegion, nRegions, regions);
508}
509void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
510 MlirBlock const *successors) {
511 APPEND_ELEMS(MlirBlock, nSuccessors, successors);
512}
513void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
514 MlirNamedAttribute const *attributes) {
515 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
516}
517
518void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) {
519 state->enableResultTypeInference = true;
520}
521
522//===----------------------------------------------------------------------===//
523// Operation API.
524//===----------------------------------------------------------------------===//
525
526static LogicalResult inferOperationTypes(OperationState &state) {
527 MLIRContext *context = state.getContext();
528 std::optional<RegisteredOperationName> info = state.name.getRegisteredInfo();
529 if (!info) {
530 emitError(loc: state.location)
531 << "type inference was requested for the operation " << state.name
532 << ", but the operation was not registered; ensure that the dialect "
533 "containing the operation is linked into MLIR and registered with "
534 "the context";
535 return failure();
536 }
537
538 auto *inferInterface = info->getInterface<InferTypeOpInterface>();
539 if (!inferInterface) {
540 emitError(loc: state.location)
541 << "type inference was requested for the operation " << state.name
542 << ", but the operation does not support type inference; result "
543 "types must be specified explicitly";
544 return failure();
545 }
546
547 DictionaryAttr attributes = state.attributes.getDictionary(context);
548 OpaqueProperties properties = state.getRawProperties();
549
550 if (!properties && info->getOpPropertyByteSize() > 0 && !attributes.empty()) {
551 auto prop = std::make_unique<char[]>(num: info->getOpPropertyByteSize());
552 properties = OpaqueProperties(prop.get());
553 if (properties) {
554 auto emitError = [&]() {
555 return mlir::emitError(loc: state.location)
556 << " failed properties conversion while building "
557 << state.name.getStringRef() << " with `" << attributes << "`: ";
558 };
559 if (failed(Result: info->setOpPropertiesFromAttribute(opName: state.name, properties,
560 attr: attributes, emitError)))
561 return failure();
562 }
563 if (succeeded(Result: inferInterface->inferReturnTypes(
564 context, state.location, state.operands, attributes, properties,
565 state.regions, state.types))) {
566 return success();
567 }
568 // Diagnostic emitted by interface.
569 return failure();
570 }
571
572 if (succeeded(Result: inferInterface->inferReturnTypes(
573 context, state.location, state.operands, attributes, properties,
574 state.regions, state.types)))
575 return success();
576
577 // Diagnostic emitted by interface.
578 return failure();
579}
580
581MlirOperation mlirOperationCreate(MlirOperationState *state) {
582 assert(state);
583 OperationState cppState(unwrap(c: state->location), unwrap(ref: state->name));
584 SmallVector<Type, 4> resultStorage;
585 SmallVector<Value, 8> operandStorage;
586 SmallVector<Block *, 2> successorStorage;
587 cppState.addTypes(newTypes: unwrapList(size: state->nResults, first: state->results, storage&: resultStorage));
588 cppState.addOperands(
589 newOperands: unwrapList(size: state->nOperands, first: state->operands, storage&: operandStorage));
590 cppState.addSuccessors(
591 newSuccessors: unwrapList(size: state->nSuccessors, first: state->successors, storage&: successorStorage));
592
593 cppState.attributes.reserve(N: state->nAttributes);
594 for (intptr_t i = 0; i < state->nAttributes; ++i)
595 cppState.addAttribute(name: unwrap(c: state->attributes[i].name),
596 attr: unwrap(c: state->attributes[i].attribute));
597
598 for (intptr_t i = 0; i < state->nRegions; ++i)
599 cppState.addRegion(region: std::unique_ptr<Region>(unwrap(c: state->regions[i])));
600
601 free(ptr: state->results);
602 free(ptr: state->operands);
603 free(ptr: state->successors);
604 free(ptr: state->regions);
605 free(ptr: state->attributes);
606
607 // Infer result types.
608 if (state->enableResultTypeInference) {
609 assert(cppState.types.empty() &&
610 "result type inference enabled and result types provided");
611 if (failed(Result: inferOperationTypes(state&: cppState)))
612 return {.ptr: nullptr};
613 }
614
615 return wrap(cpp: Operation::create(state: cppState));
616}
617
618MlirOperation mlirOperationCreateParse(MlirContext context,
619 MlirStringRef sourceStr,
620 MlirStringRef sourceName) {
621
622 return wrap(
623 cpp: parseSourceString(sourceStr: unwrap(ref: sourceStr), config: unwrap(c: context), sourceName: unwrap(ref: sourceName))
624 .release());
625}
626
627MlirOperation mlirOperationClone(MlirOperation op) {
628 return wrap(cpp: unwrap(c: op)->clone());
629}
630
631void mlirOperationDestroy(MlirOperation op) { unwrap(c: op)->erase(); }
632
633void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(c: op)->remove(); }
634
635bool mlirOperationEqual(MlirOperation op, MlirOperation other) {
636 return unwrap(c: op) == unwrap(c: other);
637}
638
639MlirContext mlirOperationGetContext(MlirOperation op) {
640 return wrap(cpp: unwrap(c: op)->getContext());
641}
642
643MlirLocation mlirOperationGetLocation(MlirOperation op) {
644 return wrap(cpp: unwrap(c: op)->getLoc());
645}
646
647MlirTypeID mlirOperationGetTypeID(MlirOperation op) {
648 if (auto info = unwrap(c: op)->getRegisteredInfo())
649 return wrap(cpp: info->getTypeID());
650 return {.ptr: nullptr};
651}
652
653MlirIdentifier mlirOperationGetName(MlirOperation op) {
654 return wrap(cpp: unwrap(c: op)->getName().getIdentifier());
655}
656
657MlirBlock mlirOperationGetBlock(MlirOperation op) {
658 return wrap(cpp: unwrap(c: op)->getBlock());
659}
660
661MlirOperation mlirOperationGetParentOperation(MlirOperation op) {
662 return wrap(cpp: unwrap(c: op)->getParentOp());
663}
664
665intptr_t mlirOperationGetNumRegions(MlirOperation op) {
666 return static_cast<intptr_t>(unwrap(c: op)->getNumRegions());
667}
668
669MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
670 return wrap(cpp: &unwrap(c: op)->getRegion(index: static_cast<unsigned>(pos)));
671}
672
673MlirRegion mlirOperationGetFirstRegion(MlirOperation op) {
674 Operation *cppOp = unwrap(c: op);
675 if (cppOp->getNumRegions() == 0)
676 return wrap(cpp: static_cast<Region *>(nullptr));
677 return wrap(cpp: &cppOp->getRegion(index: 0));
678}
679
680MlirRegion mlirRegionGetNextInOperation(MlirRegion region) {
681 Region *cppRegion = unwrap(c: region);
682 Operation *parent = cppRegion->getParentOp();
683 intptr_t next = cppRegion->getRegionNumber() + 1;
684 if (parent->getNumRegions() > next)
685 return wrap(cpp: &parent->getRegion(index: next));
686 return wrap(cpp: static_cast<Region *>(nullptr));
687}
688
689MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
690 return wrap(cpp: unwrap(c: op)->getNextNode());
691}
692
693intptr_t mlirOperationGetNumOperands(MlirOperation op) {
694 return static_cast<intptr_t>(unwrap(c: op)->getNumOperands());
695}
696
697MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
698 return wrap(cpp: unwrap(c: op)->getOperand(idx: static_cast<unsigned>(pos)));
699}
700
701void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
702 MlirValue newValue) {
703 unwrap(c: op)->setOperand(idx: static_cast<unsigned>(pos), value: unwrap(c: newValue));
704}
705
706void mlirOperationSetOperands(MlirOperation op, intptr_t nOperands,
707 MlirValue const *operands) {
708 SmallVector<Value> ops;
709 unwrap(c: op)->setOperands(unwrapList(size: nOperands, first: operands, storage&: ops));
710}
711
712intptr_t mlirOperationGetNumResults(MlirOperation op) {
713 return static_cast<intptr_t>(unwrap(c: op)->getNumResults());
714}
715
716MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
717 return wrap(cpp: unwrap(c: op)->getResult(idx: static_cast<unsigned>(pos)));
718}
719
720intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
721 return static_cast<intptr_t>(unwrap(c: op)->getNumSuccessors());
722}
723
724MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
725 return wrap(cpp: unwrap(c: op)->getSuccessor(index: static_cast<unsigned>(pos)));
726}
727
728MLIR_CAPI_EXPORTED bool
729mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name) {
730 std::optional<Attribute> attr = unwrap(c: op)->getInherentAttr(name: unwrap(ref: name));
731 return attr.has_value();
732}
733
734MlirAttribute mlirOperationGetInherentAttributeByName(MlirOperation op,
735 MlirStringRef name) {
736 std::optional<Attribute> attr = unwrap(c: op)->getInherentAttr(name: unwrap(ref: name));
737 if (attr.has_value())
738 return wrap(cpp: *attr);
739 return {};
740}
741
742void mlirOperationSetInherentAttributeByName(MlirOperation op,
743 MlirStringRef name,
744 MlirAttribute attr) {
745 unwrap(c: op)->setInherentAttr(
746 name: StringAttr::get(context: unwrap(c: op)->getContext(), bytes: unwrap(ref: name)), value: unwrap(c: attr));
747}
748
749intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) {
750 return static_cast<intptr_t>(
751 llvm::range_size(Range: unwrap(c: op)->getDiscardableAttrs()));
752}
753
754MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op,
755 intptr_t pos) {
756 NamedAttribute attr =
757 *std::next(x: unwrap(c: op)->getDiscardableAttrs().begin(), n: pos);
758 return MlirNamedAttribute{.name: wrap(cpp: attr.getName()), .attribute: wrap(cpp: attr.getValue())};
759}
760
761MlirAttribute mlirOperationGetDiscardableAttributeByName(MlirOperation op,
762 MlirStringRef name) {
763 return wrap(cpp: unwrap(c: op)->getDiscardableAttr(name: unwrap(ref: name)));
764}
765
766void mlirOperationSetDiscardableAttributeByName(MlirOperation op,
767 MlirStringRef name,
768 MlirAttribute attr) {
769 unwrap(c: op)->setDiscardableAttr(name: unwrap(ref: name), value: unwrap(c: attr));
770}
771
772bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op,
773 MlirStringRef name) {
774 return !!unwrap(c: op)->removeDiscardableAttr(name: unwrap(ref: name));
775}
776
777void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos,
778 MlirBlock block) {
779 unwrap(c: op)->setSuccessor(block: unwrap(c: block), index: static_cast<unsigned>(pos));
780}
781
782intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
783 return static_cast<intptr_t>(unwrap(c: op)->getAttrs().size());
784}
785
786MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
787 NamedAttribute attr = unwrap(c: op)->getAttrs()[pos];
788 return MlirNamedAttribute{.name: wrap(cpp: attr.getName()), .attribute: wrap(cpp: attr.getValue())};
789}
790
791MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
792 MlirStringRef name) {
793 return wrap(cpp: unwrap(c: op)->getAttr(name: unwrap(ref: name)));
794}
795
796void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name,
797 MlirAttribute attr) {
798 unwrap(c: op)->setAttr(name: unwrap(ref: name), value: unwrap(c: attr));
799}
800
801bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) {
802 return !!unwrap(c: op)->removeAttr(name: unwrap(ref: name));
803}
804
805void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
806 void *userData) {
807 detail::CallbackOstream stream(callback, userData);
808 unwrap(c: op)->print(os&: stream);
809}
810
811void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
812 MlirStringCallback callback, void *userData) {
813 detail::CallbackOstream stream(callback, userData);
814 unwrap(c: op)->print(os&: stream, flags: *unwrap(c: flags));
815}
816
817void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state,
818 MlirStringCallback callback, void *userData) {
819 detail::CallbackOstream stream(callback, userData);
820 if (state.ptr)
821 unwrap(c: op)->print(os&: stream, state&: *unwrap(c: state));
822 unwrap(c: op)->print(os&: stream);
823}
824
825void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback,
826 void *userData) {
827 detail::CallbackOstream stream(callback, userData);
828 // As no desired version is set, no failure can occur.
829 (void)writeBytecodeToFile(op: unwrap(c: op), os&: stream);
830}
831
832MlirLogicalResult mlirOperationWriteBytecodeWithConfig(
833 MlirOperation op, MlirBytecodeWriterConfig config,
834 MlirStringCallback callback, void *userData) {
835 detail::CallbackOstream stream(callback, userData);
836 return wrap(res: writeBytecodeToFile(op: unwrap(c: op), os&: stream, config: *unwrap(c: config)));
837}
838
839void mlirOperationDump(MlirOperation op) { return unwrap(c: op)->dump(); }
840
841bool mlirOperationVerify(MlirOperation op) {
842 return succeeded(Result: verify(op: unwrap(c: op)));
843}
844
845void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) {
846 return unwrap(c: op)->moveAfter(existingOp: unwrap(c: other));
847}
848
849void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
850 return unwrap(c: op)->moveBefore(existingOp: unwrap(c: other));
851}
852
853static mlir::WalkResult unwrap(MlirWalkResult result) {
854 switch (result) {
855 case MlirWalkResultAdvance:
856 return mlir::WalkResult::advance();
857
858 case MlirWalkResultInterrupt:
859 return mlir::WalkResult::interrupt();
860
861 case MlirWalkResultSkip:
862 return mlir::WalkResult::skip();
863 }
864 llvm_unreachable("unknown result in WalkResult::unwrap");
865}
866
867void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
868 void *userData, MlirWalkOrder walkOrder) {
869 switch (walkOrder) {
870
871 case MlirWalkPreOrder:
872 unwrap(c: op)->walk<mlir::WalkOrder::PreOrder>(
873 callback: [callback, userData](Operation *op) {
874 return unwrap(result: callback(wrap(cpp: op), userData));
875 });
876 break;
877 case MlirWalkPostOrder:
878 unwrap(c: op)->walk<mlir::WalkOrder::PostOrder>(
879 callback: [callback, userData](Operation *op) {
880 return unwrap(result: callback(wrap(cpp: op), userData));
881 });
882 }
883}
884
885//===----------------------------------------------------------------------===//
886// Region API.
887//===----------------------------------------------------------------------===//
888
889MlirRegion mlirRegionCreate() { return wrap(cpp: new Region); }
890
891bool mlirRegionEqual(MlirRegion region, MlirRegion other) {
892 return unwrap(c: region) == unwrap(c: other);
893}
894
895MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
896 Region *cppRegion = unwrap(c: region);
897 if (cppRegion->empty())
898 return wrap(cpp: static_cast<Block *>(nullptr));
899 return wrap(cpp: &cppRegion->front());
900}
901
902void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
903 unwrap(c: region)->push_back(block: unwrap(c: block));
904}
905
906void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
907 MlirBlock block) {
908 auto &blockList = unwrap(c: region)->getBlocks();
909 blockList.insert(where: std::next(x: blockList.begin(), n: pos), New: unwrap(c: block));
910}
911
912void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
913 MlirBlock block) {
914 Region *cppRegion = unwrap(c: region);
915 if (mlirBlockIsNull(block: reference)) {
916 cppRegion->getBlocks().insert(where: cppRegion->begin(), New: unwrap(c: block));
917 return;
918 }
919
920 assert(unwrap(reference)->getParent() == unwrap(region) &&
921 "expected reference block to belong to the region");
922 cppRegion->getBlocks().insertAfter(where: Region::iterator(unwrap(c: reference)),
923 New: unwrap(c: block));
924}
925
926void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
927 MlirBlock block) {
928 if (mlirBlockIsNull(block: reference))
929 return mlirRegionAppendOwnedBlock(region, block);
930
931 assert(unwrap(reference)->getParent() == unwrap(region) &&
932 "expected reference block to belong to the region");
933 unwrap(c: region)->getBlocks().insert(where: Region::iterator(unwrap(c: reference)),
934 New: unwrap(c: block));
935}
936
937void mlirRegionDestroy(MlirRegion region) {
938 delete static_cast<Region *>(region.ptr);
939}
940
941void mlirRegionTakeBody(MlirRegion target, MlirRegion source) {
942 unwrap(c: target)->takeBody(other&: *unwrap(c: source));
943}
944
945//===----------------------------------------------------------------------===//
946// Block API.
947//===----------------------------------------------------------------------===//
948
949MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args,
950 MlirLocation const *locs) {
951 Block *b = new Block;
952 for (intptr_t i = 0; i < nArgs; ++i)
953 b->addArgument(type: unwrap(c: args[i]), loc: unwrap(c: locs[i]));
954 return wrap(cpp: b);
955}
956
957bool mlirBlockEqual(MlirBlock block, MlirBlock other) {
958 return unwrap(c: block) == unwrap(c: other);
959}
960
961MlirOperation mlirBlockGetParentOperation(MlirBlock block) {
962 return wrap(cpp: unwrap(c: block)->getParentOp());
963}
964
965MlirRegion mlirBlockGetParentRegion(MlirBlock block) {
966 return wrap(cpp: unwrap(c: block)->getParent());
967}
968
969MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
970 return wrap(cpp: unwrap(c: block)->getNextNode());
971}
972
973MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
974 Block *cppBlock = unwrap(c: block);
975 if (cppBlock->empty())
976 return wrap(cpp: static_cast<Operation *>(nullptr));
977 return wrap(cpp: &cppBlock->front());
978}
979
980MlirOperation mlirBlockGetTerminator(MlirBlock block) {
981 Block *cppBlock = unwrap(c: block);
982 if (cppBlock->empty())
983 return wrap(cpp: static_cast<Operation *>(nullptr));
984 Operation &back = cppBlock->back();
985 if (!back.hasTrait<OpTrait::IsTerminator>())
986 return wrap(cpp: static_cast<Operation *>(nullptr));
987 return wrap(cpp: &back);
988}
989
990void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
991 unwrap(c: block)->push_back(op: unwrap(c: operation));
992}
993
994void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
995 MlirOperation operation) {
996 auto &opList = unwrap(c: block)->getOperations();
997 opList.insert(where: std::next(x: opList.begin(), n: pos), New: unwrap(c: operation));
998}
999
1000void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
1001 MlirOperation reference,
1002 MlirOperation operation) {
1003 Block *cppBlock = unwrap(c: block);
1004 if (mlirOperationIsNull(op: reference)) {
1005 cppBlock->getOperations().insert(where: cppBlock->begin(), New: unwrap(c: operation));
1006 return;
1007 }
1008
1009 assert(unwrap(reference)->getBlock() == unwrap(block) &&
1010 "expected reference operation to belong to the block");
1011 cppBlock->getOperations().insertAfter(where: Block::iterator(unwrap(c: reference)),
1012 New: unwrap(c: operation));
1013}
1014
1015void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
1016 MlirOperation reference,
1017 MlirOperation operation) {
1018 if (mlirOperationIsNull(op: reference))
1019 return mlirBlockAppendOwnedOperation(block, operation);
1020
1021 assert(unwrap(reference)->getBlock() == unwrap(block) &&
1022 "expected reference operation to belong to the block");
1023 unwrap(c: block)->getOperations().insert(where: Block::iterator(unwrap(c: reference)),
1024 New: unwrap(c: operation));
1025}
1026
1027void mlirBlockDestroy(MlirBlock block) { delete unwrap(c: block); }
1028
1029void mlirBlockDetach(MlirBlock block) {
1030 Block *b = unwrap(c: block);
1031 b->getParent()->getBlocks().remove(IT: b);
1032}
1033
1034intptr_t mlirBlockGetNumArguments(MlirBlock block) {
1035 return static_cast<intptr_t>(unwrap(c: block)->getNumArguments());
1036}
1037
1038MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type,
1039 MlirLocation loc) {
1040 return wrap(cpp: unwrap(c: block)->addArgument(type: unwrap(c: type), loc: unwrap(c: loc)));
1041}
1042
1043void mlirBlockEraseArgument(MlirBlock block, unsigned index) {
1044 return unwrap(c: block)->eraseArgument(index);
1045}
1046
1047MlirValue mlirBlockInsertArgument(MlirBlock block, intptr_t pos, MlirType type,
1048 MlirLocation loc) {
1049 return wrap(cpp: unwrap(c: block)->insertArgument(index: pos, type: unwrap(c: type), loc: unwrap(c: loc)));
1050}
1051
1052MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
1053 return wrap(cpp: unwrap(c: block)->getArgument(i: static_cast<unsigned>(pos)));
1054}
1055
1056void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
1057 void *userData) {
1058 detail::CallbackOstream stream(callback, userData);
1059 unwrap(c: block)->print(os&: stream);
1060}
1061
1062intptr_t mlirBlockGetNumSuccessors(MlirBlock block) {
1063 return static_cast<intptr_t>(unwrap(c: block)->getNumSuccessors());
1064}
1065
1066MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos) {
1067 return wrap(cpp: unwrap(c: block)->getSuccessor(i: static_cast<unsigned>(pos)));
1068}
1069
1070intptr_t mlirBlockGetNumPredecessors(MlirBlock block) {
1071 Block *b = unwrap(c: block);
1072 return static_cast<intptr_t>(std::distance(first: b->pred_begin(), last: b->pred_end()));
1073}
1074
1075MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos) {
1076 Block *b = unwrap(c: block);
1077 Block::pred_iterator it = b->pred_begin();
1078 std::advance(i&: it, n: pos);
1079 return wrap(cpp: *it);
1080}
1081
1082//===----------------------------------------------------------------------===//
1083// Value API.
1084//===----------------------------------------------------------------------===//
1085
1086bool mlirValueEqual(MlirValue value1, MlirValue value2) {
1087 return unwrap(c: value1) == unwrap(c: value2);
1088}
1089
1090bool mlirValueIsABlockArgument(MlirValue value) {
1091 return llvm::isa<BlockArgument>(Val: unwrap(c: value));
1092}
1093
1094bool mlirValueIsAOpResult(MlirValue value) {
1095 return llvm::isa<OpResult>(Val: unwrap(c: value));
1096}
1097
1098MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
1099 return wrap(cpp: llvm::dyn_cast<BlockArgument>(Val: unwrap(c: value)).getOwner());
1100}
1101
1102intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
1103 return static_cast<intptr_t>(
1104 llvm::dyn_cast<BlockArgument>(Val: unwrap(c: value)).getArgNumber());
1105}
1106
1107void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
1108 if (auto blockArg = llvm::dyn_cast<BlockArgument>(Val: unwrap(c: value)))
1109 blockArg.setType(unwrap(c: type));
1110}
1111
1112MlirOperation mlirOpResultGetOwner(MlirValue value) {
1113 return wrap(cpp: llvm::dyn_cast<OpResult>(Val: unwrap(c: value)).getOwner());
1114}
1115
1116intptr_t mlirOpResultGetResultNumber(MlirValue value) {
1117 return static_cast<intptr_t>(
1118 llvm::dyn_cast<OpResult>(Val: unwrap(c: value)).getResultNumber());
1119}
1120
1121MlirType mlirValueGetType(MlirValue value) {
1122 return wrap(cpp: unwrap(c: value).getType());
1123}
1124
1125void mlirValueSetType(MlirValue value, MlirType type) {
1126 unwrap(c: value).setType(unwrap(c: type));
1127}
1128
1129void mlirValueDump(MlirValue value) { unwrap(c: value).dump(); }
1130
1131void mlirValuePrint(MlirValue value, MlirStringCallback callback,
1132 void *userData) {
1133 detail::CallbackOstream stream(callback, userData);
1134 unwrap(c: value).print(os&: stream);
1135}
1136
1137void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state,
1138 MlirStringCallback callback, void *userData) {
1139 detail::CallbackOstream stream(callback, userData);
1140 Value cppValue = unwrap(c: value);
1141 cppValue.printAsOperand(os&: stream, state&: *unwrap(c: state));
1142}
1143
1144MlirOpOperand mlirValueGetFirstUse(MlirValue value) {
1145 Value cppValue = unwrap(c: value);
1146 if (cppValue.use_empty())
1147 return {};
1148
1149 OpOperand *opOperand = cppValue.use_begin().getOperand();
1150
1151 return wrap(cpp: opOperand);
1152}
1153
1154void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) {
1155 unwrap(c: oldValue).replaceAllUsesWith(newValue: unwrap(c: newValue));
1156}
1157
1158void mlirValueReplaceAllUsesExcept(MlirValue oldValue, MlirValue newValue,
1159 intptr_t numExceptions,
1160 MlirOperation *exceptions) {
1161 Value oldValueCpp = unwrap(c: oldValue);
1162 Value newValueCpp = unwrap(c: newValue);
1163
1164 llvm::SmallPtrSet<mlir::Operation *, 4> exceptionSet;
1165 for (intptr_t i = 0; i < numExceptions; ++i) {
1166 exceptionSet.insert(Ptr: unwrap(c: exceptions[i]));
1167 }
1168
1169 oldValueCpp.replaceAllUsesExcept(newValue: newValueCpp, exceptions: exceptionSet);
1170}
1171
1172MlirLocation mlirValueGetLocation(MlirValue v) {
1173 return wrap(cpp: unwrap(c: v).getLoc());
1174}
1175
1176MlirContext mlirValueGetContext(MlirValue v) {
1177 return wrap(cpp: unwrap(c: v).getContext());
1178}
1179
1180//===----------------------------------------------------------------------===//
1181// OpOperand API.
1182//===----------------------------------------------------------------------===//
1183
1184bool mlirOpOperandIsNull(MlirOpOperand opOperand) { return !opOperand.ptr; }
1185
1186MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand) {
1187 return wrap(cpp: unwrap(c: opOperand)->getOwner());
1188}
1189
1190MlirValue mlirOpOperandGetValue(MlirOpOperand opOperand) {
1191 return wrap(cpp: unwrap(c: opOperand)->get());
1192}
1193
1194unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand) {
1195 return unwrap(c: opOperand)->getOperandNumber();
1196}
1197
1198MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand) {
1199 if (mlirOpOperandIsNull(opOperand))
1200 return {};
1201
1202 OpOperand *nextOpOperand = static_cast<OpOperand *>(
1203 unwrap(c: opOperand)->getNextOperandUsingThisValue());
1204
1205 if (!nextOpOperand)
1206 return {};
1207
1208 return wrap(cpp: nextOpOperand);
1209}
1210
1211//===----------------------------------------------------------------------===//
1212// Type API.
1213//===----------------------------------------------------------------------===//
1214
1215MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) {
1216 return wrap(cpp: mlir::parseType(typeStr: unwrap(ref: type), context: unwrap(c: context)));
1217}
1218
1219MlirContext mlirTypeGetContext(MlirType type) {
1220 return wrap(cpp: unwrap(c: type).getContext());
1221}
1222
1223MlirTypeID mlirTypeGetTypeID(MlirType type) {
1224 return wrap(cpp: unwrap(c: type).getTypeID());
1225}
1226
1227MlirDialect mlirTypeGetDialect(MlirType type) {
1228 return wrap(cpp: &unwrap(c: type).getDialect());
1229}
1230
1231bool mlirTypeEqual(MlirType t1, MlirType t2) {
1232 return unwrap(c: t1) == unwrap(c: t2);
1233}
1234
1235void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
1236 detail::CallbackOstream stream(callback, userData);
1237 unwrap(c: type).print(os&: stream);
1238}
1239
1240void mlirTypeDump(MlirType type) { unwrap(c: type).dump(); }
1241
1242//===----------------------------------------------------------------------===//
1243// Attribute API.
1244//===----------------------------------------------------------------------===//
1245
1246MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) {
1247 return wrap(cpp: mlir::parseAttribute(attrStr: unwrap(ref: attr), context: unwrap(c: context)));
1248}
1249
1250MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
1251 return wrap(cpp: unwrap(c: attribute).getContext());
1252}
1253
1254MlirType mlirAttributeGetType(MlirAttribute attribute) {
1255 Attribute attr = unwrap(c: attribute);
1256 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(Val&: attr))
1257 return wrap(cpp: typedAttr.getType());
1258 return wrap(cpp: NoneType::get(context: attr.getContext()));
1259}
1260
1261MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) {
1262 return wrap(cpp: unwrap(c: attr).getTypeID());
1263}
1264
1265MlirDialect mlirAttributeGetDialect(MlirAttribute attr) {
1266 return wrap(cpp: &unwrap(c: attr).getDialect());
1267}
1268
1269bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
1270 return unwrap(c: a1) == unwrap(c: a2);
1271}
1272
1273void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
1274 void *userData) {
1275 detail::CallbackOstream stream(callback, userData);
1276 unwrap(c: attr).print(os&: stream);
1277}
1278
1279void mlirAttributeDump(MlirAttribute attr) { unwrap(c: attr).dump(); }
1280
1281MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name,
1282 MlirAttribute attr) {
1283 return MlirNamedAttribute{.name: name, .attribute: attr};
1284}
1285
1286//===----------------------------------------------------------------------===//
1287// Identifier API.
1288//===----------------------------------------------------------------------===//
1289
1290MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) {
1291 return wrap(cpp: StringAttr::get(context: unwrap(c: context), bytes: unwrap(ref: str)));
1292}
1293
1294MlirContext mlirIdentifierGetContext(MlirIdentifier ident) {
1295 return wrap(cpp: unwrap(c: ident).getContext());
1296}
1297
1298bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) {
1299 return unwrap(c: ident) == unwrap(c: other);
1300}
1301
1302MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
1303 return wrap(ref: unwrap(c: ident).strref());
1304}
1305
1306//===----------------------------------------------------------------------===//
1307// Symbol and SymbolTable API.
1308//===----------------------------------------------------------------------===//
1309
1310MlirStringRef mlirSymbolTableGetSymbolAttributeName() {
1311 return wrap(ref: SymbolTable::getSymbolAttrName());
1312}
1313
1314MlirStringRef mlirSymbolTableGetVisibilityAttributeName() {
1315 return wrap(ref: SymbolTable::getVisibilityAttrName());
1316}
1317
1318MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) {
1319 if (!unwrap(c: operation)->hasTrait<OpTrait::SymbolTable>())
1320 return wrap(cpp: static_cast<SymbolTable *>(nullptr));
1321 return wrap(cpp: new SymbolTable(unwrap(c: operation)));
1322}
1323
1324void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) {
1325 delete unwrap(c: symbolTable);
1326}
1327
1328MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable,
1329 MlirStringRef name) {
1330 return wrap(cpp: unwrap(c: symbolTable)->lookup(name: StringRef(name.data, name.length)));
1331}
1332
1333MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable,
1334 MlirOperation operation) {
1335 return wrap(cpp: (Attribute)unwrap(c: symbolTable)->insert(symbol: unwrap(c: operation)));
1336}
1337
1338void mlirSymbolTableErase(MlirSymbolTable symbolTable,
1339 MlirOperation operation) {
1340 unwrap(c: symbolTable)->erase(symbol: unwrap(c: operation));
1341}
1342
1343MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol,
1344 MlirStringRef newSymbol,
1345 MlirOperation from) {
1346 auto *cppFrom = unwrap(c: from);
1347 auto *context = cppFrom->getContext();
1348 auto oldSymbolAttr = StringAttr::get(context, bytes: unwrap(ref: oldSymbol));
1349 auto newSymbolAttr = StringAttr::get(context, bytes: unwrap(ref: newSymbol));
1350 return wrap(res: SymbolTable::replaceAllSymbolUses(oldSymbol: oldSymbolAttr, newSymbol: newSymbolAttr,
1351 from: unwrap(c: from)));
1352}
1353
1354void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible,
1355 void (*callback)(MlirOperation, bool,
1356 void *userData),
1357 void *userData) {
1358 SymbolTable::walkSymbolTables(op: unwrap(c: from), allSymUsesVisible,
1359 callback: [&](Operation *foundOpCpp, bool isVisible) {
1360 callback(wrap(cpp: foundOpCpp), isVisible,
1361 userData);
1362 });
1363}
1364

source code of mlir/lib/CAPI/IR/IR.cpp