| 1 | //===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir-c/Rewrite.h" |
| 10 | |
| 11 | #include "mlir-c/Transforms.h" |
| 12 | #include "mlir/CAPI/IR.h" |
| 13 | #include "mlir/CAPI/Rewrite.h" |
| 14 | #include "mlir/CAPI/Support.h" |
| 15 | #include "mlir/CAPI/Wrap.h" |
| 16 | #include "mlir/IR/PatternMatch.h" |
| 17 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
| 18 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 19 | |
| 20 | using namespace mlir; |
| 21 | |
| 22 | //===----------------------------------------------------------------------===// |
| 23 | /// RewriterBase API inherited from OpBuilder |
| 24 | //===----------------------------------------------------------------------===// |
| 25 | |
| 26 | MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) { |
| 27 | return wrap(cpp: unwrap(c: rewriter)->getContext()); |
| 28 | } |
| 29 | |
| 30 | //===----------------------------------------------------------------------===// |
| 31 | /// Insertion points methods |
| 32 | //===----------------------------------------------------------------------===// |
| 33 | |
| 34 | void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter) { |
| 35 | unwrap(c: rewriter)->clearInsertionPoint(); |
| 36 | } |
| 37 | |
| 38 | void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter, |
| 39 | MlirOperation op) { |
| 40 | unwrap(c: rewriter)->setInsertionPoint(unwrap(c: op)); |
| 41 | } |
| 42 | |
| 43 | void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter, |
| 44 | MlirOperation op) { |
| 45 | unwrap(c: rewriter)->setInsertionPointAfter(unwrap(c: op)); |
| 46 | } |
| 47 | |
| 48 | void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter, |
| 49 | MlirValue value) { |
| 50 | unwrap(c: rewriter)->setInsertionPointAfterValue(unwrap(c: value)); |
| 51 | } |
| 52 | |
| 53 | void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter, |
| 54 | MlirBlock block) { |
| 55 | unwrap(c: rewriter)->setInsertionPointToStart(unwrap(c: block)); |
| 56 | } |
| 57 | |
| 58 | void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter, |
| 59 | MlirBlock block) { |
| 60 | unwrap(c: rewriter)->setInsertionPointToEnd(unwrap(c: block)); |
| 61 | } |
| 62 | |
| 63 | MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter) { |
| 64 | return wrap(cpp: unwrap(c: rewriter)->getInsertionBlock()); |
| 65 | } |
| 66 | |
| 67 | MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) { |
| 68 | return wrap(cpp: unwrap(c: rewriter)->getBlock()); |
| 69 | } |
| 70 | |
| 71 | //===----------------------------------------------------------------------===// |
| 72 | /// Block and operation creation/insertion/cloning |
| 73 | //===----------------------------------------------------------------------===// |
| 74 | |
| 75 | MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter, |
| 76 | MlirBlock insertBefore, |
| 77 | intptr_t nArgTypes, |
| 78 | MlirType const *argTypes, |
| 79 | MlirLocation const *locations) { |
| 80 | SmallVector<Type, 4> args; |
| 81 | ArrayRef<Type> unwrappedArgs = unwrapList(size: nArgTypes, first: argTypes, storage&: args); |
| 82 | SmallVector<Location, 4> locs; |
| 83 | ArrayRef<Location> unwrappedLocs = unwrapList(size: nArgTypes, first: locations, storage&: locs); |
| 84 | return wrap(cpp: unwrap(c: rewriter)->createBlock(insertBefore: unwrap(c: insertBefore), argTypes: unwrappedArgs, |
| 85 | locs: unwrappedLocs)); |
| 86 | } |
| 87 | |
| 88 | MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter, |
| 89 | MlirOperation op) { |
| 90 | return wrap(cpp: unwrap(c: rewriter)->insert(op: unwrap(c: op))); |
| 91 | } |
| 92 | |
| 93 | // Other methods of OpBuilder |
| 94 | |
| 95 | MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter, |
| 96 | MlirOperation op) { |
| 97 | return wrap(cpp: unwrap(c: rewriter)->clone(op&: *unwrap(c: op))); |
| 98 | } |
| 99 | |
| 100 | MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter, |
| 101 | MlirOperation op) { |
| 102 | return wrap(cpp: unwrap(c: rewriter)->cloneWithoutRegions(op&: *unwrap(c: op))); |
| 103 | } |
| 104 | |
| 105 | void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, |
| 106 | MlirRegion region, MlirBlock before) { |
| 107 | |
| 108 | unwrap(c: rewriter)->cloneRegionBefore(region&: *unwrap(c: region), before: unwrap(c: before)); |
| 109 | } |
| 110 | |
| 111 | //===----------------------------------------------------------------------===// |
| 112 | /// RewriterBase API |
| 113 | //===----------------------------------------------------------------------===// |
| 114 | |
| 115 | void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter, |
| 116 | MlirRegion region, MlirBlock before) { |
| 117 | unwrap(c: rewriter)->inlineRegionBefore(region&: *unwrap(c: region), before: unwrap(c: before)); |
| 118 | } |
| 119 | |
| 120 | void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, |
| 121 | MlirOperation op, intptr_t nValues, |
| 122 | MlirValue const *values) { |
| 123 | SmallVector<Value, 4> vals; |
| 124 | ArrayRef<Value> unwrappedVals = unwrapList(size: nValues, first: values, storage&: vals); |
| 125 | unwrap(c: rewriter)->replaceOp(op: unwrap(c: op), newValues: unwrappedVals); |
| 126 | } |
| 127 | |
| 128 | void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter, |
| 129 | MlirOperation op, |
| 130 | MlirOperation newOp) { |
| 131 | unwrap(c: rewriter)->replaceOp(op: unwrap(c: op), newOp: unwrap(c: newOp)); |
| 132 | } |
| 133 | |
| 134 | void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op) { |
| 135 | unwrap(c: rewriter)->eraseOp(op: unwrap(c: op)); |
| 136 | } |
| 137 | |
| 138 | void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, MlirBlock block) { |
| 139 | unwrap(c: rewriter)->eraseBlock(block: unwrap(c: block)); |
| 140 | } |
| 141 | |
| 142 | void mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter, |
| 143 | MlirBlock source, MlirOperation op, |
| 144 | intptr_t nArgValues, |
| 145 | MlirValue const *argValues) { |
| 146 | SmallVector<Value, 4> vals; |
| 147 | ArrayRef<Value> unwrappedVals = unwrapList(size: nArgValues, first: argValues, storage&: vals); |
| 148 | |
| 149 | unwrap(c: rewriter)->inlineBlockBefore(source: unwrap(c: source), op: unwrap(c: op), |
| 150 | argValues: unwrappedVals); |
| 151 | } |
| 152 | |
| 153 | void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, MlirBlock source, |
| 154 | MlirBlock dest, intptr_t nArgValues, |
| 155 | MlirValue const *argValues) { |
| 156 | SmallVector<Value, 4> args; |
| 157 | ArrayRef<Value> unwrappedArgs = unwrapList(size: nArgValues, first: argValues, storage&: args); |
| 158 | unwrap(c: rewriter)->mergeBlocks(source: unwrap(c: source), dest: unwrap(c: dest), argValues: unwrappedArgs); |
| 159 | } |
| 160 | |
| 161 | void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, MlirOperation op, |
| 162 | MlirOperation existingOp) { |
| 163 | unwrap(c: rewriter)->moveOpBefore(op: unwrap(c: op), existingOp: unwrap(c: existingOp)); |
| 164 | } |
| 165 | |
| 166 | void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, MlirOperation op, |
| 167 | MlirOperation existingOp) { |
| 168 | unwrap(c: rewriter)->moveOpAfter(op: unwrap(c: op), existingOp: unwrap(c: existingOp)); |
| 169 | } |
| 170 | |
| 171 | void mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block, |
| 172 | MlirBlock existingBlock) { |
| 173 | unwrap(c: rewriter)->moveBlockBefore(block: unwrap(c: block), anotherBlock: unwrap(c: existingBlock)); |
| 174 | } |
| 175 | |
| 176 | void mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter, |
| 177 | MlirOperation op) { |
| 178 | unwrap(c: rewriter)->startOpModification(op: unwrap(c: op)); |
| 179 | } |
| 180 | |
| 181 | void mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter, |
| 182 | MlirOperation op) { |
| 183 | unwrap(c: rewriter)->finalizeOpModification(op: unwrap(c: op)); |
| 184 | } |
| 185 | |
| 186 | void mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter, |
| 187 | MlirOperation op) { |
| 188 | unwrap(c: rewriter)->cancelOpModification(op: unwrap(c: op)); |
| 189 | } |
| 190 | |
| 191 | void mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter, |
| 192 | MlirValue from, MlirValue to) { |
| 193 | unwrap(c: rewriter)->replaceAllUsesWith(from: unwrap(c: from), to: unwrap(c: to)); |
| 194 | } |
| 195 | |
| 196 | void mlirRewriterBaseReplaceAllValueRangeUsesWith(MlirRewriterBase rewriter, |
| 197 | intptr_t nValues, |
| 198 | MlirValue const *from, |
| 199 | MlirValue const *to) { |
| 200 | SmallVector<Value, 4> fromVals; |
| 201 | ArrayRef<Value> unwrappedFromVals = unwrapList(size: nValues, first: from, storage&: fromVals); |
| 202 | SmallVector<Value, 4> toVals; |
| 203 | ArrayRef<Value> unwrappedToVals = unwrapList(size: nValues, first: to, storage&: toVals); |
| 204 | unwrap(c: rewriter)->replaceAllUsesWith(from: unwrappedFromVals, to: unwrappedToVals); |
| 205 | } |
| 206 | |
| 207 | void mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter, |
| 208 | MlirOperation from, |
| 209 | intptr_t nTo, |
| 210 | MlirValue const *to) { |
| 211 | SmallVector<Value, 4> toVals; |
| 212 | ArrayRef<Value> unwrappedToVals = unwrapList(size: nTo, first: to, storage&: toVals); |
| 213 | unwrap(c: rewriter)->replaceAllOpUsesWith(from: unwrap(c: from), to: unwrappedToVals); |
| 214 | } |
| 215 | |
| 216 | void mlirRewriterBaseReplaceAllOpUsesWithOperation(MlirRewriterBase rewriter, |
| 217 | MlirOperation from, |
| 218 | MlirOperation to) { |
| 219 | unwrap(c: rewriter)->replaceAllOpUsesWith(from: unwrap(c: from), to: unwrap(c: to)); |
| 220 | } |
| 221 | |
| 222 | void mlirRewriterBaseReplaceOpUsesWithinBlock(MlirRewriterBase rewriter, |
| 223 | MlirOperation op, |
| 224 | intptr_t nNewValues, |
| 225 | MlirValue const *newValues, |
| 226 | MlirBlock block) { |
| 227 | SmallVector<Value, 4> vals; |
| 228 | ArrayRef<Value> unwrappedVals = unwrapList(size: nNewValues, first: newValues, storage&: vals); |
| 229 | unwrap(c: rewriter)->replaceOpUsesWithinBlock(op: unwrap(c: op), newValues: unwrappedVals, |
| 230 | block: unwrap(c: block)); |
| 231 | } |
| 232 | |
| 233 | void mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter, |
| 234 | MlirValue from, MlirValue to, |
| 235 | MlirOperation exceptedUser) { |
| 236 | unwrap(c: rewriter)->replaceAllUsesExcept(from: unwrap(c: from), to: unwrap(c: to), |
| 237 | exceptedUser: unwrap(c: exceptedUser)); |
| 238 | } |
| 239 | |
| 240 | //===----------------------------------------------------------------------===// |
| 241 | /// IRRewriter API |
| 242 | //===----------------------------------------------------------------------===// |
| 243 | |
| 244 | MlirRewriterBase mlirIRRewriterCreate(MlirContext context) { |
| 245 | return wrap(cpp: new IRRewriter(unwrap(c: context))); |
| 246 | } |
| 247 | |
| 248 | MlirRewriterBase mlirIRRewriterCreateFromOp(MlirOperation op) { |
| 249 | return wrap(cpp: new IRRewriter(unwrap(c: op))); |
| 250 | } |
| 251 | |
| 252 | void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { |
| 253 | delete static_cast<IRRewriter *>(unwrap(c: rewriter)); |
| 254 | } |
| 255 | |
| 256 | //===----------------------------------------------------------------------===// |
| 257 | /// RewritePatternSet and FrozenRewritePatternSet API |
| 258 | //===----------------------------------------------------------------------===// |
| 259 | |
| 260 | inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { |
| 261 | assert(module.ptr && "unexpected null module" ); |
| 262 | return *(static_cast<mlir::RewritePatternSet *>(module.ptr)); |
| 263 | } |
| 264 | |
| 265 | inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { |
| 266 | return {.ptr: module}; |
| 267 | } |
| 268 | |
| 269 | inline mlir::FrozenRewritePatternSet * |
| 270 | unwrap(MlirFrozenRewritePatternSet module) { |
| 271 | assert(module.ptr && "unexpected null module" ); |
| 272 | return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr); |
| 273 | } |
| 274 | |
| 275 | inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) { |
| 276 | return {.ptr: module}; |
| 277 | } |
| 278 | |
| 279 | MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) { |
| 280 | auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(module: op))); |
| 281 | op.ptr = nullptr; |
| 282 | return wrap(module: m); |
| 283 | } |
| 284 | |
| 285 | void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) { |
| 286 | delete unwrap(module: op); |
| 287 | op.ptr = nullptr; |
| 288 | } |
| 289 | |
| 290 | MlirLogicalResult |
| 291 | mlirApplyPatternsAndFoldGreedily(MlirModule op, |
| 292 | MlirFrozenRewritePatternSet patterns, |
| 293 | MlirGreedyRewriteDriverConfig) { |
| 294 | return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(module: patterns))); |
| 295 | } |
| 296 | |
| 297 | //===----------------------------------------------------------------------===// |
| 298 | /// PDLPatternModule API |
| 299 | //===----------------------------------------------------------------------===// |
| 300 | |
| 301 | #if MLIR_ENABLE_PDL_IN_PATTERNMATCH |
| 302 | inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { |
| 303 | assert(module.ptr && "unexpected null module" ); |
| 304 | return static_cast<mlir::PDLPatternModule *>(module.ptr); |
| 305 | } |
| 306 | |
| 307 | inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { |
| 308 | return {.ptr: module}; |
| 309 | } |
| 310 | |
| 311 | MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) { |
| 312 | return wrap(module: new mlir::PDLPatternModule( |
| 313 | mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op)))); |
| 314 | } |
| 315 | |
| 316 | void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) { |
| 317 | delete unwrap(module: op); |
| 318 | op.ptr = nullptr; |
| 319 | } |
| 320 | |
| 321 | MlirRewritePatternSet |
| 322 | mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) { |
| 323 | auto *m = new mlir::RewritePatternSet(std::move(*unwrap(module: op))); |
| 324 | op.ptr = nullptr; |
| 325 | return wrap(module: m); |
| 326 | } |
| 327 | #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH |
| 328 | |