| 1 | //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir-c/IR.h" |
| 10 | #include "mlir-c/Support.h" |
| 11 | |
| 12 | #include "mlir/AsmParser/AsmParser.h" |
| 13 | #include "mlir/Bytecode/BytecodeWriter.h" |
| 14 | #include "mlir/CAPI/IR.h" |
| 15 | #include "mlir/CAPI/Support.h" |
| 16 | #include "mlir/CAPI/Utils.h" |
| 17 | #include "mlir/IR/Attributes.h" |
| 18 | #include "mlir/IR/BuiltinAttributes.h" |
| 19 | #include "mlir/IR/BuiltinOps.h" |
| 20 | #include "mlir/IR/Diagnostics.h" |
| 21 | #include "mlir/IR/Dialect.h" |
| 22 | #include "mlir/IR/Location.h" |
| 23 | #include "mlir/IR/Operation.h" |
| 24 | #include "mlir/IR/OperationSupport.h" |
| 25 | #include "mlir/IR/OwningOpRef.h" |
| 26 | #include "mlir/IR/Types.h" |
| 27 | #include "mlir/IR/Value.h" |
| 28 | #include "mlir/IR/Verifier.h" |
| 29 | #include "mlir/IR/Visitors.h" |
| 30 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
| 31 | #include "mlir/Parser/Parser.h" |
| 32 | #include "llvm/ADT/SmallPtrSet.h" |
| 33 | #include "llvm/Support/ThreadPool.h" |
| 34 | |
| 35 | #include <cstddef> |
| 36 | #include <memory> |
| 37 | #include <optional> |
| 38 | |
| 39 | using namespace mlir; |
| 40 | |
| 41 | //===----------------------------------------------------------------------===// |
| 42 | // Context API. |
| 43 | //===----------------------------------------------------------------------===// |
| 44 | |
| 45 | MlirContext mlirContextCreate() { |
| 46 | auto *context = new MLIRContext; |
| 47 | return wrap(cpp: context); |
| 48 | } |
| 49 | |
| 50 | static inline MLIRContext::Threading toThreadingEnum(bool threadingEnabled) { |
| 51 | return threadingEnabled ? MLIRContext::Threading::ENABLED |
| 52 | : MLIRContext::Threading::DISABLED; |
| 53 | } |
| 54 | |
| 55 | MlirContext mlirContextCreateWithThreading(bool threadingEnabled) { |
| 56 | auto *context = new MLIRContext(toThreadingEnum(threadingEnabled)); |
| 57 | return wrap(cpp: context); |
| 58 | } |
| 59 | |
| 60 | MlirContext mlirContextCreateWithRegistry(MlirDialectRegistry registry, |
| 61 | bool threadingEnabled) { |
| 62 | auto *context = |
| 63 | new MLIRContext(*unwrap(c: registry), toThreadingEnum(threadingEnabled)); |
| 64 | return wrap(cpp: context); |
| 65 | } |
| 66 | |
| 67 | bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { |
| 68 | return unwrap(c: ctx1) == unwrap(c: ctx2); |
| 69 | } |
| 70 | |
| 71 | void mlirContextDestroy(MlirContext context) { delete unwrap(c: context); } |
| 72 | |
| 73 | void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) { |
| 74 | unwrap(c: context)->allowUnregisteredDialects(allow); |
| 75 | } |
| 76 | |
| 77 | bool mlirContextGetAllowUnregisteredDialects(MlirContext context) { |
| 78 | return unwrap(c: context)->allowsUnregisteredDialects(); |
| 79 | } |
| 80 | intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) { |
| 81 | return static_cast<intptr_t>(unwrap(c: context)->getAvailableDialects().size()); |
| 82 | } |
| 83 | |
| 84 | void mlirContextAppendDialectRegistry(MlirContext ctx, |
| 85 | MlirDialectRegistry registry) { |
| 86 | unwrap(c: ctx)->appendDialectRegistry(registry: *unwrap(c: registry)); |
| 87 | } |
| 88 | |
| 89 | // TODO: expose a cheaper way than constructing + sorting a vector only to take |
| 90 | // its size. |
| 91 | intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { |
| 92 | return static_cast<intptr_t>(unwrap(c: context)->getLoadedDialects().size()); |
| 93 | } |
| 94 | |
| 95 | MlirDialect mlirContextGetOrLoadDialect(MlirContext context, |
| 96 | MlirStringRef name) { |
| 97 | return wrap(cpp: unwrap(c: context)->getOrLoadDialect(name: unwrap(ref: name))); |
| 98 | } |
| 99 | |
| 100 | bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) { |
| 101 | return unwrap(c: context)->isOperationRegistered(name: unwrap(ref: name)); |
| 102 | } |
| 103 | |
| 104 | void mlirContextEnableMultithreading(MlirContext context, bool enable) { |
| 105 | return unwrap(c: context)->enableMultithreading(enable); |
| 106 | } |
| 107 | |
| 108 | void mlirContextLoadAllAvailableDialects(MlirContext context) { |
| 109 | unwrap(c: context)->loadAllAvailableDialects(); |
| 110 | } |
| 111 | |
| 112 | void mlirContextSetThreadPool(MlirContext context, |
| 113 | MlirLlvmThreadPool threadPool) { |
| 114 | unwrap(c: context)->setThreadPool(*unwrap(c: threadPool)); |
| 115 | } |
| 116 | |
| 117 | unsigned mlirContextGetNumThreads(MlirContext context) { |
| 118 | return unwrap(c: context)->getNumThreads(); |
| 119 | } |
| 120 | |
| 121 | MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context) { |
| 122 | return wrap(cpp: &unwrap(c: context)->getThreadPool()); |
| 123 | } |
| 124 | |
| 125 | //===----------------------------------------------------------------------===// |
| 126 | // Dialect API. |
| 127 | //===----------------------------------------------------------------------===// |
| 128 | |
| 129 | MlirContext mlirDialectGetContext(MlirDialect dialect) { |
| 130 | return wrap(cpp: unwrap(c: dialect)->getContext()); |
| 131 | } |
| 132 | |
| 133 | bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) { |
| 134 | return unwrap(c: dialect1) == unwrap(c: dialect2); |
| 135 | } |
| 136 | |
| 137 | MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) { |
| 138 | return wrap(ref: unwrap(c: dialect)->getNamespace()); |
| 139 | } |
| 140 | |
| 141 | //===----------------------------------------------------------------------===// |
| 142 | // DialectRegistry API. |
| 143 | //===----------------------------------------------------------------------===// |
| 144 | |
| 145 | MlirDialectRegistry mlirDialectRegistryCreate() { |
| 146 | return wrap(cpp: new DialectRegistry()); |
| 147 | } |
| 148 | |
| 149 | void mlirDialectRegistryDestroy(MlirDialectRegistry registry) { |
| 150 | delete unwrap(c: registry); |
| 151 | } |
| 152 | |
| 153 | //===----------------------------------------------------------------------===// |
| 154 | // AsmState API. |
| 155 | //===----------------------------------------------------------------------===// |
| 156 | |
| 157 | MlirAsmState mlirAsmStateCreateForOperation(MlirOperation op, |
| 158 | MlirOpPrintingFlags flags) { |
| 159 | return wrap(cpp: new AsmState(unwrap(c: op), *unwrap(c: flags))); |
| 160 | } |
| 161 | |
| 162 | static Operation *findParent(Operation *op, bool shouldUseLocalScope) { |
| 163 | do { |
| 164 | // If we are printing local scope, stop at the first operation that is |
| 165 | // isolated from above. |
| 166 | if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>()) |
| 167 | break; |
| 168 | |
| 169 | // Otherwise, traverse up to the next parent. |
| 170 | Operation *parentOp = op->getParentOp(); |
| 171 | if (!parentOp) |
| 172 | break; |
| 173 | op = parentOp; |
| 174 | } while (true); |
| 175 | return op; |
| 176 | } |
| 177 | |
| 178 | MlirAsmState mlirAsmStateCreateForValue(MlirValue value, |
| 179 | MlirOpPrintingFlags flags) { |
| 180 | Operation *op; |
| 181 | mlir::Value val = unwrap(c: value); |
| 182 | if (auto result = llvm::dyn_cast<OpResult>(Val&: val)) { |
| 183 | op = result.getOwner(); |
| 184 | } else { |
| 185 | op = llvm::cast<BlockArgument>(Val&: val).getOwner()->getParentOp(); |
| 186 | if (!op) { |
| 187 | emitError(loc: val.getLoc()) << "<<UNKNOWN SSA VALUE>>" ; |
| 188 | return {.ptr: nullptr}; |
| 189 | } |
| 190 | } |
| 191 | op = findParent(op, shouldUseLocalScope: unwrap(c: flags)->shouldUseLocalScope()); |
| 192 | return wrap(cpp: new AsmState(op, *unwrap(c: flags))); |
| 193 | } |
| 194 | |
| 195 | /// Destroys printing flags created with mlirAsmStateCreate. |
| 196 | void mlirAsmStateDestroy(MlirAsmState state) { delete unwrap(c: state); } |
| 197 | |
| 198 | //===----------------------------------------------------------------------===// |
| 199 | // Printing flags API. |
| 200 | //===----------------------------------------------------------------------===// |
| 201 | |
| 202 | MlirOpPrintingFlags mlirOpPrintingFlagsCreate() { |
| 203 | return wrap(cpp: new OpPrintingFlags()); |
| 204 | } |
| 205 | |
| 206 | void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) { |
| 207 | delete unwrap(c: flags); |
| 208 | } |
| 209 | |
| 210 | void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, |
| 211 | intptr_t largeElementLimit) { |
| 212 | unwrap(c: flags)->elideLargeElementsAttrs(largeElementLimit); |
| 213 | } |
| 214 | |
| 215 | void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags, |
| 216 | intptr_t largeResourceLimit) { |
| 217 | unwrap(c: flags)->elideLargeResourceString(largeResourceLimit); |
| 218 | } |
| 219 | |
| 220 | void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, |
| 221 | bool prettyForm) { |
| 222 | unwrap(c: flags)->enableDebugInfo(enable, /*prettyForm=*/prettyForm); |
| 223 | } |
| 224 | |
| 225 | void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) { |
| 226 | unwrap(c: flags)->printGenericOpForm(); |
| 227 | } |
| 228 | |
| 229 | void mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags) { |
| 230 | unwrap(c: flags)->printNameLocAsPrefix(); |
| 231 | } |
| 232 | |
| 233 | void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) { |
| 234 | unwrap(c: flags)->useLocalScope(); |
| 235 | } |
| 236 | |
| 237 | void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) { |
| 238 | unwrap(c: flags)->assumeVerified(); |
| 239 | } |
| 240 | |
| 241 | void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags) { |
| 242 | unwrap(c: flags)->skipRegions(); |
| 243 | } |
| 244 | //===----------------------------------------------------------------------===// |
| 245 | // Bytecode printing flags API. |
| 246 | //===----------------------------------------------------------------------===// |
| 247 | |
| 248 | MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate() { |
| 249 | return wrap(cpp: new BytecodeWriterConfig()); |
| 250 | } |
| 251 | |
| 252 | void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config) { |
| 253 | delete unwrap(c: config); |
| 254 | } |
| 255 | |
| 256 | void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, |
| 257 | int64_t version) { |
| 258 | unwrap(c: flags)->setDesiredBytecodeVersion(version); |
| 259 | } |
| 260 | |
| 261 | //===----------------------------------------------------------------------===// |
| 262 | // Location API. |
| 263 | //===----------------------------------------------------------------------===// |
| 264 | |
| 265 | MlirAttribute mlirLocationGetAttribute(MlirLocation location) { |
| 266 | return wrap(cpp: LocationAttr(unwrap(c: location))); |
| 267 | } |
| 268 | |
| 269 | MlirLocation mlirLocationFromAttribute(MlirAttribute attribute) { |
| 270 | return wrap(cpp: Location(llvm::dyn_cast<LocationAttr>(Val: unwrap(c: attribute)))); |
| 271 | } |
| 272 | |
| 273 | MlirLocation mlirLocationFileLineColGet(MlirContext context, |
| 274 | MlirStringRef filename, unsigned line, |
| 275 | unsigned col) { |
| 276 | return wrap(cpp: Location( |
| 277 | FileLineColLoc::get(context: unwrap(c: context), fileName: unwrap(ref: filename), line, column: col))); |
| 278 | } |
| 279 | |
| 280 | MlirLocation |
| 281 | mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename, |
| 282 | unsigned startLine, unsigned startCol, |
| 283 | unsigned endLine, unsigned endCol) { |
| 284 | return wrap( |
| 285 | Location(FileLineColRange::get(unwrap(context), unwrap(filename), |
| 286 | startLine, startCol, endLine, endCol))); |
| 287 | } |
| 288 | |
| 289 | MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location) { |
| 290 | return wrap(llvm::dyn_cast<FileLineColRange>(unwrap(c: location)).getFilename()); |
| 291 | } |
| 292 | |
| 293 | int mlirLocationFileLineColRangeGetStartLine(MlirLocation location) { |
| 294 | if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location))) |
| 295 | return loc.getStartLine(); |
| 296 | return -1; |
| 297 | } |
| 298 | |
| 299 | int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location) { |
| 300 | if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location))) |
| 301 | return loc.getStartColumn(); |
| 302 | return -1; |
| 303 | } |
| 304 | |
| 305 | int mlirLocationFileLineColRangeGetEndLine(MlirLocation location) { |
| 306 | if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location))) |
| 307 | return loc.getEndLine(); |
| 308 | return -1; |
| 309 | } |
| 310 | |
| 311 | int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location) { |
| 312 | if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location))) |
| 313 | return loc.getEndColumn(); |
| 314 | return -1; |
| 315 | } |
| 316 | |
| 317 | MlirTypeID mlirLocationFileLineColRangeGetTypeID() { |
| 318 | return wrap(FileLineColRange::getTypeID()); |
| 319 | } |
| 320 | |
| 321 | bool mlirLocationIsAFileLineColRange(MlirLocation location) { |
| 322 | return isa<FileLineColRange>(Val: unwrap(c: location)); |
| 323 | } |
| 324 | |
| 325 | MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) { |
| 326 | return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller)))); |
| 327 | } |
| 328 | |
| 329 | MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location) { |
| 330 | return wrap( |
| 331 | Location(llvm::dyn_cast<CallSiteLoc>(unwrap(location)).getCallee())); |
| 332 | } |
| 333 | |
| 334 | MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location) { |
| 335 | return wrap( |
| 336 | Location(llvm::dyn_cast<CallSiteLoc>(unwrap(location)).getCaller())); |
| 337 | } |
| 338 | |
| 339 | MlirTypeID mlirLocationCallSiteGetTypeID() { |
| 340 | return wrap(CallSiteLoc::getTypeID()); |
| 341 | } |
| 342 | |
| 343 | bool mlirLocationIsACallSite(MlirLocation location) { |
| 344 | return isa<CallSiteLoc>(unwrap(location)); |
| 345 | } |
| 346 | |
| 347 | MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, |
| 348 | MlirLocation const *locations, |
| 349 | MlirAttribute metadata) { |
| 350 | SmallVector<Location, 4> locs; |
| 351 | ArrayRef<Location> unwrappedLocs = unwrapList(size: nLocations, first: locations, storage&: locs); |
| 352 | return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx))); |
| 353 | } |
| 354 | |
| 355 | unsigned mlirLocationFusedGetNumLocations(MlirLocation location) { |
| 356 | if (auto locationsArrRef = llvm::dyn_cast<FusedLoc>(unwrap(location))) |
| 357 | return locationsArrRef.getLocations().size(); |
| 358 | return 0; |
| 359 | } |
| 360 | |
| 361 | void mlirLocationFusedGetLocations(MlirLocation location, |
| 362 | MlirLocation *locationsCPtr) { |
| 363 | if (auto locationsArrRef = llvm::dyn_cast<FusedLoc>(unwrap(location))) { |
| 364 | for (auto [i, location] : llvm::enumerate(locationsArrRef.getLocations())) |
| 365 | locationsCPtr[i] = wrap(location); |
| 366 | } |
| 367 | } |
| 368 | |
| 369 | MlirAttribute mlirLocationFusedGetMetadata(MlirLocation location) { |
| 370 | return wrap(llvm::dyn_cast<FusedLoc>(unwrap(location)).getMetadata()); |
| 371 | } |
| 372 | |
| 373 | MlirTypeID mlirLocationFusedGetTypeID() { return wrap(FusedLoc::getTypeID()); } |
| 374 | |
| 375 | bool mlirLocationIsAFused(MlirLocation location) { |
| 376 | return isa<FusedLoc>(unwrap(location)); |
| 377 | } |
| 378 | |
| 379 | MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, |
| 380 | MlirLocation childLoc) { |
| 381 | if (mlirLocationIsNull(childLoc)) |
| 382 | return wrap( |
| 383 | Location(NameLoc::get(StringAttr::get(unwrap(context), unwrap(name))))); |
| 384 | return wrap(Location(NameLoc::get( |
| 385 | StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc)))); |
| 386 | } |
| 387 | |
| 388 | MlirIdentifier mlirLocationNameGetName(MlirLocation location) { |
| 389 | return wrap((llvm::dyn_cast<NameLoc>(unwrap(location)).getName())); |
| 390 | } |
| 391 | |
| 392 | MlirLocation mlirLocationNameGetChildLoc(MlirLocation location) { |
| 393 | return wrap( |
| 394 | Location(llvm::dyn_cast<NameLoc>(unwrap(location)).getChildLoc())); |
| 395 | } |
| 396 | |
| 397 | MlirTypeID mlirLocationNameGetTypeID() { return wrap(NameLoc::getTypeID()); } |
| 398 | |
| 399 | bool mlirLocationIsAName(MlirLocation location) { |
| 400 | return isa<NameLoc>(unwrap(location)); |
| 401 | } |
| 402 | |
| 403 | MlirLocation mlirLocationUnknownGet(MlirContext context) { |
| 404 | return wrap(Location(UnknownLoc::get(unwrap(context)))); |
| 405 | } |
| 406 | |
| 407 | bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) { |
| 408 | return unwrap(c: l1) == unwrap(c: l2); |
| 409 | } |
| 410 | |
| 411 | MlirContext mlirLocationGetContext(MlirLocation location) { |
| 412 | return wrap(cpp: unwrap(c: location).getContext()); |
| 413 | } |
| 414 | |
| 415 | void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, |
| 416 | void *userData) { |
| 417 | detail::CallbackOstream stream(callback, userData); |
| 418 | unwrap(c: location).print(os&: stream); |
| 419 | } |
| 420 | |
| 421 | //===----------------------------------------------------------------------===// |
| 422 | // Module API. |
| 423 | //===----------------------------------------------------------------------===// |
| 424 | |
| 425 | MlirModule mlirModuleCreateEmpty(MlirLocation location) { |
| 426 | return wrap(ModuleOp::create(unwrap(location))); |
| 427 | } |
| 428 | |
| 429 | MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) { |
| 430 | OwningOpRef<ModuleOp> owning = |
| 431 | parseSourceString<ModuleOp>(unwrap(ref: module), unwrap(c: context)); |
| 432 | if (!owning) |
| 433 | return MlirModule{.ptr: nullptr}; |
| 434 | return MlirModule{owning.release().getOperation()}; |
| 435 | } |
| 436 | |
| 437 | MlirModule mlirModuleCreateParseFromFile(MlirContext context, |
| 438 | MlirStringRef fileName) { |
| 439 | OwningOpRef<ModuleOp> owning = |
| 440 | parseSourceFile<ModuleOp>(filename: unwrap(ref: fileName), config: unwrap(c: context)); |
| 441 | if (!owning) |
| 442 | return MlirModule{.ptr: nullptr}; |
| 443 | return MlirModule{owning.release().getOperation()}; |
| 444 | } |
| 445 | |
| 446 | MlirContext mlirModuleGetContext(MlirModule module) { |
| 447 | return wrap(unwrap(module).getContext()); |
| 448 | } |
| 449 | |
| 450 | MlirBlock mlirModuleGetBody(MlirModule module) { |
| 451 | return wrap(unwrap(module).getBody()); |
| 452 | } |
| 453 | |
| 454 | void mlirModuleDestroy(MlirModule module) { |
| 455 | // Transfer ownership to an OwningOpRef<ModuleOp> so that its destructor is |
| 456 | // called. |
| 457 | OwningOpRef<ModuleOp>(unwrap(module)); |
| 458 | } |
| 459 | |
| 460 | MlirOperation mlirModuleGetOperation(MlirModule module) { |
| 461 | return wrap(unwrap(module).getOperation()); |
| 462 | } |
| 463 | |
| 464 | MlirModule mlirModuleFromOperation(MlirOperation op) { |
| 465 | return wrap(dyn_cast<ModuleOp>(unwrap(c: op))); |
| 466 | } |
| 467 | |
| 468 | //===----------------------------------------------------------------------===// |
| 469 | // Operation state API. |
| 470 | //===----------------------------------------------------------------------===// |
| 471 | |
| 472 | MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) { |
| 473 | MlirOperationState state; |
| 474 | state.name = name; |
| 475 | state.location = loc; |
| 476 | state.nResults = 0; |
| 477 | state.results = nullptr; |
| 478 | state.nOperands = 0; |
| 479 | state.operands = nullptr; |
| 480 | state.nRegions = 0; |
| 481 | state.regions = nullptr; |
| 482 | state.nSuccessors = 0; |
| 483 | state.successors = nullptr; |
| 484 | state.nAttributes = 0; |
| 485 | state.attributes = nullptr; |
| 486 | state.enableResultTypeInference = false; |
| 487 | return state; |
| 488 | } |
| 489 | |
| 490 | #define APPEND_ELEMS(type, sizeName, elemName) \ |
| 491 | state->elemName = \ |
| 492 | (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ |
| 493 | memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ |
| 494 | state->sizeName += n; |
| 495 | |
| 496 | void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, |
| 497 | MlirType const *results) { |
| 498 | APPEND_ELEMS(MlirType, nResults, results); |
| 499 | } |
| 500 | |
| 501 | void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, |
| 502 | MlirValue const *operands) { |
| 503 | APPEND_ELEMS(MlirValue, nOperands, operands); |
| 504 | } |
| 505 | void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, |
| 506 | MlirRegion const *regions) { |
| 507 | APPEND_ELEMS(MlirRegion, nRegions, regions); |
| 508 | } |
| 509 | void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, |
| 510 | MlirBlock const *successors) { |
| 511 | APPEND_ELEMS(MlirBlock, nSuccessors, successors); |
| 512 | } |
| 513 | void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, |
| 514 | MlirNamedAttribute const *attributes) { |
| 515 | APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); |
| 516 | } |
| 517 | |
| 518 | void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) { |
| 519 | state->enableResultTypeInference = true; |
| 520 | } |
| 521 | |
| 522 | //===----------------------------------------------------------------------===// |
| 523 | // Operation API. |
| 524 | //===----------------------------------------------------------------------===// |
| 525 | |
| 526 | static LogicalResult inferOperationTypes(OperationState &state) { |
| 527 | MLIRContext *context = state.getContext(); |
| 528 | std::optional<RegisteredOperationName> info = state.name.getRegisteredInfo(); |
| 529 | if (!info) { |
| 530 | emitError(loc: state.location) |
| 531 | << "type inference was requested for the operation " << state.name |
| 532 | << ", but the operation was not registered; ensure that the dialect " |
| 533 | "containing the operation is linked into MLIR and registered with " |
| 534 | "the context" ; |
| 535 | return failure(); |
| 536 | } |
| 537 | |
| 538 | auto *inferInterface = info->getInterface<InferTypeOpInterface>(); |
| 539 | if (!inferInterface) { |
| 540 | emitError(loc: state.location) |
| 541 | << "type inference was requested for the operation " << state.name |
| 542 | << ", but the operation does not support type inference; result " |
| 543 | "types must be specified explicitly" ; |
| 544 | return failure(); |
| 545 | } |
| 546 | |
| 547 | DictionaryAttr attributes = state.attributes.getDictionary(context); |
| 548 | OpaqueProperties properties = state.getRawProperties(); |
| 549 | |
| 550 | if (!properties && info->getOpPropertyByteSize() > 0 && !attributes.empty()) { |
| 551 | auto prop = std::make_unique<char[]>(num: info->getOpPropertyByteSize()); |
| 552 | properties = OpaqueProperties(prop.get()); |
| 553 | if (properties) { |
| 554 | auto emitError = [&]() { |
| 555 | return mlir::emitError(state.location) |
| 556 | << " failed properties conversion while building " |
| 557 | << state.name.getStringRef() << " with `" << attributes << "`: " ; |
| 558 | }; |
| 559 | if (failed(info->setOpPropertiesFromAttribute(opName: state.name, properties, |
| 560 | attr: attributes, emitError))) |
| 561 | return failure(); |
| 562 | } |
| 563 | if (succeeded(inferInterface->inferReturnTypes( |
| 564 | context, state.location, state.operands, attributes, properties, |
| 565 | state.regions, state.types))) { |
| 566 | return success(); |
| 567 | } |
| 568 | // Diagnostic emitted by interface. |
| 569 | return failure(); |
| 570 | } |
| 571 | |
| 572 | if (succeeded(inferInterface->inferReturnTypes( |
| 573 | context, state.location, state.operands, attributes, properties, |
| 574 | state.regions, state.types))) |
| 575 | return success(); |
| 576 | |
| 577 | // Diagnostic emitted by interface. |
| 578 | return failure(); |
| 579 | } |
| 580 | |
| 581 | MlirOperation mlirOperationCreate(MlirOperationState *state) { |
| 582 | assert(state); |
| 583 | OperationState cppState(unwrap(c: state->location), unwrap(ref: state->name)); |
| 584 | SmallVector<Type, 4> resultStorage; |
| 585 | SmallVector<Value, 8> operandStorage; |
| 586 | SmallVector<Block *, 2> successorStorage; |
| 587 | cppState.addTypes(newTypes: unwrapList(size: state->nResults, first: state->results, storage&: resultStorage)); |
| 588 | cppState.addOperands( |
| 589 | newOperands: unwrapList(size: state->nOperands, first: state->operands, storage&: operandStorage)); |
| 590 | cppState.addSuccessors( |
| 591 | newSuccessors: unwrapList(size: state->nSuccessors, first: state->successors, storage&: successorStorage)); |
| 592 | |
| 593 | cppState.attributes.reserve(N: state->nAttributes); |
| 594 | for (intptr_t i = 0; i < state->nAttributes; ++i) |
| 595 | cppState.addAttribute(unwrap(state->attributes[i].name), |
| 596 | unwrap(c: state->attributes[i].attribute)); |
| 597 | |
| 598 | for (intptr_t i = 0; i < state->nRegions; ++i) |
| 599 | cppState.addRegion(region: std::unique_ptr<Region>(unwrap(c: state->regions[i]))); |
| 600 | |
| 601 | free(ptr: state->results); |
| 602 | free(ptr: state->operands); |
| 603 | free(ptr: state->successors); |
| 604 | free(ptr: state->regions); |
| 605 | free(ptr: state->attributes); |
| 606 | |
| 607 | // Infer result types. |
| 608 | if (state->enableResultTypeInference) { |
| 609 | assert(cppState.types.empty() && |
| 610 | "result type inference enabled and result types provided" ); |
| 611 | if (failed(Result: inferOperationTypes(state&: cppState))) |
| 612 | return {.ptr: nullptr}; |
| 613 | } |
| 614 | |
| 615 | return wrap(cpp: Operation::create(state: cppState)); |
| 616 | } |
| 617 | |
| 618 | MlirOperation mlirOperationCreateParse(MlirContext context, |
| 619 | MlirStringRef sourceStr, |
| 620 | MlirStringRef sourceName) { |
| 621 | |
| 622 | return wrap( |
| 623 | cpp: parseSourceString(sourceStr: unwrap(ref: sourceStr), config: unwrap(c: context), sourceName: unwrap(ref: sourceName)) |
| 624 | .release()); |
| 625 | } |
| 626 | |
| 627 | MlirOperation mlirOperationClone(MlirOperation op) { |
| 628 | return wrap(cpp: unwrap(c: op)->clone()); |
| 629 | } |
| 630 | |
| 631 | void mlirOperationDestroy(MlirOperation op) { unwrap(c: op)->erase(); } |
| 632 | |
| 633 | void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(c: op)->remove(); } |
| 634 | |
| 635 | bool mlirOperationEqual(MlirOperation op, MlirOperation other) { |
| 636 | return unwrap(c: op) == unwrap(c: other); |
| 637 | } |
| 638 | |
| 639 | MlirContext mlirOperationGetContext(MlirOperation op) { |
| 640 | return wrap(cpp: unwrap(c: op)->getContext()); |
| 641 | } |
| 642 | |
| 643 | MlirLocation mlirOperationGetLocation(MlirOperation op) { |
| 644 | return wrap(cpp: unwrap(c: op)->getLoc()); |
| 645 | } |
| 646 | |
| 647 | MlirTypeID mlirOperationGetTypeID(MlirOperation op) { |
| 648 | if (auto info = unwrap(c: op)->getRegisteredInfo()) |
| 649 | return wrap(cpp: info->getTypeID()); |
| 650 | return {.ptr: nullptr}; |
| 651 | } |
| 652 | |
| 653 | MlirIdentifier mlirOperationGetName(MlirOperation op) { |
| 654 | return wrap(unwrap(c: op)->getName().getIdentifier()); |
| 655 | } |
| 656 | |
| 657 | MlirBlock mlirOperationGetBlock(MlirOperation op) { |
| 658 | return wrap(cpp: unwrap(c: op)->getBlock()); |
| 659 | } |
| 660 | |
| 661 | MlirOperation mlirOperationGetParentOperation(MlirOperation op) { |
| 662 | return wrap(cpp: unwrap(c: op)->getParentOp()); |
| 663 | } |
| 664 | |
| 665 | intptr_t mlirOperationGetNumRegions(MlirOperation op) { |
| 666 | return static_cast<intptr_t>(unwrap(c: op)->getNumRegions()); |
| 667 | } |
| 668 | |
| 669 | MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { |
| 670 | return wrap(cpp: &unwrap(c: op)->getRegion(index: static_cast<unsigned>(pos))); |
| 671 | } |
| 672 | |
| 673 | MlirRegion mlirOperationGetFirstRegion(MlirOperation op) { |
| 674 | Operation *cppOp = unwrap(c: op); |
| 675 | if (cppOp->getNumRegions() == 0) |
| 676 | return wrap(cpp: static_cast<Region *>(nullptr)); |
| 677 | return wrap(cpp: &cppOp->getRegion(index: 0)); |
| 678 | } |
| 679 | |
| 680 | MlirRegion mlirRegionGetNextInOperation(MlirRegion region) { |
| 681 | Region *cppRegion = unwrap(c: region); |
| 682 | Operation *parent = cppRegion->getParentOp(); |
| 683 | intptr_t next = cppRegion->getRegionNumber() + 1; |
| 684 | if (parent->getNumRegions() > next) |
| 685 | return wrap(cpp: &parent->getRegion(index: next)); |
| 686 | return wrap(cpp: static_cast<Region *>(nullptr)); |
| 687 | } |
| 688 | |
| 689 | MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { |
| 690 | return wrap(unwrap(c: op)->getNextNode()); |
| 691 | } |
| 692 | |
| 693 | intptr_t mlirOperationGetNumOperands(MlirOperation op) { |
| 694 | return static_cast<intptr_t>(unwrap(c: op)->getNumOperands()); |
| 695 | } |
| 696 | |
| 697 | MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { |
| 698 | return wrap(cpp: unwrap(c: op)->getOperand(idx: static_cast<unsigned>(pos))); |
| 699 | } |
| 700 | |
| 701 | void mlirOperationSetOperand(MlirOperation op, intptr_t pos, |
| 702 | MlirValue newValue) { |
| 703 | unwrap(c: op)->setOperand(idx: static_cast<unsigned>(pos), value: unwrap(c: newValue)); |
| 704 | } |
| 705 | |
| 706 | void mlirOperationSetOperands(MlirOperation op, intptr_t nOperands, |
| 707 | MlirValue const *operands) { |
| 708 | SmallVector<Value> ops; |
| 709 | unwrap(c: op)->setOperands(unwrapList(size: nOperands, first: operands, storage&: ops)); |
| 710 | } |
| 711 | |
| 712 | intptr_t mlirOperationGetNumResults(MlirOperation op) { |
| 713 | return static_cast<intptr_t>(unwrap(c: op)->getNumResults()); |
| 714 | } |
| 715 | |
| 716 | MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { |
| 717 | return wrap(cpp: unwrap(c: op)->getResult(idx: static_cast<unsigned>(pos))); |
| 718 | } |
| 719 | |
| 720 | intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { |
| 721 | return static_cast<intptr_t>(unwrap(c: op)->getNumSuccessors()); |
| 722 | } |
| 723 | |
| 724 | MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { |
| 725 | return wrap(cpp: unwrap(c: op)->getSuccessor(index: static_cast<unsigned>(pos))); |
| 726 | } |
| 727 | |
| 728 | MLIR_CAPI_EXPORTED bool |
| 729 | mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name) { |
| 730 | std::optional<Attribute> attr = unwrap(c: op)->getInherentAttr(name: unwrap(ref: name)); |
| 731 | return attr.has_value(); |
| 732 | } |
| 733 | |
| 734 | MlirAttribute mlirOperationGetInherentAttributeByName(MlirOperation op, |
| 735 | MlirStringRef name) { |
| 736 | std::optional<Attribute> attr = unwrap(c: op)->getInherentAttr(name: unwrap(ref: name)); |
| 737 | if (attr.has_value()) |
| 738 | return wrap(cpp: *attr); |
| 739 | return {}; |
| 740 | } |
| 741 | |
| 742 | void mlirOperationSetInherentAttributeByName(MlirOperation op, |
| 743 | MlirStringRef name, |
| 744 | MlirAttribute attr) { |
| 745 | unwrap(op)->setInherentAttr( |
| 746 | StringAttr::get(unwrap(op)->getContext(), unwrap(name)), unwrap(attr)); |
| 747 | } |
| 748 | |
| 749 | intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) { |
| 750 | return static_cast<intptr_t>( |
| 751 | llvm::range_size(unwrap(c: op)->getDiscardableAttrs())); |
| 752 | } |
| 753 | |
| 754 | MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op, |
| 755 | intptr_t pos) { |
| 756 | NamedAttribute attr = |
| 757 | *std::next(unwrap(c: op)->getDiscardableAttrs().begin(), pos); |
| 758 | return MlirNamedAttribute{wrap(attr.getName()), wrap(cpp: attr.getValue())}; |
| 759 | } |
| 760 | |
| 761 | MlirAttribute mlirOperationGetDiscardableAttributeByName(MlirOperation op, |
| 762 | MlirStringRef name) { |
| 763 | return wrap(cpp: unwrap(c: op)->getDiscardableAttr(name: unwrap(ref: name))); |
| 764 | } |
| 765 | |
| 766 | void mlirOperationSetDiscardableAttributeByName(MlirOperation op, |
| 767 | MlirStringRef name, |
| 768 | MlirAttribute attr) { |
| 769 | unwrap(c: op)->setDiscardableAttr(name: unwrap(ref: name), value: unwrap(c: attr)); |
| 770 | } |
| 771 | |
| 772 | bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op, |
| 773 | MlirStringRef name) { |
| 774 | return !!unwrap(c: op)->removeDiscardableAttr(name: unwrap(ref: name)); |
| 775 | } |
| 776 | |
| 777 | void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, |
| 778 | MlirBlock block) { |
| 779 | unwrap(c: op)->setSuccessor(block: unwrap(c: block), index: static_cast<unsigned>(pos)); |
| 780 | } |
| 781 | |
| 782 | intptr_t mlirOperationGetNumAttributes(MlirOperation op) { |
| 783 | return static_cast<intptr_t>(unwrap(c: op)->getAttrs().size()); |
| 784 | } |
| 785 | |
| 786 | MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { |
| 787 | NamedAttribute attr = unwrap(c: op)->getAttrs()[pos]; |
| 788 | return MlirNamedAttribute{wrap(attr.getName()), wrap(cpp: attr.getValue())}; |
| 789 | } |
| 790 | |
| 791 | MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, |
| 792 | MlirStringRef name) { |
| 793 | return wrap(cpp: unwrap(c: op)->getAttr(name: unwrap(ref: name))); |
| 794 | } |
| 795 | |
| 796 | void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, |
| 797 | MlirAttribute attr) { |
| 798 | unwrap(c: op)->setAttr(name: unwrap(ref: name), value: unwrap(c: attr)); |
| 799 | } |
| 800 | |
| 801 | bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) { |
| 802 | return !!unwrap(c: op)->removeAttr(name: unwrap(ref: name)); |
| 803 | } |
| 804 | |
| 805 | void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, |
| 806 | void *userData) { |
| 807 | detail::CallbackOstream stream(callback, userData); |
| 808 | unwrap(c: op)->print(os&: stream); |
| 809 | } |
| 810 | |
| 811 | void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, |
| 812 | MlirStringCallback callback, void *userData) { |
| 813 | detail::CallbackOstream stream(callback, userData); |
| 814 | unwrap(c: op)->print(os&: stream, flags: *unwrap(c: flags)); |
| 815 | } |
| 816 | |
| 817 | void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, |
| 818 | MlirStringCallback callback, void *userData) { |
| 819 | detail::CallbackOstream stream(callback, userData); |
| 820 | if (state.ptr) |
| 821 | unwrap(c: op)->print(os&: stream, state&: *unwrap(c: state)); |
| 822 | unwrap(c: op)->print(os&: stream); |
| 823 | } |
| 824 | |
| 825 | void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, |
| 826 | void *userData) { |
| 827 | detail::CallbackOstream stream(callback, userData); |
| 828 | // As no desired version is set, no failure can occur. |
| 829 | (void)writeBytecodeToFile(op: unwrap(c: op), os&: stream); |
| 830 | } |
| 831 | |
| 832 | MlirLogicalResult mlirOperationWriteBytecodeWithConfig( |
| 833 | MlirOperation op, MlirBytecodeWriterConfig config, |
| 834 | MlirStringCallback callback, void *userData) { |
| 835 | detail::CallbackOstream stream(callback, userData); |
| 836 | return wrap(res: writeBytecodeToFile(op: unwrap(c: op), os&: stream, config: *unwrap(c: config))); |
| 837 | } |
| 838 | |
| 839 | void mlirOperationDump(MlirOperation op) { return unwrap(c: op)->dump(); } |
| 840 | |
| 841 | bool mlirOperationVerify(MlirOperation op) { |
| 842 | return succeeded(Result: verify(op: unwrap(c: op))); |
| 843 | } |
| 844 | |
| 845 | void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) { |
| 846 | return unwrap(c: op)->moveAfter(existingOp: unwrap(c: other)); |
| 847 | } |
| 848 | |
| 849 | void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) { |
| 850 | return unwrap(c: op)->moveBefore(existingOp: unwrap(c: other)); |
| 851 | } |
| 852 | |
| 853 | static mlir::WalkResult unwrap(MlirWalkResult result) { |
| 854 | switch (result) { |
| 855 | case MlirWalkResultAdvance: |
| 856 | return mlir::WalkResult::advance(); |
| 857 | |
| 858 | case MlirWalkResultInterrupt: |
| 859 | return mlir::WalkResult::interrupt(); |
| 860 | |
| 861 | case MlirWalkResultSkip: |
| 862 | return mlir::WalkResult::skip(); |
| 863 | } |
| 864 | llvm_unreachable("unknown result in WalkResult::unwrap" ); |
| 865 | } |
| 866 | |
| 867 | void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, |
| 868 | void *userData, MlirWalkOrder walkOrder) { |
| 869 | switch (walkOrder) { |
| 870 | |
| 871 | case MlirWalkPreOrder: |
| 872 | unwrap(c: op)->walk<mlir::WalkOrder::PreOrder>( |
| 873 | callback: [callback, userData](Operation *op) { |
| 874 | return unwrap(result: callback(wrap(cpp: op), userData)); |
| 875 | }); |
| 876 | break; |
| 877 | case MlirWalkPostOrder: |
| 878 | unwrap(c: op)->walk<mlir::WalkOrder::PostOrder>( |
| 879 | callback: [callback, userData](Operation *op) { |
| 880 | return unwrap(result: callback(wrap(cpp: op), userData)); |
| 881 | }); |
| 882 | } |
| 883 | } |
| 884 | |
| 885 | //===----------------------------------------------------------------------===// |
| 886 | // Region API. |
| 887 | //===----------------------------------------------------------------------===// |
| 888 | |
| 889 | MlirRegion mlirRegionCreate() { return wrap(cpp: new Region); } |
| 890 | |
| 891 | bool mlirRegionEqual(MlirRegion region, MlirRegion other) { |
| 892 | return unwrap(c: region) == unwrap(c: other); |
| 893 | } |
| 894 | |
| 895 | MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { |
| 896 | Region *cppRegion = unwrap(c: region); |
| 897 | if (cppRegion->empty()) |
| 898 | return wrap(cpp: static_cast<Block *>(nullptr)); |
| 899 | return wrap(cpp: &cppRegion->front()); |
| 900 | } |
| 901 | |
| 902 | void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { |
| 903 | unwrap(c: region)->push_back(block: unwrap(c: block)); |
| 904 | } |
| 905 | |
| 906 | void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, |
| 907 | MlirBlock block) { |
| 908 | auto &blockList = unwrap(c: region)->getBlocks(); |
| 909 | blockList.insert(where: std::next(x: blockList.begin(), n: pos), New: unwrap(c: block)); |
| 910 | } |
| 911 | |
| 912 | void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, |
| 913 | MlirBlock block) { |
| 914 | Region *cppRegion = unwrap(c: region); |
| 915 | if (mlirBlockIsNull(block: reference)) { |
| 916 | cppRegion->getBlocks().insert(where: cppRegion->begin(), New: unwrap(c: block)); |
| 917 | return; |
| 918 | } |
| 919 | |
| 920 | assert(unwrap(reference)->getParent() == unwrap(region) && |
| 921 | "expected reference block to belong to the region" ); |
| 922 | cppRegion->getBlocks().insertAfter(where: Region::iterator(unwrap(c: reference)), |
| 923 | New: unwrap(c: block)); |
| 924 | } |
| 925 | |
| 926 | void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, |
| 927 | MlirBlock block) { |
| 928 | if (mlirBlockIsNull(block: reference)) |
| 929 | return mlirRegionAppendOwnedBlock(region, block); |
| 930 | |
| 931 | assert(unwrap(reference)->getParent() == unwrap(region) && |
| 932 | "expected reference block to belong to the region" ); |
| 933 | unwrap(c: region)->getBlocks().insert(where: Region::iterator(unwrap(c: reference)), |
| 934 | New: unwrap(c: block)); |
| 935 | } |
| 936 | |
| 937 | void mlirRegionDestroy(MlirRegion region) { |
| 938 | delete static_cast<Region *>(region.ptr); |
| 939 | } |
| 940 | |
| 941 | void mlirRegionTakeBody(MlirRegion target, MlirRegion source) { |
| 942 | unwrap(c: target)->takeBody(other&: *unwrap(c: source)); |
| 943 | } |
| 944 | |
| 945 | //===----------------------------------------------------------------------===// |
| 946 | // Block API. |
| 947 | //===----------------------------------------------------------------------===// |
| 948 | |
| 949 | MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, |
| 950 | MlirLocation const *locs) { |
| 951 | Block *b = new Block; |
| 952 | for (intptr_t i = 0; i < nArgs; ++i) |
| 953 | b->addArgument(type: unwrap(c: args[i]), loc: unwrap(c: locs[i])); |
| 954 | return wrap(cpp: b); |
| 955 | } |
| 956 | |
| 957 | bool mlirBlockEqual(MlirBlock block, MlirBlock other) { |
| 958 | return unwrap(c: block) == unwrap(c: other); |
| 959 | } |
| 960 | |
| 961 | MlirOperation mlirBlockGetParentOperation(MlirBlock block) { |
| 962 | return wrap(cpp: unwrap(c: block)->getParentOp()); |
| 963 | } |
| 964 | |
| 965 | MlirRegion mlirBlockGetParentRegion(MlirBlock block) { |
| 966 | return wrap(cpp: unwrap(c: block)->getParent()); |
| 967 | } |
| 968 | |
| 969 | MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { |
| 970 | return wrap(cpp: unwrap(c: block)->getNextNode()); |
| 971 | } |
| 972 | |
| 973 | MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { |
| 974 | Block *cppBlock = unwrap(c: block); |
| 975 | if (cppBlock->empty()) |
| 976 | return wrap(cpp: static_cast<Operation *>(nullptr)); |
| 977 | return wrap(cpp: &cppBlock->front()); |
| 978 | } |
| 979 | |
| 980 | MlirOperation mlirBlockGetTerminator(MlirBlock block) { |
| 981 | Block *cppBlock = unwrap(c: block); |
| 982 | if (cppBlock->empty()) |
| 983 | return wrap(cpp: static_cast<Operation *>(nullptr)); |
| 984 | Operation &back = cppBlock->back(); |
| 985 | if (!back.hasTrait<OpTrait::IsTerminator>()) |
| 986 | return wrap(cpp: static_cast<Operation *>(nullptr)); |
| 987 | return wrap(cpp: &back); |
| 988 | } |
| 989 | |
| 990 | void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { |
| 991 | unwrap(c: block)->push_back(op: unwrap(c: operation)); |
| 992 | } |
| 993 | |
| 994 | void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, |
| 995 | MlirOperation operation) { |
| 996 | auto &opList = unwrap(c: block)->getOperations(); |
| 997 | opList.insert(where: std::next(x: opList.begin(), n: pos), New: unwrap(c: operation)); |
| 998 | } |
| 999 | |
| 1000 | void mlirBlockInsertOwnedOperationAfter(MlirBlock block, |
| 1001 | MlirOperation reference, |
| 1002 | MlirOperation operation) { |
| 1003 | Block *cppBlock = unwrap(c: block); |
| 1004 | if (mlirOperationIsNull(op: reference)) { |
| 1005 | cppBlock->getOperations().insert(where: cppBlock->begin(), New: unwrap(c: operation)); |
| 1006 | return; |
| 1007 | } |
| 1008 | |
| 1009 | assert(unwrap(reference)->getBlock() == unwrap(block) && |
| 1010 | "expected reference operation to belong to the block" ); |
| 1011 | cppBlock->getOperations().insertAfter(where: Block::iterator(unwrap(c: reference)), |
| 1012 | New: unwrap(c: operation)); |
| 1013 | } |
| 1014 | |
| 1015 | void mlirBlockInsertOwnedOperationBefore(MlirBlock block, |
| 1016 | MlirOperation reference, |
| 1017 | MlirOperation operation) { |
| 1018 | if (mlirOperationIsNull(op: reference)) |
| 1019 | return mlirBlockAppendOwnedOperation(block, operation); |
| 1020 | |
| 1021 | assert(unwrap(reference)->getBlock() == unwrap(block) && |
| 1022 | "expected reference operation to belong to the block" ); |
| 1023 | unwrap(c: block)->getOperations().insert(where: Block::iterator(unwrap(c: reference)), |
| 1024 | New: unwrap(c: operation)); |
| 1025 | } |
| 1026 | |
| 1027 | void mlirBlockDestroy(MlirBlock block) { delete unwrap(c: block); } |
| 1028 | |
| 1029 | void mlirBlockDetach(MlirBlock block) { |
| 1030 | Block *b = unwrap(c: block); |
| 1031 | b->getParent()->getBlocks().remove(IT: b); |
| 1032 | } |
| 1033 | |
| 1034 | intptr_t mlirBlockGetNumArguments(MlirBlock block) { |
| 1035 | return static_cast<intptr_t>(unwrap(c: block)->getNumArguments()); |
| 1036 | } |
| 1037 | |
| 1038 | MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, |
| 1039 | MlirLocation loc) { |
| 1040 | return wrap(cpp: unwrap(c: block)->addArgument(type: unwrap(c: type), loc: unwrap(c: loc))); |
| 1041 | } |
| 1042 | |
| 1043 | void mlirBlockEraseArgument(MlirBlock block, unsigned index) { |
| 1044 | return unwrap(c: block)->eraseArgument(index); |
| 1045 | } |
| 1046 | |
| 1047 | MlirValue mlirBlockInsertArgument(MlirBlock block, intptr_t pos, MlirType type, |
| 1048 | MlirLocation loc) { |
| 1049 | return wrap(cpp: unwrap(c: block)->insertArgument(index: pos, type: unwrap(c: type), loc: unwrap(c: loc))); |
| 1050 | } |
| 1051 | |
| 1052 | MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { |
| 1053 | return wrap(cpp: unwrap(c: block)->getArgument(i: static_cast<unsigned>(pos))); |
| 1054 | } |
| 1055 | |
| 1056 | void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, |
| 1057 | void *userData) { |
| 1058 | detail::CallbackOstream stream(callback, userData); |
| 1059 | unwrap(c: block)->print(os&: stream); |
| 1060 | } |
| 1061 | |
| 1062 | //===----------------------------------------------------------------------===// |
| 1063 | // Value API. |
| 1064 | //===----------------------------------------------------------------------===// |
| 1065 | |
| 1066 | bool mlirValueEqual(MlirValue value1, MlirValue value2) { |
| 1067 | return unwrap(c: value1) == unwrap(c: value2); |
| 1068 | } |
| 1069 | |
| 1070 | bool mlirValueIsABlockArgument(MlirValue value) { |
| 1071 | return llvm::isa<BlockArgument>(Val: unwrap(c: value)); |
| 1072 | } |
| 1073 | |
| 1074 | bool mlirValueIsAOpResult(MlirValue value) { |
| 1075 | return llvm::isa<OpResult>(Val: unwrap(c: value)); |
| 1076 | } |
| 1077 | |
| 1078 | MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { |
| 1079 | return wrap(cpp: llvm::dyn_cast<BlockArgument>(Val: unwrap(c: value)).getOwner()); |
| 1080 | } |
| 1081 | |
| 1082 | intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) { |
| 1083 | return static_cast<intptr_t>( |
| 1084 | llvm::dyn_cast<BlockArgument>(Val: unwrap(c: value)).getArgNumber()); |
| 1085 | } |
| 1086 | |
| 1087 | void mlirBlockArgumentSetType(MlirValue value, MlirType type) { |
| 1088 | if (auto blockArg = llvm::dyn_cast<BlockArgument>(Val: unwrap(c: value))) |
| 1089 | blockArg.setType(unwrap(c: type)); |
| 1090 | } |
| 1091 | |
| 1092 | MlirOperation mlirOpResultGetOwner(MlirValue value) { |
| 1093 | return wrap(cpp: llvm::dyn_cast<OpResult>(Val: unwrap(c: value)).getOwner()); |
| 1094 | } |
| 1095 | |
| 1096 | intptr_t mlirOpResultGetResultNumber(MlirValue value) { |
| 1097 | return static_cast<intptr_t>( |
| 1098 | llvm::dyn_cast<OpResult>(Val: unwrap(c: value)).getResultNumber()); |
| 1099 | } |
| 1100 | |
| 1101 | MlirType mlirValueGetType(MlirValue value) { |
| 1102 | return wrap(cpp: unwrap(c: value).getType()); |
| 1103 | } |
| 1104 | |
| 1105 | void mlirValueSetType(MlirValue value, MlirType type) { |
| 1106 | unwrap(c: value).setType(unwrap(c: type)); |
| 1107 | } |
| 1108 | |
| 1109 | void mlirValueDump(MlirValue value) { unwrap(c: value).dump(); } |
| 1110 | |
| 1111 | void mlirValuePrint(MlirValue value, MlirStringCallback callback, |
| 1112 | void *userData) { |
| 1113 | detail::CallbackOstream stream(callback, userData); |
| 1114 | unwrap(c: value).print(os&: stream); |
| 1115 | } |
| 1116 | |
| 1117 | void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, |
| 1118 | MlirStringCallback callback, void *userData) { |
| 1119 | detail::CallbackOstream stream(callback, userData); |
| 1120 | Value cppValue = unwrap(c: value); |
| 1121 | cppValue.printAsOperand(os&: stream, state&: *unwrap(c: state)); |
| 1122 | } |
| 1123 | |
| 1124 | MlirOpOperand mlirValueGetFirstUse(MlirValue value) { |
| 1125 | Value cppValue = unwrap(c: value); |
| 1126 | if (cppValue.use_empty()) |
| 1127 | return {}; |
| 1128 | |
| 1129 | OpOperand *opOperand = cppValue.use_begin().getOperand(); |
| 1130 | |
| 1131 | return wrap(cpp: opOperand); |
| 1132 | } |
| 1133 | |
| 1134 | void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) { |
| 1135 | unwrap(c: oldValue).replaceAllUsesWith(newValue: unwrap(c: newValue)); |
| 1136 | } |
| 1137 | |
| 1138 | void mlirValueReplaceAllUsesExcept(MlirValue oldValue, MlirValue newValue, |
| 1139 | intptr_t numExceptions, |
| 1140 | MlirOperation *exceptions) { |
| 1141 | Value oldValueCpp = unwrap(c: oldValue); |
| 1142 | Value newValueCpp = unwrap(c: newValue); |
| 1143 | |
| 1144 | llvm::SmallPtrSet<mlir::Operation *, 4> exceptionSet; |
| 1145 | for (intptr_t i = 0; i < numExceptions; ++i) { |
| 1146 | exceptionSet.insert(Ptr: unwrap(c: exceptions[i])); |
| 1147 | } |
| 1148 | |
| 1149 | oldValueCpp.replaceAllUsesExcept(newValue: newValueCpp, exceptions: exceptionSet); |
| 1150 | } |
| 1151 | |
| 1152 | MlirLocation mlirValueGetLocation(MlirValue v) { |
| 1153 | return wrap(cpp: unwrap(c: v).getLoc()); |
| 1154 | } |
| 1155 | |
| 1156 | MlirContext mlirValueGetContext(MlirValue v) { |
| 1157 | return wrap(cpp: unwrap(c: v).getContext()); |
| 1158 | } |
| 1159 | |
| 1160 | //===----------------------------------------------------------------------===// |
| 1161 | // OpOperand API. |
| 1162 | //===----------------------------------------------------------------------===// |
| 1163 | |
| 1164 | bool mlirOpOperandIsNull(MlirOpOperand opOperand) { return !opOperand.ptr; } |
| 1165 | |
| 1166 | MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand) { |
| 1167 | return wrap(cpp: unwrap(c: opOperand)->getOwner()); |
| 1168 | } |
| 1169 | |
| 1170 | MlirValue mlirOpOperandGetValue(MlirOpOperand opOperand) { |
| 1171 | return wrap(cpp: unwrap(c: opOperand)->get()); |
| 1172 | } |
| 1173 | |
| 1174 | unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand) { |
| 1175 | return unwrap(c: opOperand)->getOperandNumber(); |
| 1176 | } |
| 1177 | |
| 1178 | MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand) { |
| 1179 | if (mlirOpOperandIsNull(opOperand)) |
| 1180 | return {}; |
| 1181 | |
| 1182 | OpOperand *nextOpOperand = static_cast<OpOperand *>( |
| 1183 | unwrap(c: opOperand)->getNextOperandUsingThisValue()); |
| 1184 | |
| 1185 | if (!nextOpOperand) |
| 1186 | return {}; |
| 1187 | |
| 1188 | return wrap(cpp: nextOpOperand); |
| 1189 | } |
| 1190 | |
| 1191 | //===----------------------------------------------------------------------===// |
| 1192 | // Type API. |
| 1193 | //===----------------------------------------------------------------------===// |
| 1194 | |
| 1195 | MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) { |
| 1196 | return wrap(cpp: mlir::parseType(typeStr: unwrap(ref: type), context: unwrap(c: context))); |
| 1197 | } |
| 1198 | |
| 1199 | MlirContext mlirTypeGetContext(MlirType type) { |
| 1200 | return wrap(cpp: unwrap(c: type).getContext()); |
| 1201 | } |
| 1202 | |
| 1203 | MlirTypeID mlirTypeGetTypeID(MlirType type) { |
| 1204 | return wrap(cpp: unwrap(c: type).getTypeID()); |
| 1205 | } |
| 1206 | |
| 1207 | MlirDialect mlirTypeGetDialect(MlirType type) { |
| 1208 | return wrap(cpp: &unwrap(c: type).getDialect()); |
| 1209 | } |
| 1210 | |
| 1211 | bool mlirTypeEqual(MlirType t1, MlirType t2) { |
| 1212 | return unwrap(c: t1) == unwrap(c: t2); |
| 1213 | } |
| 1214 | |
| 1215 | void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { |
| 1216 | detail::CallbackOstream stream(callback, userData); |
| 1217 | unwrap(c: type).print(os&: stream); |
| 1218 | } |
| 1219 | |
| 1220 | void mlirTypeDump(MlirType type) { unwrap(c: type).dump(); } |
| 1221 | |
| 1222 | //===----------------------------------------------------------------------===// |
| 1223 | // Attribute API. |
| 1224 | //===----------------------------------------------------------------------===// |
| 1225 | |
| 1226 | MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) { |
| 1227 | return wrap(cpp: mlir::parseAttribute(attrStr: unwrap(ref: attr), context: unwrap(c: context))); |
| 1228 | } |
| 1229 | |
| 1230 | MlirContext mlirAttributeGetContext(MlirAttribute attribute) { |
| 1231 | return wrap(cpp: unwrap(c: attribute).getContext()); |
| 1232 | } |
| 1233 | |
| 1234 | MlirType mlirAttributeGetType(MlirAttribute attribute) { |
| 1235 | Attribute attr = unwrap(c: attribute); |
| 1236 | if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) |
| 1237 | return wrap(typedAttr.getType()); |
| 1238 | return wrap(NoneType::get(attr.getContext())); |
| 1239 | } |
| 1240 | |
| 1241 | MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) { |
| 1242 | return wrap(cpp: unwrap(c: attr).getTypeID()); |
| 1243 | } |
| 1244 | |
| 1245 | MlirDialect mlirAttributeGetDialect(MlirAttribute attr) { |
| 1246 | return wrap(cpp: &unwrap(c: attr).getDialect()); |
| 1247 | } |
| 1248 | |
| 1249 | bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { |
| 1250 | return unwrap(c: a1) == unwrap(c: a2); |
| 1251 | } |
| 1252 | |
| 1253 | void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, |
| 1254 | void *userData) { |
| 1255 | detail::CallbackOstream stream(callback, userData); |
| 1256 | unwrap(c: attr).print(os&: stream); |
| 1257 | } |
| 1258 | |
| 1259 | void mlirAttributeDump(MlirAttribute attr) { unwrap(c: attr).dump(); } |
| 1260 | |
| 1261 | MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, |
| 1262 | MlirAttribute attr) { |
| 1263 | return MlirNamedAttribute{.name: name, .attribute: attr}; |
| 1264 | } |
| 1265 | |
| 1266 | //===----------------------------------------------------------------------===// |
| 1267 | // Identifier API. |
| 1268 | //===----------------------------------------------------------------------===// |
| 1269 | |
| 1270 | MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) { |
| 1271 | return wrap(StringAttr::get(unwrap(context), unwrap(str))); |
| 1272 | } |
| 1273 | |
| 1274 | MlirContext mlirIdentifierGetContext(MlirIdentifier ident) { |
| 1275 | return wrap(unwrap(ident).getContext()); |
| 1276 | } |
| 1277 | |
| 1278 | bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) { |
| 1279 | return unwrap(ident) == unwrap(other); |
| 1280 | } |
| 1281 | |
| 1282 | MlirStringRef mlirIdentifierStr(MlirIdentifier ident) { |
| 1283 | return wrap(unwrap(ident).strref()); |
| 1284 | } |
| 1285 | |
| 1286 | //===----------------------------------------------------------------------===// |
| 1287 | // Symbol and SymbolTable API. |
| 1288 | //===----------------------------------------------------------------------===// |
| 1289 | |
| 1290 | MlirStringRef mlirSymbolTableGetSymbolAttributeName() { |
| 1291 | return wrap(ref: SymbolTable::getSymbolAttrName()); |
| 1292 | } |
| 1293 | |
| 1294 | MlirStringRef mlirSymbolTableGetVisibilityAttributeName() { |
| 1295 | return wrap(ref: SymbolTable::getVisibilityAttrName()); |
| 1296 | } |
| 1297 | |
| 1298 | MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) { |
| 1299 | if (!unwrap(c: operation)->hasTrait<OpTrait::SymbolTable>()) |
| 1300 | return wrap(cpp: static_cast<SymbolTable *>(nullptr)); |
| 1301 | return wrap(cpp: new SymbolTable(unwrap(c: operation))); |
| 1302 | } |
| 1303 | |
| 1304 | void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) { |
| 1305 | delete unwrap(c: symbolTable); |
| 1306 | } |
| 1307 | |
| 1308 | MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, |
| 1309 | MlirStringRef name) { |
| 1310 | return wrap(cpp: unwrap(c: symbolTable)->lookup(name: StringRef(name.data, name.length))); |
| 1311 | } |
| 1312 | |
| 1313 | MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, |
| 1314 | MlirOperation operation) { |
| 1315 | return wrap(cpp: (Attribute)unwrap(c: symbolTable)->insert(unwrap(c: operation))); |
| 1316 | } |
| 1317 | |
| 1318 | void mlirSymbolTableErase(MlirSymbolTable symbolTable, |
| 1319 | MlirOperation operation) { |
| 1320 | unwrap(c: symbolTable)->erase(symbol: unwrap(c: operation)); |
| 1321 | } |
| 1322 | |
| 1323 | MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, |
| 1324 | MlirStringRef newSymbol, |
| 1325 | MlirOperation from) { |
| 1326 | auto *cppFrom = unwrap(c: from); |
| 1327 | auto *context = cppFrom->getContext(); |
| 1328 | auto oldSymbolAttr = StringAttr::get(context, unwrap(oldSymbol)); |
| 1329 | auto newSymbolAttr = StringAttr::get(context, unwrap(newSymbol)); |
| 1330 | return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr, |
| 1331 | unwrap(c: from))); |
| 1332 | } |
| 1333 | |
| 1334 | void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, |
| 1335 | void (*callback)(MlirOperation, bool, |
| 1336 | void *userData), |
| 1337 | void *userData) { |
| 1338 | SymbolTable::walkSymbolTables(op: unwrap(c: from), allSymUsesVisible, |
| 1339 | callback: [&](Operation *foundOpCpp, bool isVisible) { |
| 1340 | callback(wrap(cpp: foundOpCpp), isVisible, |
| 1341 | userData); |
| 1342 | }); |
| 1343 | } |
| 1344 | |