1 | //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir-c/IR.h" |
10 | #include "mlir-c/Support.h" |
11 | |
12 | #include "mlir/AsmParser/AsmParser.h" |
13 | #include "mlir/Bytecode/BytecodeWriter.h" |
14 | #include "mlir/CAPI/IR.h" |
15 | #include "mlir/CAPI/Support.h" |
16 | #include "mlir/CAPI/Utils.h" |
17 | #include "mlir/IR/Attributes.h" |
18 | #include "mlir/IR/BuiltinAttributes.h" |
19 | #include "mlir/IR/BuiltinOps.h" |
20 | #include "mlir/IR/Diagnostics.h" |
21 | #include "mlir/IR/Dialect.h" |
22 | #include "mlir/IR/Location.h" |
23 | #include "mlir/IR/Operation.h" |
24 | #include "mlir/IR/OperationSupport.h" |
25 | #include "mlir/IR/OwningOpRef.h" |
26 | #include "mlir/IR/Types.h" |
27 | #include "mlir/IR/Value.h" |
28 | #include "mlir/IR/Verifier.h" |
29 | #include "mlir/IR/Visitors.h" |
30 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
31 | #include "mlir/Parser/Parser.h" |
32 | #include "llvm/ADT/SmallPtrSet.h" |
33 | #include "llvm/Support/ThreadPool.h" |
34 | |
35 | #include <cstddef> |
36 | #include <memory> |
37 | #include <optional> |
38 | |
39 | using namespace mlir; |
40 | |
41 | //===----------------------------------------------------------------------===// |
42 | // Context API. |
43 | //===----------------------------------------------------------------------===// |
44 | |
45 | MlirContext mlirContextCreate() { |
46 | auto *context = new MLIRContext; |
47 | return wrap(cpp: context); |
48 | } |
49 | |
50 | static inline MLIRContext::Threading toThreadingEnum(bool threadingEnabled) { |
51 | return threadingEnabled ? MLIRContext::Threading::ENABLED |
52 | : MLIRContext::Threading::DISABLED; |
53 | } |
54 | |
55 | MlirContext mlirContextCreateWithThreading(bool threadingEnabled) { |
56 | auto *context = new MLIRContext(toThreadingEnum(threadingEnabled)); |
57 | return wrap(cpp: context); |
58 | } |
59 | |
60 | MlirContext mlirContextCreateWithRegistry(MlirDialectRegistry registry, |
61 | bool threadingEnabled) { |
62 | auto *context = |
63 | new MLIRContext(*unwrap(c: registry), toThreadingEnum(threadingEnabled)); |
64 | return wrap(cpp: context); |
65 | } |
66 | |
67 | bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { |
68 | return unwrap(c: ctx1) == unwrap(c: ctx2); |
69 | } |
70 | |
71 | void mlirContextDestroy(MlirContext context) { delete unwrap(c: context); } |
72 | |
73 | void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) { |
74 | unwrap(c: context)->allowUnregisteredDialects(allow); |
75 | } |
76 | |
77 | bool mlirContextGetAllowUnregisteredDialects(MlirContext context) { |
78 | return unwrap(c: context)->allowsUnregisteredDialects(); |
79 | } |
80 | intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) { |
81 | return static_cast<intptr_t>(unwrap(c: context)->getAvailableDialects().size()); |
82 | } |
83 | |
84 | void mlirContextAppendDialectRegistry(MlirContext ctx, |
85 | MlirDialectRegistry registry) { |
86 | unwrap(c: ctx)->appendDialectRegistry(registry: *unwrap(c: registry)); |
87 | } |
88 | |
89 | // TODO: expose a cheaper way than constructing + sorting a vector only to take |
90 | // its size. |
91 | intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { |
92 | return static_cast<intptr_t>(unwrap(c: context)->getLoadedDialects().size()); |
93 | } |
94 | |
95 | MlirDialect mlirContextGetOrLoadDialect(MlirContext context, |
96 | MlirStringRef name) { |
97 | return wrap(cpp: unwrap(c: context)->getOrLoadDialect(name: unwrap(ref: name))); |
98 | } |
99 | |
100 | bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) { |
101 | return unwrap(c: context)->isOperationRegistered(name: unwrap(ref: name)); |
102 | } |
103 | |
104 | void mlirContextEnableMultithreading(MlirContext context, bool enable) { |
105 | return unwrap(c: context)->enableMultithreading(enable); |
106 | } |
107 | |
108 | void mlirContextLoadAllAvailableDialects(MlirContext context) { |
109 | unwrap(c: context)->loadAllAvailableDialects(); |
110 | } |
111 | |
112 | void mlirContextSetThreadPool(MlirContext context, |
113 | MlirLlvmThreadPool threadPool) { |
114 | unwrap(c: context)->setThreadPool(*unwrap(c: threadPool)); |
115 | } |
116 | |
117 | unsigned mlirContextGetNumThreads(MlirContext context) { |
118 | return unwrap(c: context)->getNumThreads(); |
119 | } |
120 | |
121 | MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context) { |
122 | return wrap(cpp: &unwrap(c: context)->getThreadPool()); |
123 | } |
124 | |
125 | //===----------------------------------------------------------------------===// |
126 | // Dialect API. |
127 | //===----------------------------------------------------------------------===// |
128 | |
129 | MlirContext mlirDialectGetContext(MlirDialect dialect) { |
130 | return wrap(cpp: unwrap(c: dialect)->getContext()); |
131 | } |
132 | |
133 | bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) { |
134 | return unwrap(c: dialect1) == unwrap(c: dialect2); |
135 | } |
136 | |
137 | MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) { |
138 | return wrap(ref: unwrap(c: dialect)->getNamespace()); |
139 | } |
140 | |
141 | //===----------------------------------------------------------------------===// |
142 | // DialectRegistry API. |
143 | //===----------------------------------------------------------------------===// |
144 | |
145 | MlirDialectRegistry mlirDialectRegistryCreate() { |
146 | return wrap(cpp: new DialectRegistry()); |
147 | } |
148 | |
149 | void mlirDialectRegistryDestroy(MlirDialectRegistry registry) { |
150 | delete unwrap(c: registry); |
151 | } |
152 | |
153 | //===----------------------------------------------------------------------===// |
154 | // AsmState API. |
155 | //===----------------------------------------------------------------------===// |
156 | |
157 | MlirAsmState mlirAsmStateCreateForOperation(MlirOperation op, |
158 | MlirOpPrintingFlags flags) { |
159 | return wrap(cpp: new AsmState(unwrap(c: op), *unwrap(c: flags))); |
160 | } |
161 | |
162 | static Operation *findParent(Operation *op, bool shouldUseLocalScope) { |
163 | do { |
164 | // If we are printing local scope, stop at the first operation that is |
165 | // isolated from above. |
166 | if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>()) |
167 | break; |
168 | |
169 | // Otherwise, traverse up to the next parent. |
170 | Operation *parentOp = op->getParentOp(); |
171 | if (!parentOp) |
172 | break; |
173 | op = parentOp; |
174 | } while (true); |
175 | return op; |
176 | } |
177 | |
178 | MlirAsmState mlirAsmStateCreateForValue(MlirValue value, |
179 | MlirOpPrintingFlags flags) { |
180 | Operation *op; |
181 | mlir::Value val = unwrap(c: value); |
182 | if (auto result = llvm::dyn_cast<OpResult>(Val&: val)) { |
183 | op = result.getOwner(); |
184 | } else { |
185 | op = llvm::cast<BlockArgument>(Val&: val).getOwner()->getParentOp(); |
186 | if (!op) { |
187 | emitError(loc: val.getLoc()) << "<<UNKNOWN SSA VALUE>>"; |
188 | return {.ptr: nullptr}; |
189 | } |
190 | } |
191 | op = findParent(op, shouldUseLocalScope: unwrap(c: flags)->shouldUseLocalScope()); |
192 | return wrap(cpp: new AsmState(op, *unwrap(c: flags))); |
193 | } |
194 | |
195 | /// Destroys printing flags created with mlirAsmStateCreate. |
196 | void mlirAsmStateDestroy(MlirAsmState state) { delete unwrap(c: state); } |
197 | |
198 | //===----------------------------------------------------------------------===// |
199 | // Printing flags API. |
200 | //===----------------------------------------------------------------------===// |
201 | |
202 | MlirOpPrintingFlags mlirOpPrintingFlagsCreate() { |
203 | return wrap(cpp: new OpPrintingFlags()); |
204 | } |
205 | |
206 | void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) { |
207 | delete unwrap(c: flags); |
208 | } |
209 | |
210 | void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, |
211 | intptr_t largeElementLimit) { |
212 | unwrap(c: flags)->elideLargeElementsAttrs(largeElementLimit); |
213 | } |
214 | |
215 | void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags, |
216 | intptr_t largeResourceLimit) { |
217 | unwrap(c: flags)->elideLargeResourceString(largeResourceLimit); |
218 | } |
219 | |
220 | void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, |
221 | bool prettyForm) { |
222 | unwrap(c: flags)->enableDebugInfo(enable, /*prettyForm=*/prettyForm); |
223 | } |
224 | |
225 | void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) { |
226 | unwrap(c: flags)->printGenericOpForm(); |
227 | } |
228 | |
229 | void mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags) { |
230 | unwrap(c: flags)->printNameLocAsPrefix(); |
231 | } |
232 | |
233 | void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) { |
234 | unwrap(c: flags)->useLocalScope(); |
235 | } |
236 | |
237 | void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) { |
238 | unwrap(c: flags)->assumeVerified(); |
239 | } |
240 | |
241 | void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags) { |
242 | unwrap(c: flags)->skipRegions(); |
243 | } |
244 | //===----------------------------------------------------------------------===// |
245 | // Bytecode printing flags API. |
246 | //===----------------------------------------------------------------------===// |
247 | |
248 | MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate() { |
249 | return wrap(cpp: new BytecodeWriterConfig()); |
250 | } |
251 | |
252 | void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config) { |
253 | delete unwrap(c: config); |
254 | } |
255 | |
256 | void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, |
257 | int64_t version) { |
258 | unwrap(c: flags)->setDesiredBytecodeVersion(version); |
259 | } |
260 | |
261 | //===----------------------------------------------------------------------===// |
262 | // Location API. |
263 | //===----------------------------------------------------------------------===// |
264 | |
265 | MlirAttribute mlirLocationGetAttribute(MlirLocation location) { |
266 | return wrap(cpp: LocationAttr(unwrap(c: location))); |
267 | } |
268 | |
269 | MlirLocation mlirLocationFromAttribute(MlirAttribute attribute) { |
270 | return wrap(cpp: Location(llvm::dyn_cast<LocationAttr>(Val: unwrap(c: attribute)))); |
271 | } |
272 | |
273 | MlirLocation mlirLocationFileLineColGet(MlirContext context, |
274 | MlirStringRef filename, unsigned line, |
275 | unsigned col) { |
276 | return wrap(cpp: Location( |
277 | FileLineColLoc::get(context: unwrap(c: context), fileName: unwrap(ref: filename), line, column: col))); |
278 | } |
279 | |
280 | MlirLocation |
281 | mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename, |
282 | unsigned startLine, unsigned startCol, |
283 | unsigned endLine, unsigned endCol) { |
284 | return wrap( |
285 | Location(FileLineColRange::get(unwrap(context), unwrap(filename), |
286 | startLine, startCol, endLine, endCol))); |
287 | } |
288 | |
289 | MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location) { |
290 | return wrap(llvm::dyn_cast<FileLineColRange>(unwrap(c: location)).getFilename()); |
291 | } |
292 | |
293 | int mlirLocationFileLineColRangeGetStartLine(MlirLocation location) { |
294 | if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location))) |
295 | return loc.getStartLine(); |
296 | return -1; |
297 | } |
298 | |
299 | int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location) { |
300 | if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location))) |
301 | return loc.getStartColumn(); |
302 | return -1; |
303 | } |
304 | |
305 | int mlirLocationFileLineColRangeGetEndLine(MlirLocation location) { |
306 | if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location))) |
307 | return loc.getEndLine(); |
308 | return -1; |
309 | } |
310 | |
311 | int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location) { |
312 | if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location))) |
313 | return loc.getEndColumn(); |
314 | return -1; |
315 | } |
316 | |
317 | MlirTypeID mlirLocationFileLineColRangeGetTypeID() { |
318 | return wrap(FileLineColRange::getTypeID()); |
319 | } |
320 | |
321 | bool mlirLocationIsAFileLineColRange(MlirLocation location) { |
322 | return isa<FileLineColRange>(Val: unwrap(c: location)); |
323 | } |
324 | |
325 | MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) { |
326 | return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller)))); |
327 | } |
328 | |
329 | MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location) { |
330 | return wrap( |
331 | Location(llvm::dyn_cast<CallSiteLoc>(unwrap(location)).getCallee())); |
332 | } |
333 | |
334 | MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location) { |
335 | return wrap( |
336 | Location(llvm::dyn_cast<CallSiteLoc>(unwrap(location)).getCaller())); |
337 | } |
338 | |
339 | MlirTypeID mlirLocationCallSiteGetTypeID() { |
340 | return wrap(CallSiteLoc::getTypeID()); |
341 | } |
342 | |
343 | bool mlirLocationIsACallSite(MlirLocation location) { |
344 | return isa<CallSiteLoc>(unwrap(location)); |
345 | } |
346 | |
347 | MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, |
348 | MlirLocation const *locations, |
349 | MlirAttribute metadata) { |
350 | SmallVector<Location, 4> locs; |
351 | ArrayRef<Location> unwrappedLocs = unwrapList(size: nLocations, first: locations, storage&: locs); |
352 | return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx))); |
353 | } |
354 | |
355 | unsigned mlirLocationFusedGetNumLocations(MlirLocation location) { |
356 | if (auto locationsArrRef = llvm::dyn_cast<FusedLoc>(unwrap(location))) |
357 | return locationsArrRef.getLocations().size(); |
358 | return 0; |
359 | } |
360 | |
361 | void mlirLocationFusedGetLocations(MlirLocation location, |
362 | MlirLocation *locationsCPtr) { |
363 | if (auto locationsArrRef = llvm::dyn_cast<FusedLoc>(unwrap(location))) { |
364 | for (auto [i, location] : llvm::enumerate(locationsArrRef.getLocations())) |
365 | locationsCPtr[i] = wrap(location); |
366 | } |
367 | } |
368 | |
369 | MlirAttribute mlirLocationFusedGetMetadata(MlirLocation location) { |
370 | return wrap(llvm::dyn_cast<FusedLoc>(unwrap(location)).getMetadata()); |
371 | } |
372 | |
373 | MlirTypeID mlirLocationFusedGetTypeID() { return wrap(FusedLoc::getTypeID()); } |
374 | |
375 | bool mlirLocationIsAFused(MlirLocation location) { |
376 | return isa<FusedLoc>(unwrap(location)); |
377 | } |
378 | |
379 | MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, |
380 | MlirLocation childLoc) { |
381 | if (mlirLocationIsNull(childLoc)) |
382 | return wrap( |
383 | Location(NameLoc::get(StringAttr::get(unwrap(context), unwrap(name))))); |
384 | return wrap(Location(NameLoc::get( |
385 | StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc)))); |
386 | } |
387 | |
388 | MlirIdentifier mlirLocationNameGetName(MlirLocation location) { |
389 | return wrap((llvm::dyn_cast<NameLoc>(unwrap(location)).getName())); |
390 | } |
391 | |
392 | MlirLocation mlirLocationNameGetChildLoc(MlirLocation location) { |
393 | return wrap( |
394 | Location(llvm::dyn_cast<NameLoc>(unwrap(location)).getChildLoc())); |
395 | } |
396 | |
397 | MlirTypeID mlirLocationNameGetTypeID() { return wrap(NameLoc::getTypeID()); } |
398 | |
399 | bool mlirLocationIsAName(MlirLocation location) { |
400 | return isa<NameLoc>(unwrap(location)); |
401 | } |
402 | |
403 | MlirLocation mlirLocationUnknownGet(MlirContext context) { |
404 | return wrap(Location(UnknownLoc::get(unwrap(context)))); |
405 | } |
406 | |
407 | bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) { |
408 | return unwrap(c: l1) == unwrap(c: l2); |
409 | } |
410 | |
411 | MlirContext mlirLocationGetContext(MlirLocation location) { |
412 | return wrap(cpp: unwrap(c: location).getContext()); |
413 | } |
414 | |
415 | void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, |
416 | void *userData) { |
417 | detail::CallbackOstream stream(callback, userData); |
418 | unwrap(c: location).print(os&: stream); |
419 | } |
420 | |
421 | //===----------------------------------------------------------------------===// |
422 | // Module API. |
423 | //===----------------------------------------------------------------------===// |
424 | |
425 | MlirModule mlirModuleCreateEmpty(MlirLocation location) { |
426 | return wrap(ModuleOp::create(unwrap(location))); |
427 | } |
428 | |
429 | MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) { |
430 | OwningOpRef<ModuleOp> owning = |
431 | parseSourceString<ModuleOp>(unwrap(ref: module), unwrap(c: context)); |
432 | if (!owning) |
433 | return MlirModule{.ptr: nullptr}; |
434 | return MlirModule{owning.release().getOperation()}; |
435 | } |
436 | |
437 | MlirModule mlirModuleCreateParseFromFile(MlirContext context, |
438 | MlirStringRef fileName) { |
439 | OwningOpRef<ModuleOp> owning = |
440 | parseSourceFile<ModuleOp>(filename: unwrap(ref: fileName), config: unwrap(c: context)); |
441 | if (!owning) |
442 | return MlirModule{.ptr: nullptr}; |
443 | return MlirModule{owning.release().getOperation()}; |
444 | } |
445 | |
446 | MlirContext mlirModuleGetContext(MlirModule module) { |
447 | return wrap(unwrap(module).getContext()); |
448 | } |
449 | |
450 | MlirBlock mlirModuleGetBody(MlirModule module) { |
451 | return wrap(unwrap(module).getBody()); |
452 | } |
453 | |
454 | void mlirModuleDestroy(MlirModule module) { |
455 | // Transfer ownership to an OwningOpRef<ModuleOp> so that its destructor is |
456 | // called. |
457 | OwningOpRef<ModuleOp>(unwrap(module)); |
458 | } |
459 | |
460 | MlirOperation mlirModuleGetOperation(MlirModule module) { |
461 | return wrap(unwrap(module).getOperation()); |
462 | } |
463 | |
464 | MlirModule mlirModuleFromOperation(MlirOperation op) { |
465 | return wrap(dyn_cast<ModuleOp>(unwrap(c: op))); |
466 | } |
467 | |
468 | //===----------------------------------------------------------------------===// |
469 | // Operation state API. |
470 | //===----------------------------------------------------------------------===// |
471 | |
472 | MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) { |
473 | MlirOperationState state; |
474 | state.name = name; |
475 | state.location = loc; |
476 | state.nResults = 0; |
477 | state.results = nullptr; |
478 | state.nOperands = 0; |
479 | state.operands = nullptr; |
480 | state.nRegions = 0; |
481 | state.regions = nullptr; |
482 | state.nSuccessors = 0; |
483 | state.successors = nullptr; |
484 | state.nAttributes = 0; |
485 | state.attributes = nullptr; |
486 | state.enableResultTypeInference = false; |
487 | return state; |
488 | } |
489 | |
490 | #define APPEND_ELEMS(type, sizeName, elemName) \ |
491 | state->elemName = \ |
492 | (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ |
493 | memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ |
494 | state->sizeName += n; |
495 | |
496 | void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, |
497 | MlirType const *results) { |
498 | APPEND_ELEMS(MlirType, nResults, results); |
499 | } |
500 | |
501 | void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, |
502 | MlirValue const *operands) { |
503 | APPEND_ELEMS(MlirValue, nOperands, operands); |
504 | } |
505 | void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, |
506 | MlirRegion const *regions) { |
507 | APPEND_ELEMS(MlirRegion, nRegions, regions); |
508 | } |
509 | void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, |
510 | MlirBlock const *successors) { |
511 | APPEND_ELEMS(MlirBlock, nSuccessors, successors); |
512 | } |
513 | void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, |
514 | MlirNamedAttribute const *attributes) { |
515 | APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); |
516 | } |
517 | |
518 | void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) { |
519 | state->enableResultTypeInference = true; |
520 | } |
521 | |
522 | //===----------------------------------------------------------------------===// |
523 | // Operation API. |
524 | //===----------------------------------------------------------------------===// |
525 | |
526 | static LogicalResult inferOperationTypes(OperationState &state) { |
527 | MLIRContext *context = state.getContext(); |
528 | std::optional<RegisteredOperationName> info = state.name.getRegisteredInfo(); |
529 | if (!info) { |
530 | emitError(loc: state.location) |
531 | << "type inference was requested for the operation "<< state.name |
532 | << ", but the operation was not registered; ensure that the dialect " |
533 | "containing the operation is linked into MLIR and registered with " |
534 | "the context"; |
535 | return failure(); |
536 | } |
537 | |
538 | auto *inferInterface = info->getInterface<InferTypeOpInterface>(); |
539 | if (!inferInterface) { |
540 | emitError(loc: state.location) |
541 | << "type inference was requested for the operation "<< state.name |
542 | << ", but the operation does not support type inference; result " |
543 | "types must be specified explicitly"; |
544 | return failure(); |
545 | } |
546 | |
547 | DictionaryAttr attributes = state.attributes.getDictionary(context); |
548 | OpaqueProperties properties = state.getRawProperties(); |
549 | |
550 | if (!properties && info->getOpPropertyByteSize() > 0 && !attributes.empty()) { |
551 | auto prop = std::make_unique<char[]>(num: info->getOpPropertyByteSize()); |
552 | properties = OpaqueProperties(prop.get()); |
553 | if (properties) { |
554 | auto emitError = [&]() { |
555 | return mlir::emitError(state.location) |
556 | << " failed properties conversion while building " |
557 | << state.name.getStringRef() << " with `"<< attributes << "`: "; |
558 | }; |
559 | if (failed(info->setOpPropertiesFromAttribute(opName: state.name, properties, |
560 | attr: attributes, emitError))) |
561 | return failure(); |
562 | } |
563 | if (succeeded(inferInterface->inferReturnTypes( |
564 | context, state.location, state.operands, attributes, properties, |
565 | state.regions, state.types))) { |
566 | return success(); |
567 | } |
568 | // Diagnostic emitted by interface. |
569 | return failure(); |
570 | } |
571 | |
572 | if (succeeded(inferInterface->inferReturnTypes( |
573 | context, state.location, state.operands, attributes, properties, |
574 | state.regions, state.types))) |
575 | return success(); |
576 | |
577 | // Diagnostic emitted by interface. |
578 | return failure(); |
579 | } |
580 | |
581 | MlirOperation mlirOperationCreate(MlirOperationState *state) { |
582 | assert(state); |
583 | OperationState cppState(unwrap(c: state->location), unwrap(ref: state->name)); |
584 | SmallVector<Type, 4> resultStorage; |
585 | SmallVector<Value, 8> operandStorage; |
586 | SmallVector<Block *, 2> successorStorage; |
587 | cppState.addTypes(newTypes: unwrapList(size: state->nResults, first: state->results, storage&: resultStorage)); |
588 | cppState.addOperands( |
589 | newOperands: unwrapList(size: state->nOperands, first: state->operands, storage&: operandStorage)); |
590 | cppState.addSuccessors( |
591 | newSuccessors: unwrapList(size: state->nSuccessors, first: state->successors, storage&: successorStorage)); |
592 | |
593 | cppState.attributes.reserve(N: state->nAttributes); |
594 | for (intptr_t i = 0; i < state->nAttributes; ++i) |
595 | cppState.addAttribute(unwrap(state->attributes[i].name), |
596 | unwrap(c: state->attributes[i].attribute)); |
597 | |
598 | for (intptr_t i = 0; i < state->nRegions; ++i) |
599 | cppState.addRegion(region: std::unique_ptr<Region>(unwrap(c: state->regions[i]))); |
600 | |
601 | free(ptr: state->results); |
602 | free(ptr: state->operands); |
603 | free(ptr: state->successors); |
604 | free(ptr: state->regions); |
605 | free(ptr: state->attributes); |
606 | |
607 | // Infer result types. |
608 | if (state->enableResultTypeInference) { |
609 | assert(cppState.types.empty() && |
610 | "result type inference enabled and result types provided"); |
611 | if (failed(Result: inferOperationTypes(state&: cppState))) |
612 | return {.ptr: nullptr}; |
613 | } |
614 | |
615 | return wrap(cpp: Operation::create(state: cppState)); |
616 | } |
617 | |
618 | MlirOperation mlirOperationCreateParse(MlirContext context, |
619 | MlirStringRef sourceStr, |
620 | MlirStringRef sourceName) { |
621 | |
622 | return wrap( |
623 | cpp: parseSourceString(sourceStr: unwrap(ref: sourceStr), config: unwrap(c: context), sourceName: unwrap(ref: sourceName)) |
624 | .release()); |
625 | } |
626 | |
627 | MlirOperation mlirOperationClone(MlirOperation op) { |
628 | return wrap(cpp: unwrap(c: op)->clone()); |
629 | } |
630 | |
631 | void mlirOperationDestroy(MlirOperation op) { unwrap(c: op)->erase(); } |
632 | |
633 | void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(c: op)->remove(); } |
634 | |
635 | bool mlirOperationEqual(MlirOperation op, MlirOperation other) { |
636 | return unwrap(c: op) == unwrap(c: other); |
637 | } |
638 | |
639 | MlirContext mlirOperationGetContext(MlirOperation op) { |
640 | return wrap(cpp: unwrap(c: op)->getContext()); |
641 | } |
642 | |
643 | MlirLocation mlirOperationGetLocation(MlirOperation op) { |
644 | return wrap(cpp: unwrap(c: op)->getLoc()); |
645 | } |
646 | |
647 | MlirTypeID mlirOperationGetTypeID(MlirOperation op) { |
648 | if (auto info = unwrap(c: op)->getRegisteredInfo()) |
649 | return wrap(cpp: info->getTypeID()); |
650 | return {.ptr: nullptr}; |
651 | } |
652 | |
653 | MlirIdentifier mlirOperationGetName(MlirOperation op) { |
654 | return wrap(unwrap(c: op)->getName().getIdentifier()); |
655 | } |
656 | |
657 | MlirBlock mlirOperationGetBlock(MlirOperation op) { |
658 | return wrap(cpp: unwrap(c: op)->getBlock()); |
659 | } |
660 | |
661 | MlirOperation mlirOperationGetParentOperation(MlirOperation op) { |
662 | return wrap(cpp: unwrap(c: op)->getParentOp()); |
663 | } |
664 | |
665 | intptr_t mlirOperationGetNumRegions(MlirOperation op) { |
666 | return static_cast<intptr_t>(unwrap(c: op)->getNumRegions()); |
667 | } |
668 | |
669 | MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { |
670 | return wrap(cpp: &unwrap(c: op)->getRegion(index: static_cast<unsigned>(pos))); |
671 | } |
672 | |
673 | MlirRegion mlirOperationGetFirstRegion(MlirOperation op) { |
674 | Operation *cppOp = unwrap(c: op); |
675 | if (cppOp->getNumRegions() == 0) |
676 | return wrap(cpp: static_cast<Region *>(nullptr)); |
677 | return wrap(cpp: &cppOp->getRegion(index: 0)); |
678 | } |
679 | |
680 | MlirRegion mlirRegionGetNextInOperation(MlirRegion region) { |
681 | Region *cppRegion = unwrap(c: region); |
682 | Operation *parent = cppRegion->getParentOp(); |
683 | intptr_t next = cppRegion->getRegionNumber() + 1; |
684 | if (parent->getNumRegions() > next) |
685 | return wrap(cpp: &parent->getRegion(index: next)); |
686 | return wrap(cpp: static_cast<Region *>(nullptr)); |
687 | } |
688 | |
689 | MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { |
690 | return wrap(unwrap(c: op)->getNextNode()); |
691 | } |
692 | |
693 | intptr_t mlirOperationGetNumOperands(MlirOperation op) { |
694 | return static_cast<intptr_t>(unwrap(c: op)->getNumOperands()); |
695 | } |
696 | |
697 | MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { |
698 | return wrap(cpp: unwrap(c: op)->getOperand(idx: static_cast<unsigned>(pos))); |
699 | } |
700 | |
701 | void mlirOperationSetOperand(MlirOperation op, intptr_t pos, |
702 | MlirValue newValue) { |
703 | unwrap(c: op)->setOperand(idx: static_cast<unsigned>(pos), value: unwrap(c: newValue)); |
704 | } |
705 | |
706 | void mlirOperationSetOperands(MlirOperation op, intptr_t nOperands, |
707 | MlirValue const *operands) { |
708 | SmallVector<Value> ops; |
709 | unwrap(c: op)->setOperands(unwrapList(size: nOperands, first: operands, storage&: ops)); |
710 | } |
711 | |
712 | intptr_t mlirOperationGetNumResults(MlirOperation op) { |
713 | return static_cast<intptr_t>(unwrap(c: op)->getNumResults()); |
714 | } |
715 | |
716 | MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { |
717 | return wrap(cpp: unwrap(c: op)->getResult(idx: static_cast<unsigned>(pos))); |
718 | } |
719 | |
720 | intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { |
721 | return static_cast<intptr_t>(unwrap(c: op)->getNumSuccessors()); |
722 | } |
723 | |
724 | MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { |
725 | return wrap(cpp: unwrap(c: op)->getSuccessor(index: static_cast<unsigned>(pos))); |
726 | } |
727 | |
728 | MLIR_CAPI_EXPORTED bool |
729 | mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name) { |
730 | std::optional<Attribute> attr = unwrap(c: op)->getInherentAttr(name: unwrap(ref: name)); |
731 | return attr.has_value(); |
732 | } |
733 | |
734 | MlirAttribute mlirOperationGetInherentAttributeByName(MlirOperation op, |
735 | MlirStringRef name) { |
736 | std::optional<Attribute> attr = unwrap(c: op)->getInherentAttr(name: unwrap(ref: name)); |
737 | if (attr.has_value()) |
738 | return wrap(cpp: *attr); |
739 | return {}; |
740 | } |
741 | |
742 | void mlirOperationSetInherentAttributeByName(MlirOperation op, |
743 | MlirStringRef name, |
744 | MlirAttribute attr) { |
745 | unwrap(op)->setInherentAttr( |
746 | StringAttr::get(unwrap(op)->getContext(), unwrap(name)), unwrap(attr)); |
747 | } |
748 | |
749 | intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) { |
750 | return static_cast<intptr_t>( |
751 | llvm::range_size(unwrap(c: op)->getDiscardableAttrs())); |
752 | } |
753 | |
754 | MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op, |
755 | intptr_t pos) { |
756 | NamedAttribute attr = |
757 | *std::next(unwrap(c: op)->getDiscardableAttrs().begin(), pos); |
758 | return MlirNamedAttribute{wrap(attr.getName()), wrap(cpp: attr.getValue())}; |
759 | } |
760 | |
761 | MlirAttribute mlirOperationGetDiscardableAttributeByName(MlirOperation op, |
762 | MlirStringRef name) { |
763 | return wrap(cpp: unwrap(c: op)->getDiscardableAttr(name: unwrap(ref: name))); |
764 | } |
765 | |
766 | void mlirOperationSetDiscardableAttributeByName(MlirOperation op, |
767 | MlirStringRef name, |
768 | MlirAttribute attr) { |
769 | unwrap(c: op)->setDiscardableAttr(name: unwrap(ref: name), value: unwrap(c: attr)); |
770 | } |
771 | |
772 | bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op, |
773 | MlirStringRef name) { |
774 | return !!unwrap(c: op)->removeDiscardableAttr(name: unwrap(ref: name)); |
775 | } |
776 | |
777 | void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, |
778 | MlirBlock block) { |
779 | unwrap(c: op)->setSuccessor(block: unwrap(c: block), index: static_cast<unsigned>(pos)); |
780 | } |
781 | |
782 | intptr_t mlirOperationGetNumAttributes(MlirOperation op) { |
783 | return static_cast<intptr_t>(unwrap(c: op)->getAttrs().size()); |
784 | } |
785 | |
786 | MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { |
787 | NamedAttribute attr = unwrap(c: op)->getAttrs()[pos]; |
788 | return MlirNamedAttribute{wrap(attr.getName()), wrap(cpp: attr.getValue())}; |
789 | } |
790 | |
791 | MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, |
792 | MlirStringRef name) { |
793 | return wrap(cpp: unwrap(c: op)->getAttr(name: unwrap(ref: name))); |
794 | } |
795 | |
796 | void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, |
797 | MlirAttribute attr) { |
798 | unwrap(c: op)->setAttr(name: unwrap(ref: name), value: unwrap(c: attr)); |
799 | } |
800 | |
801 | bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) { |
802 | return !!unwrap(c: op)->removeAttr(name: unwrap(ref: name)); |
803 | } |
804 | |
805 | void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, |
806 | void *userData) { |
807 | detail::CallbackOstream stream(callback, userData); |
808 | unwrap(c: op)->print(os&: stream); |
809 | } |
810 | |
811 | void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, |
812 | MlirStringCallback callback, void *userData) { |
813 | detail::CallbackOstream stream(callback, userData); |
814 | unwrap(c: op)->print(os&: stream, flags: *unwrap(c: flags)); |
815 | } |
816 | |
817 | void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, |
818 | MlirStringCallback callback, void *userData) { |
819 | detail::CallbackOstream stream(callback, userData); |
820 | if (state.ptr) |
821 | unwrap(c: op)->print(os&: stream, state&: *unwrap(c: state)); |
822 | unwrap(c: op)->print(os&: stream); |
823 | } |
824 | |
825 | void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, |
826 | void *userData) { |
827 | detail::CallbackOstream stream(callback, userData); |
828 | // As no desired version is set, no failure can occur. |
829 | (void)writeBytecodeToFile(op: unwrap(c: op), os&: stream); |
830 | } |
831 | |
832 | MlirLogicalResult mlirOperationWriteBytecodeWithConfig( |
833 | MlirOperation op, MlirBytecodeWriterConfig config, |
834 | MlirStringCallback callback, void *userData) { |
835 | detail::CallbackOstream stream(callback, userData); |
836 | return wrap(res: writeBytecodeToFile(op: unwrap(c: op), os&: stream, config: *unwrap(c: config))); |
837 | } |
838 | |
839 | void mlirOperationDump(MlirOperation op) { return unwrap(c: op)->dump(); } |
840 | |
841 | bool mlirOperationVerify(MlirOperation op) { |
842 | return succeeded(Result: verify(op: unwrap(c: op))); |
843 | } |
844 | |
845 | void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) { |
846 | return unwrap(c: op)->moveAfter(existingOp: unwrap(c: other)); |
847 | } |
848 | |
849 | void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) { |
850 | return unwrap(c: op)->moveBefore(existingOp: unwrap(c: other)); |
851 | } |
852 | |
853 | static mlir::WalkResult unwrap(MlirWalkResult result) { |
854 | switch (result) { |
855 | case MlirWalkResultAdvance: |
856 | return mlir::WalkResult::advance(); |
857 | |
858 | case MlirWalkResultInterrupt: |
859 | return mlir::WalkResult::interrupt(); |
860 | |
861 | case MlirWalkResultSkip: |
862 | return mlir::WalkResult::skip(); |
863 | } |
864 | llvm_unreachable("unknown result in WalkResult::unwrap"); |
865 | } |
866 | |
867 | void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, |
868 | void *userData, MlirWalkOrder walkOrder) { |
869 | switch (walkOrder) { |
870 | |
871 | case MlirWalkPreOrder: |
872 | unwrap(c: op)->walk<mlir::WalkOrder::PreOrder>( |
873 | callback: [callback, userData](Operation *op) { |
874 | return unwrap(result: callback(wrap(cpp: op), userData)); |
875 | }); |
876 | break; |
877 | case MlirWalkPostOrder: |
878 | unwrap(c: op)->walk<mlir::WalkOrder::PostOrder>( |
879 | callback: [callback, userData](Operation *op) { |
880 | return unwrap(result: callback(wrap(cpp: op), userData)); |
881 | }); |
882 | } |
883 | } |
884 | |
885 | //===----------------------------------------------------------------------===// |
886 | // Region API. |
887 | //===----------------------------------------------------------------------===// |
888 | |
889 | MlirRegion mlirRegionCreate() { return wrap(cpp: new Region); } |
890 | |
891 | bool mlirRegionEqual(MlirRegion region, MlirRegion other) { |
892 | return unwrap(c: region) == unwrap(c: other); |
893 | } |
894 | |
895 | MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { |
896 | Region *cppRegion = unwrap(c: region); |
897 | if (cppRegion->empty()) |
898 | return wrap(cpp: static_cast<Block *>(nullptr)); |
899 | return wrap(cpp: &cppRegion->front()); |
900 | } |
901 | |
902 | void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { |
903 | unwrap(c: region)->push_back(block: unwrap(c: block)); |
904 | } |
905 | |
906 | void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, |
907 | MlirBlock block) { |
908 | auto &blockList = unwrap(c: region)->getBlocks(); |
909 | blockList.insert(where: std::next(x: blockList.begin(), n: pos), New: unwrap(c: block)); |
910 | } |
911 | |
912 | void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, |
913 | MlirBlock block) { |
914 | Region *cppRegion = unwrap(c: region); |
915 | if (mlirBlockIsNull(block: reference)) { |
916 | cppRegion->getBlocks().insert(where: cppRegion->begin(), New: unwrap(c: block)); |
917 | return; |
918 | } |
919 | |
920 | assert(unwrap(reference)->getParent() == unwrap(region) && |
921 | "expected reference block to belong to the region"); |
922 | cppRegion->getBlocks().insertAfter(where: Region::iterator(unwrap(c: reference)), |
923 | New: unwrap(c: block)); |
924 | } |
925 | |
926 | void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, |
927 | MlirBlock block) { |
928 | if (mlirBlockIsNull(block: reference)) |
929 | return mlirRegionAppendOwnedBlock(region, block); |
930 | |
931 | assert(unwrap(reference)->getParent() == unwrap(region) && |
932 | "expected reference block to belong to the region"); |
933 | unwrap(c: region)->getBlocks().insert(where: Region::iterator(unwrap(c: reference)), |
934 | New: unwrap(c: block)); |
935 | } |
936 | |
937 | void mlirRegionDestroy(MlirRegion region) { |
938 | delete static_cast<Region *>(region.ptr); |
939 | } |
940 | |
941 | void mlirRegionTakeBody(MlirRegion target, MlirRegion source) { |
942 | unwrap(c: target)->takeBody(other&: *unwrap(c: source)); |
943 | } |
944 | |
945 | //===----------------------------------------------------------------------===// |
946 | // Block API. |
947 | //===----------------------------------------------------------------------===// |
948 | |
949 | MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, |
950 | MlirLocation const *locs) { |
951 | Block *b = new Block; |
952 | for (intptr_t i = 0; i < nArgs; ++i) |
953 | b->addArgument(type: unwrap(c: args[i]), loc: unwrap(c: locs[i])); |
954 | return wrap(cpp: b); |
955 | } |
956 | |
957 | bool mlirBlockEqual(MlirBlock block, MlirBlock other) { |
958 | return unwrap(c: block) == unwrap(c: other); |
959 | } |
960 | |
961 | MlirOperation mlirBlockGetParentOperation(MlirBlock block) { |
962 | return wrap(cpp: unwrap(c: block)->getParentOp()); |
963 | } |
964 | |
965 | MlirRegion mlirBlockGetParentRegion(MlirBlock block) { |
966 | return wrap(cpp: unwrap(c: block)->getParent()); |
967 | } |
968 | |
969 | MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { |
970 | return wrap(cpp: unwrap(c: block)->getNextNode()); |
971 | } |
972 | |
973 | MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { |
974 | Block *cppBlock = unwrap(c: block); |
975 | if (cppBlock->empty()) |
976 | return wrap(cpp: static_cast<Operation *>(nullptr)); |
977 | return wrap(cpp: &cppBlock->front()); |
978 | } |
979 | |
980 | MlirOperation mlirBlockGetTerminator(MlirBlock block) { |
981 | Block *cppBlock = unwrap(c: block); |
982 | if (cppBlock->empty()) |
983 | return wrap(cpp: static_cast<Operation *>(nullptr)); |
984 | Operation &back = cppBlock->back(); |
985 | if (!back.hasTrait<OpTrait::IsTerminator>()) |
986 | return wrap(cpp: static_cast<Operation *>(nullptr)); |
987 | return wrap(cpp: &back); |
988 | } |
989 | |
990 | void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { |
991 | unwrap(c: block)->push_back(op: unwrap(c: operation)); |
992 | } |
993 | |
994 | void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, |
995 | MlirOperation operation) { |
996 | auto &opList = unwrap(c: block)->getOperations(); |
997 | opList.insert(where: std::next(x: opList.begin(), n: pos), New: unwrap(c: operation)); |
998 | } |
999 | |
1000 | void mlirBlockInsertOwnedOperationAfter(MlirBlock block, |
1001 | MlirOperation reference, |
1002 | MlirOperation operation) { |
1003 | Block *cppBlock = unwrap(c: block); |
1004 | if (mlirOperationIsNull(op: reference)) { |
1005 | cppBlock->getOperations().insert(where: cppBlock->begin(), New: unwrap(c: operation)); |
1006 | return; |
1007 | } |
1008 | |
1009 | assert(unwrap(reference)->getBlock() == unwrap(block) && |
1010 | "expected reference operation to belong to the block"); |
1011 | cppBlock->getOperations().insertAfter(where: Block::iterator(unwrap(c: reference)), |
1012 | New: unwrap(c: operation)); |
1013 | } |
1014 | |
1015 | void mlirBlockInsertOwnedOperationBefore(MlirBlock block, |
1016 | MlirOperation reference, |
1017 | MlirOperation operation) { |
1018 | if (mlirOperationIsNull(op: reference)) |
1019 | return mlirBlockAppendOwnedOperation(block, operation); |
1020 | |
1021 | assert(unwrap(reference)->getBlock() == unwrap(block) && |
1022 | "expected reference operation to belong to the block"); |
1023 | unwrap(c: block)->getOperations().insert(where: Block::iterator(unwrap(c: reference)), |
1024 | New: unwrap(c: operation)); |
1025 | } |
1026 | |
1027 | void mlirBlockDestroy(MlirBlock block) { delete unwrap(c: block); } |
1028 | |
1029 | void mlirBlockDetach(MlirBlock block) { |
1030 | Block *b = unwrap(c: block); |
1031 | b->getParent()->getBlocks().remove(IT: b); |
1032 | } |
1033 | |
1034 | intptr_t mlirBlockGetNumArguments(MlirBlock block) { |
1035 | return static_cast<intptr_t>(unwrap(c: block)->getNumArguments()); |
1036 | } |
1037 | |
1038 | MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, |
1039 | MlirLocation loc) { |
1040 | return wrap(cpp: unwrap(c: block)->addArgument(type: unwrap(c: type), loc: unwrap(c: loc))); |
1041 | } |
1042 | |
1043 | void mlirBlockEraseArgument(MlirBlock block, unsigned index) { |
1044 | return unwrap(c: block)->eraseArgument(index); |
1045 | } |
1046 | |
1047 | MlirValue mlirBlockInsertArgument(MlirBlock block, intptr_t pos, MlirType type, |
1048 | MlirLocation loc) { |
1049 | return wrap(cpp: unwrap(c: block)->insertArgument(index: pos, type: unwrap(c: type), loc: unwrap(c: loc))); |
1050 | } |
1051 | |
1052 | MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { |
1053 | return wrap(cpp: unwrap(c: block)->getArgument(i: static_cast<unsigned>(pos))); |
1054 | } |
1055 | |
1056 | void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, |
1057 | void *userData) { |
1058 | detail::CallbackOstream stream(callback, userData); |
1059 | unwrap(c: block)->print(os&: stream); |
1060 | } |
1061 | |
1062 | //===----------------------------------------------------------------------===// |
1063 | // Value API. |
1064 | //===----------------------------------------------------------------------===// |
1065 | |
1066 | bool mlirValueEqual(MlirValue value1, MlirValue value2) { |
1067 | return unwrap(c: value1) == unwrap(c: value2); |
1068 | } |
1069 | |
1070 | bool mlirValueIsABlockArgument(MlirValue value) { |
1071 | return llvm::isa<BlockArgument>(Val: unwrap(c: value)); |
1072 | } |
1073 | |
1074 | bool mlirValueIsAOpResult(MlirValue value) { |
1075 | return llvm::isa<OpResult>(Val: unwrap(c: value)); |
1076 | } |
1077 | |
1078 | MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { |
1079 | return wrap(cpp: llvm::dyn_cast<BlockArgument>(Val: unwrap(c: value)).getOwner()); |
1080 | } |
1081 | |
1082 | intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) { |
1083 | return static_cast<intptr_t>( |
1084 | llvm::dyn_cast<BlockArgument>(Val: unwrap(c: value)).getArgNumber()); |
1085 | } |
1086 | |
1087 | void mlirBlockArgumentSetType(MlirValue value, MlirType type) { |
1088 | if (auto blockArg = llvm::dyn_cast<BlockArgument>(Val: unwrap(c: value))) |
1089 | blockArg.setType(unwrap(c: type)); |
1090 | } |
1091 | |
1092 | MlirOperation mlirOpResultGetOwner(MlirValue value) { |
1093 | return wrap(cpp: llvm::dyn_cast<OpResult>(Val: unwrap(c: value)).getOwner()); |
1094 | } |
1095 | |
1096 | intptr_t mlirOpResultGetResultNumber(MlirValue value) { |
1097 | return static_cast<intptr_t>( |
1098 | llvm::dyn_cast<OpResult>(Val: unwrap(c: value)).getResultNumber()); |
1099 | } |
1100 | |
1101 | MlirType mlirValueGetType(MlirValue value) { |
1102 | return wrap(cpp: unwrap(c: value).getType()); |
1103 | } |
1104 | |
1105 | void mlirValueSetType(MlirValue value, MlirType type) { |
1106 | unwrap(c: value).setType(unwrap(c: type)); |
1107 | } |
1108 | |
1109 | void mlirValueDump(MlirValue value) { unwrap(c: value).dump(); } |
1110 | |
1111 | void mlirValuePrint(MlirValue value, MlirStringCallback callback, |
1112 | void *userData) { |
1113 | detail::CallbackOstream stream(callback, userData); |
1114 | unwrap(c: value).print(os&: stream); |
1115 | } |
1116 | |
1117 | void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, |
1118 | MlirStringCallback callback, void *userData) { |
1119 | detail::CallbackOstream stream(callback, userData); |
1120 | Value cppValue = unwrap(c: value); |
1121 | cppValue.printAsOperand(os&: stream, state&: *unwrap(c: state)); |
1122 | } |
1123 | |
1124 | MlirOpOperand mlirValueGetFirstUse(MlirValue value) { |
1125 | Value cppValue = unwrap(c: value); |
1126 | if (cppValue.use_empty()) |
1127 | return {}; |
1128 | |
1129 | OpOperand *opOperand = cppValue.use_begin().getOperand(); |
1130 | |
1131 | return wrap(cpp: opOperand); |
1132 | } |
1133 | |
1134 | void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) { |
1135 | unwrap(c: oldValue).replaceAllUsesWith(newValue: unwrap(c: newValue)); |
1136 | } |
1137 | |
1138 | void mlirValueReplaceAllUsesExcept(MlirValue oldValue, MlirValue newValue, |
1139 | intptr_t numExceptions, |
1140 | MlirOperation *exceptions) { |
1141 | Value oldValueCpp = unwrap(c: oldValue); |
1142 | Value newValueCpp = unwrap(c: newValue); |
1143 | |
1144 | llvm::SmallPtrSet<mlir::Operation *, 4> exceptionSet; |
1145 | for (intptr_t i = 0; i < numExceptions; ++i) { |
1146 | exceptionSet.insert(Ptr: unwrap(c: exceptions[i])); |
1147 | } |
1148 | |
1149 | oldValueCpp.replaceAllUsesExcept(newValue: newValueCpp, exceptions: exceptionSet); |
1150 | } |
1151 | |
1152 | MlirLocation mlirValueGetLocation(MlirValue v) { |
1153 | return wrap(cpp: unwrap(c: v).getLoc()); |
1154 | } |
1155 | |
1156 | MlirContext mlirValueGetContext(MlirValue v) { |
1157 | return wrap(cpp: unwrap(c: v).getContext()); |
1158 | } |
1159 | |
1160 | //===----------------------------------------------------------------------===// |
1161 | // OpOperand API. |
1162 | //===----------------------------------------------------------------------===// |
1163 | |
1164 | bool mlirOpOperandIsNull(MlirOpOperand opOperand) { return !opOperand.ptr; } |
1165 | |
1166 | MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand) { |
1167 | return wrap(cpp: unwrap(c: opOperand)->getOwner()); |
1168 | } |
1169 | |
1170 | MlirValue mlirOpOperandGetValue(MlirOpOperand opOperand) { |
1171 | return wrap(cpp: unwrap(c: opOperand)->get()); |
1172 | } |
1173 | |
1174 | unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand) { |
1175 | return unwrap(c: opOperand)->getOperandNumber(); |
1176 | } |
1177 | |
1178 | MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand) { |
1179 | if (mlirOpOperandIsNull(opOperand)) |
1180 | return {}; |
1181 | |
1182 | OpOperand *nextOpOperand = static_cast<OpOperand *>( |
1183 | unwrap(c: opOperand)->getNextOperandUsingThisValue()); |
1184 | |
1185 | if (!nextOpOperand) |
1186 | return {}; |
1187 | |
1188 | return wrap(cpp: nextOpOperand); |
1189 | } |
1190 | |
1191 | //===----------------------------------------------------------------------===// |
1192 | // Type API. |
1193 | //===----------------------------------------------------------------------===// |
1194 | |
1195 | MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) { |
1196 | return wrap(cpp: mlir::parseType(typeStr: unwrap(ref: type), context: unwrap(c: context))); |
1197 | } |
1198 | |
1199 | MlirContext mlirTypeGetContext(MlirType type) { |
1200 | return wrap(cpp: unwrap(c: type).getContext()); |
1201 | } |
1202 | |
1203 | MlirTypeID mlirTypeGetTypeID(MlirType type) { |
1204 | return wrap(cpp: unwrap(c: type).getTypeID()); |
1205 | } |
1206 | |
1207 | MlirDialect mlirTypeGetDialect(MlirType type) { |
1208 | return wrap(cpp: &unwrap(c: type).getDialect()); |
1209 | } |
1210 | |
1211 | bool mlirTypeEqual(MlirType t1, MlirType t2) { |
1212 | return unwrap(c: t1) == unwrap(c: t2); |
1213 | } |
1214 | |
1215 | void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { |
1216 | detail::CallbackOstream stream(callback, userData); |
1217 | unwrap(c: type).print(os&: stream); |
1218 | } |
1219 | |
1220 | void mlirTypeDump(MlirType type) { unwrap(c: type).dump(); } |
1221 | |
1222 | //===----------------------------------------------------------------------===// |
1223 | // Attribute API. |
1224 | //===----------------------------------------------------------------------===// |
1225 | |
1226 | MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) { |
1227 | return wrap(cpp: mlir::parseAttribute(attrStr: unwrap(ref: attr), context: unwrap(c: context))); |
1228 | } |
1229 | |
1230 | MlirContext mlirAttributeGetContext(MlirAttribute attribute) { |
1231 | return wrap(cpp: unwrap(c: attribute).getContext()); |
1232 | } |
1233 | |
1234 | MlirType mlirAttributeGetType(MlirAttribute attribute) { |
1235 | Attribute attr = unwrap(c: attribute); |
1236 | if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) |
1237 | return wrap(typedAttr.getType()); |
1238 | return wrap(NoneType::get(attr.getContext())); |
1239 | } |
1240 | |
1241 | MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) { |
1242 | return wrap(cpp: unwrap(c: attr).getTypeID()); |
1243 | } |
1244 | |
1245 | MlirDialect mlirAttributeGetDialect(MlirAttribute attr) { |
1246 | return wrap(cpp: &unwrap(c: attr).getDialect()); |
1247 | } |
1248 | |
1249 | bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { |
1250 | return unwrap(c: a1) == unwrap(c: a2); |
1251 | } |
1252 | |
1253 | void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, |
1254 | void *userData) { |
1255 | detail::CallbackOstream stream(callback, userData); |
1256 | unwrap(c: attr).print(os&: stream); |
1257 | } |
1258 | |
1259 | void mlirAttributeDump(MlirAttribute attr) { unwrap(c: attr).dump(); } |
1260 | |
1261 | MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, |
1262 | MlirAttribute attr) { |
1263 | return MlirNamedAttribute{.name: name, .attribute: attr}; |
1264 | } |
1265 | |
1266 | //===----------------------------------------------------------------------===// |
1267 | // Identifier API. |
1268 | //===----------------------------------------------------------------------===// |
1269 | |
1270 | MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) { |
1271 | return wrap(StringAttr::get(unwrap(context), unwrap(str))); |
1272 | } |
1273 | |
1274 | MlirContext mlirIdentifierGetContext(MlirIdentifier ident) { |
1275 | return wrap(unwrap(ident).getContext()); |
1276 | } |
1277 | |
1278 | bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) { |
1279 | return unwrap(ident) == unwrap(other); |
1280 | } |
1281 | |
1282 | MlirStringRef mlirIdentifierStr(MlirIdentifier ident) { |
1283 | return wrap(unwrap(ident).strref()); |
1284 | } |
1285 | |
1286 | //===----------------------------------------------------------------------===// |
1287 | // Symbol and SymbolTable API. |
1288 | //===----------------------------------------------------------------------===// |
1289 | |
1290 | MlirStringRef mlirSymbolTableGetSymbolAttributeName() { |
1291 | return wrap(ref: SymbolTable::getSymbolAttrName()); |
1292 | } |
1293 | |
1294 | MlirStringRef mlirSymbolTableGetVisibilityAttributeName() { |
1295 | return wrap(ref: SymbolTable::getVisibilityAttrName()); |
1296 | } |
1297 | |
1298 | MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) { |
1299 | if (!unwrap(c: operation)->hasTrait<OpTrait::SymbolTable>()) |
1300 | return wrap(cpp: static_cast<SymbolTable *>(nullptr)); |
1301 | return wrap(cpp: new SymbolTable(unwrap(c: operation))); |
1302 | } |
1303 | |
1304 | void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) { |
1305 | delete unwrap(c: symbolTable); |
1306 | } |
1307 | |
1308 | MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, |
1309 | MlirStringRef name) { |
1310 | return wrap(cpp: unwrap(c: symbolTable)->lookup(name: StringRef(name.data, name.length))); |
1311 | } |
1312 | |
1313 | MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, |
1314 | MlirOperation operation) { |
1315 | return wrap(cpp: (Attribute)unwrap(c: symbolTable)->insert(unwrap(c: operation))); |
1316 | } |
1317 | |
1318 | void mlirSymbolTableErase(MlirSymbolTable symbolTable, |
1319 | MlirOperation operation) { |
1320 | unwrap(c: symbolTable)->erase(symbol: unwrap(c: operation)); |
1321 | } |
1322 | |
1323 | MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, |
1324 | MlirStringRef newSymbol, |
1325 | MlirOperation from) { |
1326 | auto *cppFrom = unwrap(c: from); |
1327 | auto *context = cppFrom->getContext(); |
1328 | auto oldSymbolAttr = StringAttr::get(context, unwrap(oldSymbol)); |
1329 | auto newSymbolAttr = StringAttr::get(context, unwrap(newSymbol)); |
1330 | return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr, |
1331 | unwrap(c: from))); |
1332 | } |
1333 | |
1334 | void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, |
1335 | void (*callback)(MlirOperation, bool, |
1336 | void *userData), |
1337 | void *userData) { |
1338 | SymbolTable::walkSymbolTables(op: unwrap(c: from), allSymUsesVisible, |
1339 | callback: [&](Operation *foundOpCpp, bool isVisible) { |
1340 | callback(wrap(cpp: foundOpCpp), isVisible, |
1341 | userData); |
1342 | }); |
1343 | } |
1344 |
Definitions
- mlirContextCreate
- toThreadingEnum
- mlirContextCreateWithThreading
- mlirContextCreateWithRegistry
- mlirContextEqual
- mlirContextDestroy
- mlirContextSetAllowUnregisteredDialects
- mlirContextGetAllowUnregisteredDialects
- mlirContextGetNumRegisteredDialects
- mlirContextAppendDialectRegistry
- mlirContextGetNumLoadedDialects
- mlirContextGetOrLoadDialect
- mlirContextIsRegisteredOperation
- mlirContextEnableMultithreading
- mlirContextLoadAllAvailableDialects
- mlirContextSetThreadPool
- mlirContextGetNumThreads
- mlirContextGetThreadPool
- mlirDialectGetContext
- mlirDialectEqual
- mlirDialectGetNamespace
- mlirDialectRegistryCreate
- mlirDialectRegistryDestroy
- mlirAsmStateCreateForOperation
- findParent
- mlirAsmStateCreateForValue
- mlirAsmStateDestroy
- mlirOpPrintingFlagsCreate
- mlirOpPrintingFlagsDestroy
- mlirOpPrintingFlagsElideLargeElementsAttrs
- mlirOpPrintingFlagsElideLargeResourceString
- mlirOpPrintingFlagsEnableDebugInfo
- mlirOpPrintingFlagsPrintGenericOpForm
- mlirOpPrintingFlagsPrintNameLocAsPrefix
- mlirOpPrintingFlagsUseLocalScope
- mlirOpPrintingFlagsAssumeVerified
- mlirOpPrintingFlagsSkipRegions
- mlirBytecodeWriterConfigCreate
- mlirBytecodeWriterConfigDestroy
- mlirBytecodeWriterConfigDesiredEmitVersion
- mlirLocationGetAttribute
- mlirLocationFromAttribute
- mlirLocationFileLineColGet
- mlirLocationFileLineColRangeGet
- mlirLocationFileLineColRangeGetFilename
- mlirLocationFileLineColRangeGetStartLine
- mlirLocationFileLineColRangeGetStartColumn
- mlirLocationFileLineColRangeGetEndLine
- mlirLocationFileLineColRangeGetEndColumn
- mlirLocationFileLineColRangeGetTypeID
- mlirLocationIsAFileLineColRange
- mlirLocationCallSiteGet
- mlirLocationCallSiteGetCallee
- mlirLocationCallSiteGetCaller
- mlirLocationCallSiteGetTypeID
- mlirLocationIsACallSite
- mlirLocationFusedGet
- mlirLocationFusedGetNumLocations
- mlirLocationFusedGetLocations
- mlirLocationFusedGetMetadata
- mlirLocationFusedGetTypeID
- mlirLocationIsAFused
- mlirLocationNameGet
- mlirLocationNameGetName
- mlirLocationNameGetChildLoc
- mlirLocationNameGetTypeID
- mlirLocationIsAName
- mlirLocationUnknownGet
- mlirLocationEqual
- mlirLocationGetContext
- mlirLocationPrint
- mlirModuleCreateEmpty
- mlirModuleCreateParse
- mlirModuleCreateParseFromFile
- mlirModuleGetContext
- mlirModuleGetBody
- mlirModuleDestroy
- mlirModuleGetOperation
- mlirModuleFromOperation
- mlirOperationStateGet
- mlirOperationStateAddResults
- mlirOperationStateAddOperands
- mlirOperationStateAddOwnedRegions
- mlirOperationStateAddSuccessors
- mlirOperationStateAddAttributes
- mlirOperationStateEnableResultTypeInference
- inferOperationTypes
- mlirOperationCreate
- mlirOperationCreateParse
- mlirOperationClone
- mlirOperationDestroy
- mlirOperationRemoveFromParent
- mlirOperationEqual
- mlirOperationGetContext
- mlirOperationGetLocation
- mlirOperationGetTypeID
- mlirOperationGetName
- mlirOperationGetBlock
- mlirOperationGetParentOperation
- mlirOperationGetNumRegions
- mlirOperationGetRegion
- mlirOperationGetFirstRegion
- mlirRegionGetNextInOperation
- mlirOperationGetNextInBlock
- mlirOperationGetNumOperands
- mlirOperationGetOperand
- mlirOperationSetOperand
- mlirOperationSetOperands
- mlirOperationGetNumResults
- mlirOperationGetResult
- mlirOperationGetNumSuccessors
- mlirOperationGetSuccessor
- mlirOperationHasInherentAttributeByName
- mlirOperationGetInherentAttributeByName
- mlirOperationSetInherentAttributeByName
- mlirOperationGetNumDiscardableAttributes
- mlirOperationGetDiscardableAttribute
- mlirOperationGetDiscardableAttributeByName
- mlirOperationSetDiscardableAttributeByName
- mlirOperationRemoveDiscardableAttributeByName
- mlirOperationSetSuccessor
- mlirOperationGetNumAttributes
- mlirOperationGetAttribute
- mlirOperationGetAttributeByName
- mlirOperationSetAttributeByName
- mlirOperationRemoveAttributeByName
- mlirOperationPrint
- mlirOperationPrintWithFlags
- mlirOperationPrintWithState
- mlirOperationWriteBytecode
- mlirOperationWriteBytecodeWithConfig
- mlirOperationDump
- mlirOperationVerify
- mlirOperationMoveAfter
- mlirOperationMoveBefore
- unwrap
- mlirOperationWalk
- mlirRegionCreate
- mlirRegionEqual
- mlirRegionGetFirstBlock
- mlirRegionAppendOwnedBlock
- mlirRegionInsertOwnedBlock
- mlirRegionInsertOwnedBlockAfter
- mlirRegionInsertOwnedBlockBefore
- mlirRegionDestroy
- mlirRegionTakeBody
- mlirBlockCreate
- mlirBlockEqual
- mlirBlockGetParentOperation
- mlirBlockGetParentRegion
- mlirBlockGetNextInRegion
- mlirBlockGetFirstOperation
- mlirBlockGetTerminator
- mlirBlockAppendOwnedOperation
- mlirBlockInsertOwnedOperation
- mlirBlockInsertOwnedOperationAfter
- mlirBlockInsertOwnedOperationBefore
- mlirBlockDestroy
- mlirBlockDetach
- mlirBlockGetNumArguments
- mlirBlockAddArgument
- mlirBlockEraseArgument
- mlirBlockInsertArgument
- mlirBlockGetArgument
- mlirBlockPrint
- mlirValueEqual
- mlirValueIsABlockArgument
- mlirValueIsAOpResult
- mlirBlockArgumentGetOwner
- mlirBlockArgumentGetArgNumber
- mlirBlockArgumentSetType
- mlirOpResultGetOwner
- mlirOpResultGetResultNumber
- mlirValueGetType
- mlirValueSetType
- mlirValueDump
- mlirValuePrint
- mlirValuePrintAsOperand
- mlirValueGetFirstUse
- mlirValueReplaceAllUsesOfWith
- mlirValueReplaceAllUsesExcept
- mlirValueGetLocation
- mlirValueGetContext
- mlirOpOperandIsNull
- mlirOpOperandGetOwner
- mlirOpOperandGetValue
- mlirOpOperandGetOperandNumber
- mlirOpOperandGetNextUse
- mlirTypeParseGet
- mlirTypeGetContext
- mlirTypeGetTypeID
- mlirTypeGetDialect
- mlirTypeEqual
- mlirTypePrint
- mlirTypeDump
- mlirAttributeParseGet
- mlirAttributeGetContext
- mlirAttributeGetType
- mlirAttributeGetTypeID
- mlirAttributeGetDialect
- mlirAttributeEqual
- mlirAttributePrint
- mlirAttributeDump
- mlirNamedAttributeGet
- mlirIdentifierGet
- mlirIdentifierGetContext
- mlirIdentifierEqual
- mlirIdentifierStr
- mlirSymbolTableGetSymbolAttributeName
- mlirSymbolTableGetVisibilityAttributeName
- mlirSymbolTableCreate
- mlirSymbolTableDestroy
- mlirSymbolTableLookup
- mlirSymbolTableInsert
- mlirSymbolTableErase
- mlirSymbolTableReplaceAllSymbolUses
Learn to use CMake with our Intro Training
Find out more