1//===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir-c/IR.h"
10#include "mlir-c/Support.h"
11
12#include "mlir/AsmParser/AsmParser.h"
13#include "mlir/Bytecode/BytecodeWriter.h"
14#include "mlir/CAPI/IR.h"
15#include "mlir/CAPI/Support.h"
16#include "mlir/CAPI/Utils.h"
17#include "mlir/IR/Attributes.h"
18#include "mlir/IR/BuiltinAttributes.h"
19#include "mlir/IR/BuiltinOps.h"
20#include "mlir/IR/Diagnostics.h"
21#include "mlir/IR/Dialect.h"
22#include "mlir/IR/Location.h"
23#include "mlir/IR/Operation.h"
24#include "mlir/IR/OperationSupport.h"
25#include "mlir/IR/Types.h"
26#include "mlir/IR/Value.h"
27#include "mlir/IR/Verifier.h"
28#include "mlir/IR/Visitors.h"
29#include "mlir/Interfaces/InferTypeOpInterface.h"
30#include "mlir/Parser/Parser.h"
31#include "llvm/Support/ThreadPool.h"
32
33#include <cstddef>
34#include <memory>
35#include <optional>
36
37using namespace mlir;
38
39//===----------------------------------------------------------------------===//
40// Context API.
41//===----------------------------------------------------------------------===//
42
43MlirContext mlirContextCreate() {
44 auto *context = new MLIRContext;
45 return wrap(cpp: context);
46}
47
48static inline MLIRContext::Threading toThreadingEnum(bool threadingEnabled) {
49 return threadingEnabled ? MLIRContext::Threading::ENABLED
50 : MLIRContext::Threading::DISABLED;
51}
52
53MlirContext mlirContextCreateWithThreading(bool threadingEnabled) {
54 auto *context = new MLIRContext(toThreadingEnum(threadingEnabled));
55 return wrap(cpp: context);
56}
57
58MlirContext mlirContextCreateWithRegistry(MlirDialectRegistry registry,
59 bool threadingEnabled) {
60 auto *context =
61 new MLIRContext(*unwrap(c: registry), toThreadingEnum(threadingEnabled));
62 return wrap(cpp: context);
63}
64
65bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
66 return unwrap(c: ctx1) == unwrap(c: ctx2);
67}
68
69void mlirContextDestroy(MlirContext context) { delete unwrap(c: context); }
70
71void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) {
72 unwrap(c: context)->allowUnregisteredDialects(allow);
73}
74
75bool mlirContextGetAllowUnregisteredDialects(MlirContext context) {
76 return unwrap(c: context)->allowsUnregisteredDialects();
77}
78intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
79 return static_cast<intptr_t>(unwrap(c: context)->getAvailableDialects().size());
80}
81
82void mlirContextAppendDialectRegistry(MlirContext ctx,
83 MlirDialectRegistry registry) {
84 unwrap(c: ctx)->appendDialectRegistry(registry: *unwrap(c: registry));
85}
86
87// TODO: expose a cheaper way than constructing + sorting a vector only to take
88// its size.
89intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
90 return static_cast<intptr_t>(unwrap(c: context)->getLoadedDialects().size());
91}
92
93MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
94 MlirStringRef name) {
95 return wrap(cpp: unwrap(c: context)->getOrLoadDialect(name: unwrap(ref: name)));
96}
97
98bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) {
99 return unwrap(c: context)->isOperationRegistered(name: unwrap(ref: name));
100}
101
102void mlirContextEnableMultithreading(MlirContext context, bool enable) {
103 return unwrap(c: context)->enableMultithreading(enable);
104}
105
106void mlirContextLoadAllAvailableDialects(MlirContext context) {
107 unwrap(c: context)->loadAllAvailableDialects();
108}
109
110void mlirContextSetThreadPool(MlirContext context,
111 MlirLlvmThreadPool threadPool) {
112 unwrap(c: context)->setThreadPool(*unwrap(c: threadPool));
113}
114
115//===----------------------------------------------------------------------===//
116// Dialect API.
117//===----------------------------------------------------------------------===//
118
119MlirContext mlirDialectGetContext(MlirDialect dialect) {
120 return wrap(cpp: unwrap(c: dialect)->getContext());
121}
122
123bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
124 return unwrap(c: dialect1) == unwrap(c: dialect2);
125}
126
127MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
128 return wrap(ref: unwrap(c: dialect)->getNamespace());
129}
130
131//===----------------------------------------------------------------------===//
132// DialectRegistry API.
133//===----------------------------------------------------------------------===//
134
135MlirDialectRegistry mlirDialectRegistryCreate() {
136 return wrap(cpp: new DialectRegistry());
137}
138
139void mlirDialectRegistryDestroy(MlirDialectRegistry registry) {
140 delete unwrap(c: registry);
141}
142
143//===----------------------------------------------------------------------===//
144// AsmState API.
145//===----------------------------------------------------------------------===//
146
147MlirAsmState mlirAsmStateCreateForOperation(MlirOperation op,
148 MlirOpPrintingFlags flags) {
149 return wrap(cpp: new AsmState(unwrap(c: op), *unwrap(c: flags)));
150}
151
152static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
153 do {
154 // If we are printing local scope, stop at the first operation that is
155 // isolated from above.
156 if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
157 break;
158
159 // Otherwise, traverse up to the next parent.
160 Operation *parentOp = op->getParentOp();
161 if (!parentOp)
162 break;
163 op = parentOp;
164 } while (true);
165 return op;
166}
167
168MlirAsmState mlirAsmStateCreateForValue(MlirValue value,
169 MlirOpPrintingFlags flags) {
170 Operation *op;
171 mlir::Value val = unwrap(c: value);
172 if (auto result = llvm::dyn_cast<OpResult>(Val&: val)) {
173 op = result.getOwner();
174 } else {
175 op = llvm::cast<BlockArgument>(Val&: val).getOwner()->getParentOp();
176 if (!op) {
177 emitError(loc: val.getLoc()) << "<<UNKNOWN SSA VALUE>>";
178 return {.ptr: nullptr};
179 }
180 }
181 op = findParent(op, shouldUseLocalScope: unwrap(c: flags)->shouldUseLocalScope());
182 return wrap(cpp: new AsmState(op, *unwrap(c: flags)));
183}
184
185/// Destroys printing flags created with mlirAsmStateCreate.
186void mlirAsmStateDestroy(MlirAsmState state) { delete unwrap(c: state); }
187
188//===----------------------------------------------------------------------===//
189// Printing flags API.
190//===----------------------------------------------------------------------===//
191
192MlirOpPrintingFlags mlirOpPrintingFlagsCreate() {
193 return wrap(cpp: new OpPrintingFlags());
194}
195
196void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) {
197 delete unwrap(c: flags);
198}
199
200void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,
201 intptr_t largeElementLimit) {
202 unwrap(c: flags)->elideLargeElementsAttrs(largeElementLimit);
203}
204
205void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable,
206 bool prettyForm) {
207 unwrap(c: flags)->enableDebugInfo(enable, /*prettyForm=*/prettyForm);
208}
209
210void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) {
211 unwrap(c: flags)->printGenericOpForm();
212}
213
214void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
215 unwrap(c: flags)->useLocalScope();
216}
217
218void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) {
219 unwrap(c: flags)->assumeVerified();
220}
221
222//===----------------------------------------------------------------------===//
223// Bytecode printing flags API.
224//===----------------------------------------------------------------------===//
225
226MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate() {
227 return wrap(cpp: new BytecodeWriterConfig());
228}
229
230void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config) {
231 delete unwrap(c: config);
232}
233
234void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags,
235 int64_t version) {
236 unwrap(c: flags)->setDesiredBytecodeVersion(version);
237}
238
239//===----------------------------------------------------------------------===//
240// Location API.
241//===----------------------------------------------------------------------===//
242
243MlirAttribute mlirLocationGetAttribute(MlirLocation location) {
244 return wrap(cpp: LocationAttr(unwrap(c: location)));
245}
246
247MlirLocation mlirLocationFromAttribute(MlirAttribute attribute) {
248 return wrap(cpp: Location(llvm::cast<LocationAttr>(Val: unwrap(c: attribute))));
249}
250
251MlirLocation mlirLocationFileLineColGet(MlirContext context,
252 MlirStringRef filename, unsigned line,
253 unsigned col) {
254 return wrap(Location(
255 FileLineColLoc::get(unwrap(context), unwrap(filename), line, col)));
256}
257
258MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) {
259 return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller))));
260}
261
262MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
263 MlirLocation const *locations,
264 MlirAttribute metadata) {
265 SmallVector<Location, 4> locs;
266 ArrayRef<Location> unwrappedLocs = unwrapList(size: nLocations, first: locations, storage&: locs);
267 return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx)));
268}
269
270MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name,
271 MlirLocation childLoc) {
272 if (mlirLocationIsNull(childLoc))
273 return wrap(
274 Location(NameLoc::get(StringAttr::get(unwrap(context), unwrap(name)))));
275 return wrap(Location(NameLoc::get(
276 StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc))));
277}
278
279MlirLocation mlirLocationUnknownGet(MlirContext context) {
280 return wrap(Location(UnknownLoc::get(unwrap(context))));
281}
282
283bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) {
284 return unwrap(c: l1) == unwrap(c: l2);
285}
286
287MlirContext mlirLocationGetContext(MlirLocation location) {
288 return wrap(cpp: unwrap(c: location).getContext());
289}
290
291void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
292 void *userData) {
293 detail::CallbackOstream stream(callback, userData);
294 unwrap(c: location).print(os&: stream);
295}
296
297//===----------------------------------------------------------------------===//
298// Module API.
299//===----------------------------------------------------------------------===//
300
301MlirModule mlirModuleCreateEmpty(MlirLocation location) {
302 return wrap(ModuleOp::create(unwrap(location)));
303}
304
305MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) {
306 OwningOpRef<ModuleOp> owning =
307 parseSourceString<ModuleOp>(unwrap(ref: module), unwrap(c: context));
308 if (!owning)
309 return MlirModule{.ptr: nullptr};
310 return MlirModule{owning.release().getOperation()};
311}
312
313MlirContext mlirModuleGetContext(MlirModule module) {
314 return wrap(unwrap(module).getContext());
315}
316
317MlirBlock mlirModuleGetBody(MlirModule module) {
318 return wrap(unwrap(module).getBody());
319}
320
321void mlirModuleDestroy(MlirModule module) {
322 // Transfer ownership to an OwningOpRef<ModuleOp> so that its destructor is
323 // called.
324 OwningOpRef<ModuleOp>(unwrap(module));
325}
326
327MlirOperation mlirModuleGetOperation(MlirModule module) {
328 return wrap(unwrap(module).getOperation());
329}
330
331MlirModule mlirModuleFromOperation(MlirOperation op) {
332 return wrap(dyn_cast<ModuleOp>(unwrap(c: op)));
333}
334
335//===----------------------------------------------------------------------===//
336// Operation state API.
337//===----------------------------------------------------------------------===//
338
339MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) {
340 MlirOperationState state;
341 state.name = name;
342 state.location = loc;
343 state.nResults = 0;
344 state.results = nullptr;
345 state.nOperands = 0;
346 state.operands = nullptr;
347 state.nRegions = 0;
348 state.regions = nullptr;
349 state.nSuccessors = 0;
350 state.successors = nullptr;
351 state.nAttributes = 0;
352 state.attributes = nullptr;
353 state.enableResultTypeInference = false;
354 return state;
355}
356
357#define APPEND_ELEMS(type, sizeName, elemName) \
358 state->elemName = \
359 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \
360 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \
361 state->sizeName += n;
362
363void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
364 MlirType const *results) {
365 APPEND_ELEMS(MlirType, nResults, results);
366}
367
368void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
369 MlirValue const *operands) {
370 APPEND_ELEMS(MlirValue, nOperands, operands);
371}
372void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
373 MlirRegion const *regions) {
374 APPEND_ELEMS(MlirRegion, nRegions, regions);
375}
376void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
377 MlirBlock const *successors) {
378 APPEND_ELEMS(MlirBlock, nSuccessors, successors);
379}
380void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
381 MlirNamedAttribute const *attributes) {
382 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
383}
384
385void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) {
386 state->enableResultTypeInference = true;
387}
388
389//===----------------------------------------------------------------------===//
390// Operation API.
391//===----------------------------------------------------------------------===//
392
393static LogicalResult inferOperationTypes(OperationState &state) {
394 MLIRContext *context = state.getContext();
395 std::optional<RegisteredOperationName> info = state.name.getRegisteredInfo();
396 if (!info) {
397 emitError(loc: state.location)
398 << "type inference was requested for the operation " << state.name
399 << ", but the operation was not registered; ensure that the dialect "
400 "containing the operation is linked into MLIR and registered with "
401 "the context";
402 return failure();
403 }
404
405 auto *inferInterface = info->getInterface<InferTypeOpInterface>();
406 if (!inferInterface) {
407 emitError(loc: state.location)
408 << "type inference was requested for the operation " << state.name
409 << ", but the operation does not support type inference; result "
410 "types must be specified explicitly";
411 return failure();
412 }
413
414 DictionaryAttr attributes = state.attributes.getDictionary(context);
415 OpaqueProperties properties = state.getRawProperties();
416
417 if (!properties && info->getOpPropertyByteSize() > 0 && !attributes.empty()) {
418 auto prop = std::make_unique<char[]>(num: info->getOpPropertyByteSize());
419 properties = OpaqueProperties(prop.get());
420 if (properties) {
421 auto emitError = [&]() {
422 return mlir::emitError(state.location)
423 << " failed properties conversion while building "
424 << state.name.getStringRef() << " with `" << attributes << "`: ";
425 };
426 if (failed(info->setOpPropertiesFromAttribute(opName: state.name, properties,
427 attr: attributes, emitError)))
428 return failure();
429 }
430 if (succeeded(inferInterface->inferReturnTypes(
431 context, state.location, state.operands, attributes, properties,
432 state.regions, state.types))) {
433 return success();
434 }
435 // Diagnostic emitted by interface.
436 return failure();
437 }
438
439 if (succeeded(inferInterface->inferReturnTypes(
440 context, state.location, state.operands, attributes, properties,
441 state.regions, state.types)))
442 return success();
443
444 // Diagnostic emitted by interface.
445 return failure();
446}
447
448MlirOperation mlirOperationCreate(MlirOperationState *state) {
449 assert(state);
450 OperationState cppState(unwrap(c: state->location), unwrap(ref: state->name));
451 SmallVector<Type, 4> resultStorage;
452 SmallVector<Value, 8> operandStorage;
453 SmallVector<Block *, 2> successorStorage;
454 cppState.addTypes(newTypes: unwrapList(size: state->nResults, first: state->results, storage&: resultStorage));
455 cppState.addOperands(
456 newOperands: unwrapList(size: state->nOperands, first: state->operands, storage&: operandStorage));
457 cppState.addSuccessors(
458 newSuccessors: unwrapList(size: state->nSuccessors, first: state->successors, storage&: successorStorage));
459
460 cppState.attributes.reserve(N: state->nAttributes);
461 for (intptr_t i = 0; i < state->nAttributes; ++i)
462 cppState.addAttribute(unwrap(state->attributes[i].name),
463 unwrap(c: state->attributes[i].attribute));
464
465 for (intptr_t i = 0; i < state->nRegions; ++i)
466 cppState.addRegion(region: std::unique_ptr<Region>(unwrap(c: state->regions[i])));
467
468 free(ptr: state->results);
469 free(ptr: state->operands);
470 free(ptr: state->successors);
471 free(ptr: state->regions);
472 free(ptr: state->attributes);
473
474 // Infer result types.
475 if (state->enableResultTypeInference) {
476 assert(cppState.types.empty() &&
477 "result type inference enabled and result types provided");
478 if (failed(result: inferOperationTypes(state&: cppState)))
479 return {.ptr: nullptr};
480 }
481
482 return wrap(cpp: Operation::create(state: cppState));
483}
484
485MlirOperation mlirOperationCreateParse(MlirContext context,
486 MlirStringRef sourceStr,
487 MlirStringRef sourceName) {
488
489 return wrap(
490 cpp: parseSourceString(sourceStr: unwrap(ref: sourceStr), config: unwrap(c: context), sourceName: unwrap(ref: sourceName))
491 .release());
492}
493
494MlirOperation mlirOperationClone(MlirOperation op) {
495 return wrap(cpp: unwrap(c: op)->clone());
496}
497
498void mlirOperationDestroy(MlirOperation op) { unwrap(c: op)->erase(); }
499
500void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(c: op)->remove(); }
501
502bool mlirOperationEqual(MlirOperation op, MlirOperation other) {
503 return unwrap(c: op) == unwrap(c: other);
504}
505
506MlirContext mlirOperationGetContext(MlirOperation op) {
507 return wrap(cpp: unwrap(c: op)->getContext());
508}
509
510MlirLocation mlirOperationGetLocation(MlirOperation op) {
511 return wrap(cpp: unwrap(c: op)->getLoc());
512}
513
514MlirTypeID mlirOperationGetTypeID(MlirOperation op) {
515 if (auto info = unwrap(c: op)->getRegisteredInfo())
516 return wrap(cpp: info->getTypeID());
517 return {.ptr: nullptr};
518}
519
520MlirIdentifier mlirOperationGetName(MlirOperation op) {
521 return wrap(unwrap(c: op)->getName().getIdentifier());
522}
523
524MlirBlock mlirOperationGetBlock(MlirOperation op) {
525 return wrap(cpp: unwrap(c: op)->getBlock());
526}
527
528MlirOperation mlirOperationGetParentOperation(MlirOperation op) {
529 return wrap(cpp: unwrap(c: op)->getParentOp());
530}
531
532intptr_t mlirOperationGetNumRegions(MlirOperation op) {
533 return static_cast<intptr_t>(unwrap(c: op)->getNumRegions());
534}
535
536MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
537 return wrap(cpp: &unwrap(c: op)->getRegion(index: static_cast<unsigned>(pos)));
538}
539
540MlirRegion mlirOperationGetFirstRegion(MlirOperation op) {
541 Operation *cppOp = unwrap(c: op);
542 if (cppOp->getNumRegions() == 0)
543 return wrap(cpp: static_cast<Region *>(nullptr));
544 return wrap(cpp: &cppOp->getRegion(index: 0));
545}
546
547MlirRegion mlirRegionGetNextInOperation(MlirRegion region) {
548 Region *cppRegion = unwrap(c: region);
549 Operation *parent = cppRegion->getParentOp();
550 intptr_t next = cppRegion->getRegionNumber() + 1;
551 if (parent->getNumRegions() > next)
552 return wrap(cpp: &parent->getRegion(index: next));
553 return wrap(cpp: static_cast<Region *>(nullptr));
554}
555
556MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
557 return wrap(unwrap(c: op)->getNextNode());
558}
559
560intptr_t mlirOperationGetNumOperands(MlirOperation op) {
561 return static_cast<intptr_t>(unwrap(c: op)->getNumOperands());
562}
563
564MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
565 return wrap(cpp: unwrap(c: op)->getOperand(idx: static_cast<unsigned>(pos)));
566}
567
568void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
569 MlirValue newValue) {
570 unwrap(c: op)->setOperand(idx: static_cast<unsigned>(pos), value: unwrap(c: newValue));
571}
572
573void mlirOperationSetOperands(MlirOperation op, intptr_t nOperands,
574 MlirValue const *operands) {
575 SmallVector<Value> ops;
576 unwrap(c: op)->setOperands(unwrapList(size: nOperands, first: operands, storage&: ops));
577}
578
579intptr_t mlirOperationGetNumResults(MlirOperation op) {
580 return static_cast<intptr_t>(unwrap(c: op)->getNumResults());
581}
582
583MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
584 return wrap(cpp: unwrap(c: op)->getResult(idx: static_cast<unsigned>(pos)));
585}
586
587intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
588 return static_cast<intptr_t>(unwrap(c: op)->getNumSuccessors());
589}
590
591MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
592 return wrap(cpp: unwrap(c: op)->getSuccessor(index: static_cast<unsigned>(pos)));
593}
594
595MLIR_CAPI_EXPORTED bool
596mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name) {
597 std::optional<Attribute> attr = unwrap(c: op)->getInherentAttr(name: unwrap(ref: name));
598 return attr.has_value();
599}
600
601MlirAttribute mlirOperationGetInherentAttributeByName(MlirOperation op,
602 MlirStringRef name) {
603 std::optional<Attribute> attr = unwrap(c: op)->getInherentAttr(name: unwrap(ref: name));
604 if (attr.has_value())
605 return wrap(cpp: *attr);
606 return {};
607}
608
609void mlirOperationSetInherentAttributeByName(MlirOperation op,
610 MlirStringRef name,
611 MlirAttribute attr) {
612 unwrap(op)->setInherentAttr(
613 StringAttr::get(unwrap(op)->getContext(), unwrap(name)), unwrap(attr));
614}
615
616intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) {
617 return static_cast<intptr_t>(
618 llvm::range_size(unwrap(c: op)->getDiscardableAttrs()));
619}
620
621MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op,
622 intptr_t pos) {
623 NamedAttribute attr =
624 *std::next(unwrap(c: op)->getDiscardableAttrs().begin(), pos);
625 return MlirNamedAttribute{wrap(attr.getName()), wrap(cpp: attr.getValue())};
626}
627
628MlirAttribute mlirOperationGetDiscardableAttributeByName(MlirOperation op,
629 MlirStringRef name) {
630 return wrap(cpp: unwrap(c: op)->getDiscardableAttr(name: unwrap(ref: name)));
631}
632
633void mlirOperationSetDiscardableAttributeByName(MlirOperation op,
634 MlirStringRef name,
635 MlirAttribute attr) {
636 unwrap(c: op)->setDiscardableAttr(name: unwrap(ref: name), value: unwrap(c: attr));
637}
638
639bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op,
640 MlirStringRef name) {
641 return !!unwrap(c: op)->removeDiscardableAttr(name: unwrap(ref: name));
642}
643
644void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos,
645 MlirBlock block) {
646 unwrap(c: op)->setSuccessor(block: unwrap(c: block), index: static_cast<unsigned>(pos));
647}
648
649intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
650 return static_cast<intptr_t>(unwrap(c: op)->getAttrs().size());
651}
652
653MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
654 NamedAttribute attr = unwrap(c: op)->getAttrs()[pos];
655 return MlirNamedAttribute{wrap(attr.getName()), wrap(cpp: attr.getValue())};
656}
657
658MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
659 MlirStringRef name) {
660 return wrap(cpp: unwrap(c: op)->getAttr(name: unwrap(ref: name)));
661}
662
663void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name,
664 MlirAttribute attr) {
665 unwrap(c: op)->setAttr(name: unwrap(ref: name), value: unwrap(c: attr));
666}
667
668bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) {
669 return !!unwrap(c: op)->removeAttr(name: unwrap(ref: name));
670}
671
672void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
673 void *userData) {
674 detail::CallbackOstream stream(callback, userData);
675 unwrap(c: op)->print(os&: stream);
676}
677
678void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
679 MlirStringCallback callback, void *userData) {
680 detail::CallbackOstream stream(callback, userData);
681 unwrap(c: op)->print(os&: stream, flags: *unwrap(c: flags));
682}
683
684void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state,
685 MlirStringCallback callback, void *userData) {
686 detail::CallbackOstream stream(callback, userData);
687 if (state.ptr)
688 unwrap(c: op)->print(os&: stream, state&: *unwrap(c: state));
689 unwrap(c: op)->print(os&: stream);
690}
691
692void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback,
693 void *userData) {
694 detail::CallbackOstream stream(callback, userData);
695 // As no desired version is set, no failure can occur.
696 (void)writeBytecodeToFile(op: unwrap(c: op), os&: stream);
697}
698
699MlirLogicalResult mlirOperationWriteBytecodeWithConfig(
700 MlirOperation op, MlirBytecodeWriterConfig config,
701 MlirStringCallback callback, void *userData) {
702 detail::CallbackOstream stream(callback, userData);
703 return wrap(res: writeBytecodeToFile(op: unwrap(c: op), os&: stream, config: *unwrap(c: config)));
704}
705
706void mlirOperationDump(MlirOperation op) { return unwrap(c: op)->dump(); }
707
708bool mlirOperationVerify(MlirOperation op) {
709 return succeeded(result: verify(op: unwrap(c: op)));
710}
711
712void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) {
713 return unwrap(c: op)->moveAfter(existingOp: unwrap(c: other));
714}
715
716void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
717 return unwrap(c: op)->moveBefore(existingOp: unwrap(c: other));
718}
719
720static mlir::WalkResult unwrap(MlirWalkResult result) {
721 switch (result) {
722 case MlirWalkResultAdvance:
723 return mlir::WalkResult::advance();
724
725 case MlirWalkResultInterrupt:
726 return mlir::WalkResult::interrupt();
727
728 case MlirWalkResultSkip:
729 return mlir::WalkResult::skip();
730 }
731}
732
733void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
734 void *userData, MlirWalkOrder walkOrder) {
735 switch (walkOrder) {
736
737 case MlirWalkPreOrder:
738 unwrap(c: op)->walk<mlir::WalkOrder::PreOrder>(
739 callback: [callback, userData](Operation *op) {
740 return unwrap(result: callback(wrap(cpp: op), userData));
741 });
742 break;
743 case MlirWalkPostOrder:
744 unwrap(c: op)->walk<mlir::WalkOrder::PostOrder>(
745 callback: [callback, userData](Operation *op) {
746 return unwrap(result: callback(wrap(cpp: op), userData));
747 });
748 }
749}
750
751//===----------------------------------------------------------------------===//
752// Region API.
753//===----------------------------------------------------------------------===//
754
755MlirRegion mlirRegionCreate() { return wrap(cpp: new Region); }
756
757bool mlirRegionEqual(MlirRegion region, MlirRegion other) {
758 return unwrap(c: region) == unwrap(c: other);
759}
760
761MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
762 Region *cppRegion = unwrap(c: region);
763 if (cppRegion->empty())
764 return wrap(cpp: static_cast<Block *>(nullptr));
765 return wrap(cpp: &cppRegion->front());
766}
767
768void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
769 unwrap(c: region)->push_back(block: unwrap(c: block));
770}
771
772void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
773 MlirBlock block) {
774 auto &blockList = unwrap(c: region)->getBlocks();
775 blockList.insert(where: std::next(x: blockList.begin(), n: pos), New: unwrap(c: block));
776}
777
778void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
779 MlirBlock block) {
780 Region *cppRegion = unwrap(c: region);
781 if (mlirBlockIsNull(block: reference)) {
782 cppRegion->getBlocks().insert(where: cppRegion->begin(), New: unwrap(c: block));
783 return;
784 }
785
786 assert(unwrap(reference)->getParent() == unwrap(region) &&
787 "expected reference block to belong to the region");
788 cppRegion->getBlocks().insertAfter(where: Region::iterator(unwrap(c: reference)),
789 New: unwrap(c: block));
790}
791
792void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
793 MlirBlock block) {
794 if (mlirBlockIsNull(block: reference))
795 return mlirRegionAppendOwnedBlock(region, block);
796
797 assert(unwrap(reference)->getParent() == unwrap(region) &&
798 "expected reference block to belong to the region");
799 unwrap(c: region)->getBlocks().insert(where: Region::iterator(unwrap(c: reference)),
800 New: unwrap(c: block));
801}
802
803void mlirRegionDestroy(MlirRegion region) {
804 delete static_cast<Region *>(region.ptr);
805}
806
807void mlirRegionTakeBody(MlirRegion target, MlirRegion source) {
808 unwrap(c: target)->takeBody(other&: *unwrap(c: source));
809}
810
811//===----------------------------------------------------------------------===//
812// Block API.
813//===----------------------------------------------------------------------===//
814
815MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args,
816 MlirLocation const *locs) {
817 Block *b = new Block;
818 for (intptr_t i = 0; i < nArgs; ++i)
819 b->addArgument(type: unwrap(c: args[i]), loc: unwrap(c: locs[i]));
820 return wrap(cpp: b);
821}
822
823bool mlirBlockEqual(MlirBlock block, MlirBlock other) {
824 return unwrap(c: block) == unwrap(c: other);
825}
826
827MlirOperation mlirBlockGetParentOperation(MlirBlock block) {
828 return wrap(cpp: unwrap(c: block)->getParentOp());
829}
830
831MlirRegion mlirBlockGetParentRegion(MlirBlock block) {
832 return wrap(cpp: unwrap(c: block)->getParent());
833}
834
835MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
836 return wrap(cpp: unwrap(c: block)->getNextNode());
837}
838
839MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
840 Block *cppBlock = unwrap(c: block);
841 if (cppBlock->empty())
842 return wrap(cpp: static_cast<Operation *>(nullptr));
843 return wrap(cpp: &cppBlock->front());
844}
845
846MlirOperation mlirBlockGetTerminator(MlirBlock block) {
847 Block *cppBlock = unwrap(c: block);
848 if (cppBlock->empty())
849 return wrap(cpp: static_cast<Operation *>(nullptr));
850 Operation &back = cppBlock->back();
851 if (!back.hasTrait<OpTrait::IsTerminator>())
852 return wrap(cpp: static_cast<Operation *>(nullptr));
853 return wrap(cpp: &back);
854}
855
856void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
857 unwrap(c: block)->push_back(op: unwrap(c: operation));
858}
859
860void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
861 MlirOperation operation) {
862 auto &opList = unwrap(c: block)->getOperations();
863 opList.insert(where: std::next(x: opList.begin(), n: pos), New: unwrap(c: operation));
864}
865
866void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
867 MlirOperation reference,
868 MlirOperation operation) {
869 Block *cppBlock = unwrap(c: block);
870 if (mlirOperationIsNull(op: reference)) {
871 cppBlock->getOperations().insert(where: cppBlock->begin(), New: unwrap(c: operation));
872 return;
873 }
874
875 assert(unwrap(reference)->getBlock() == unwrap(block) &&
876 "expected reference operation to belong to the block");
877 cppBlock->getOperations().insertAfter(where: Block::iterator(unwrap(c: reference)),
878 New: unwrap(c: operation));
879}
880
881void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
882 MlirOperation reference,
883 MlirOperation operation) {
884 if (mlirOperationIsNull(op: reference))
885 return mlirBlockAppendOwnedOperation(block, operation);
886
887 assert(unwrap(reference)->getBlock() == unwrap(block) &&
888 "expected reference operation to belong to the block");
889 unwrap(c: block)->getOperations().insert(where: Block::iterator(unwrap(c: reference)),
890 New: unwrap(c: operation));
891}
892
893void mlirBlockDestroy(MlirBlock block) { delete unwrap(c: block); }
894
895void mlirBlockDetach(MlirBlock block) {
896 Block *b = unwrap(c: block);
897 b->getParent()->getBlocks().remove(IT: b);
898}
899
900intptr_t mlirBlockGetNumArguments(MlirBlock block) {
901 return static_cast<intptr_t>(unwrap(c: block)->getNumArguments());
902}
903
904MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type,
905 MlirLocation loc) {
906 return wrap(cpp: unwrap(c: block)->addArgument(type: unwrap(c: type), loc: unwrap(c: loc)));
907}
908
909MlirValue mlirBlockInsertArgument(MlirBlock block, intptr_t pos, MlirType type,
910 MlirLocation loc) {
911 return wrap(cpp: unwrap(c: block)->insertArgument(index: pos, type: unwrap(c: type), loc: unwrap(c: loc)));
912}
913
914MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
915 return wrap(cpp: unwrap(c: block)->getArgument(i: static_cast<unsigned>(pos)));
916}
917
918void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
919 void *userData) {
920 detail::CallbackOstream stream(callback, userData);
921 unwrap(c: block)->print(os&: stream);
922}
923
924//===----------------------------------------------------------------------===//
925// Value API.
926//===----------------------------------------------------------------------===//
927
928bool mlirValueEqual(MlirValue value1, MlirValue value2) {
929 return unwrap(c: value1) == unwrap(c: value2);
930}
931
932bool mlirValueIsABlockArgument(MlirValue value) {
933 return llvm::isa<BlockArgument>(Val: unwrap(c: value));
934}
935
936bool mlirValueIsAOpResult(MlirValue value) {
937 return llvm::isa<OpResult>(Val: unwrap(c: value));
938}
939
940MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
941 return wrap(cpp: llvm::cast<BlockArgument>(Val: unwrap(c: value)).getOwner());
942}
943
944intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
945 return static_cast<intptr_t>(
946 llvm::cast<BlockArgument>(Val: unwrap(c: value)).getArgNumber());
947}
948
949void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
950 llvm::cast<BlockArgument>(Val: unwrap(c: value)).setType(unwrap(c: type));
951}
952
953MlirOperation mlirOpResultGetOwner(MlirValue value) {
954 return wrap(cpp: llvm::cast<OpResult>(Val: unwrap(c: value)).getOwner());
955}
956
957intptr_t mlirOpResultGetResultNumber(MlirValue value) {
958 return static_cast<intptr_t>(
959 llvm::cast<OpResult>(Val: unwrap(c: value)).getResultNumber());
960}
961
962MlirType mlirValueGetType(MlirValue value) {
963 return wrap(cpp: unwrap(c: value).getType());
964}
965
966void mlirValueSetType(MlirValue value, MlirType type) {
967 unwrap(c: value).setType(unwrap(c: type));
968}
969
970void mlirValueDump(MlirValue value) { unwrap(c: value).dump(); }
971
972void mlirValuePrint(MlirValue value, MlirStringCallback callback,
973 void *userData) {
974 detail::CallbackOstream stream(callback, userData);
975 unwrap(c: value).print(os&: stream);
976}
977
978void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state,
979 MlirStringCallback callback, void *userData) {
980 detail::CallbackOstream stream(callback, userData);
981 Value cppValue = unwrap(c: value);
982 cppValue.printAsOperand(os&: stream, state&: *unwrap(c: state));
983}
984
985MlirOpOperand mlirValueGetFirstUse(MlirValue value) {
986 Value cppValue = unwrap(c: value);
987 if (cppValue.use_empty())
988 return {};
989
990 OpOperand *opOperand = cppValue.use_begin().getOperand();
991
992 return wrap(cpp: opOperand);
993}
994
995void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) {
996 unwrap(c: oldValue).replaceAllUsesWith(newValue: unwrap(c: newValue));
997}
998
999//===----------------------------------------------------------------------===//
1000// OpOperand API.
1001//===----------------------------------------------------------------------===//
1002
1003bool mlirOpOperandIsNull(MlirOpOperand opOperand) { return !opOperand.ptr; }
1004
1005MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand) {
1006 return wrap(cpp: unwrap(c: opOperand)->getOwner());
1007}
1008
1009MlirValue mlirOpOperandGetValue(MlirOpOperand opOperand) {
1010 return wrap(cpp: unwrap(c: opOperand)->get());
1011}
1012
1013unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand) {
1014 return unwrap(c: opOperand)->getOperandNumber();
1015}
1016
1017MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand) {
1018 if (mlirOpOperandIsNull(opOperand))
1019 return {};
1020
1021 OpOperand *nextOpOperand = static_cast<OpOperand *>(
1022 unwrap(c: opOperand)->getNextOperandUsingThisValue());
1023
1024 if (!nextOpOperand)
1025 return {};
1026
1027 return wrap(cpp: nextOpOperand);
1028}
1029
1030//===----------------------------------------------------------------------===//
1031// Type API.
1032//===----------------------------------------------------------------------===//
1033
1034MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) {
1035 return wrap(cpp: mlir::parseType(typeStr: unwrap(ref: type), context: unwrap(c: context)));
1036}
1037
1038MlirContext mlirTypeGetContext(MlirType type) {
1039 return wrap(cpp: unwrap(c: type).getContext());
1040}
1041
1042MlirTypeID mlirTypeGetTypeID(MlirType type) {
1043 return wrap(cpp: unwrap(c: type).getTypeID());
1044}
1045
1046MlirDialect mlirTypeGetDialect(MlirType type) {
1047 return wrap(cpp: &unwrap(c: type).getDialect());
1048}
1049
1050bool mlirTypeEqual(MlirType t1, MlirType t2) {
1051 return unwrap(c: t1) == unwrap(c: t2);
1052}
1053
1054void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
1055 detail::CallbackOstream stream(callback, userData);
1056 unwrap(c: type).print(os&: stream);
1057}
1058
1059void mlirTypeDump(MlirType type) { unwrap(c: type).dump(); }
1060
1061//===----------------------------------------------------------------------===//
1062// Attribute API.
1063//===----------------------------------------------------------------------===//
1064
1065MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) {
1066 return wrap(cpp: mlir::parseAttribute(attrStr: unwrap(ref: attr), context: unwrap(c: context)));
1067}
1068
1069MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
1070 return wrap(cpp: unwrap(c: attribute).getContext());
1071}
1072
1073MlirType mlirAttributeGetType(MlirAttribute attribute) {
1074 Attribute attr = unwrap(c: attribute);
1075 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr))
1076 return wrap(typedAttr.getType());
1077 return wrap(NoneType::get(attr.getContext()));
1078}
1079
1080MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) {
1081 return wrap(cpp: unwrap(c: attr).getTypeID());
1082}
1083
1084MlirDialect mlirAttributeGetDialect(MlirAttribute attr) {
1085 return wrap(cpp: &unwrap(c: attr).getDialect());
1086}
1087
1088bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
1089 return unwrap(c: a1) == unwrap(c: a2);
1090}
1091
1092void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
1093 void *userData) {
1094 detail::CallbackOstream stream(callback, userData);
1095 unwrap(c: attr).print(os&: stream);
1096}
1097
1098void mlirAttributeDump(MlirAttribute attr) { unwrap(c: attr).dump(); }
1099
1100MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name,
1101 MlirAttribute attr) {
1102 return MlirNamedAttribute{.name: name, .attribute: attr};
1103}
1104
1105//===----------------------------------------------------------------------===//
1106// Identifier API.
1107//===----------------------------------------------------------------------===//
1108
1109MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) {
1110 return wrap(StringAttr::get(unwrap(context), unwrap(str)));
1111}
1112
1113MlirContext mlirIdentifierGetContext(MlirIdentifier ident) {
1114 return wrap(unwrap(ident).getContext());
1115}
1116
1117bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) {
1118 return unwrap(ident) == unwrap(other);
1119}
1120
1121MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
1122 return wrap(unwrap(ident).strref());
1123}
1124
1125//===----------------------------------------------------------------------===//
1126// Symbol and SymbolTable API.
1127//===----------------------------------------------------------------------===//
1128
1129MlirStringRef mlirSymbolTableGetSymbolAttributeName() {
1130 return wrap(ref: SymbolTable::getSymbolAttrName());
1131}
1132
1133MlirStringRef mlirSymbolTableGetVisibilityAttributeName() {
1134 return wrap(ref: SymbolTable::getVisibilityAttrName());
1135}
1136
1137MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) {
1138 if (!unwrap(c: operation)->hasTrait<OpTrait::SymbolTable>())
1139 return wrap(cpp: static_cast<SymbolTable *>(nullptr));
1140 return wrap(cpp: new SymbolTable(unwrap(c: operation)));
1141}
1142
1143void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) {
1144 delete unwrap(c: symbolTable);
1145}
1146
1147MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable,
1148 MlirStringRef name) {
1149 return wrap(cpp: unwrap(c: symbolTable)->lookup(name: StringRef(name.data, name.length)));
1150}
1151
1152MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable,
1153 MlirOperation operation) {
1154 return wrap(cpp: (Attribute)unwrap(c: symbolTable)->insert(unwrap(c: operation)));
1155}
1156
1157void mlirSymbolTableErase(MlirSymbolTable symbolTable,
1158 MlirOperation operation) {
1159 unwrap(c: symbolTable)->erase(symbol: unwrap(c: operation));
1160}
1161
1162MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol,
1163 MlirStringRef newSymbol,
1164 MlirOperation from) {
1165 auto *cppFrom = unwrap(c: from);
1166 auto *context = cppFrom->getContext();
1167 auto oldSymbolAttr = StringAttr::get(context, unwrap(oldSymbol));
1168 auto newSymbolAttr = StringAttr::get(context, unwrap(newSymbol));
1169 return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr,
1170 unwrap(c: from)));
1171}
1172
1173void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible,
1174 void (*callback)(MlirOperation, bool,
1175 void *userData),
1176 void *userData) {
1177 SymbolTable::walkSymbolTables(op: unwrap(c: from), allSymUsesVisible,
1178 callback: [&](Operation *foundOpCpp, bool isVisible) {
1179 callback(wrap(cpp: foundOpCpp), isVisible,
1180 userData);
1181 });
1182}
1183

source code of mlir/lib/CAPI/IR/IR.cpp