| 1 | //===- rewrite.c - Test of the rewriting C API ----------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM |
| 4 | // Exceptions. |
| 5 | // See https://llvm.org/LICENSE.txt for license information. |
| 6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 7 | // |
| 8 | //===----------------------------------------------------------------------===// |
| 9 | |
| 10 | // RUN: mlir-capi-rewrite-test 2>&1 | FileCheck %s |
| 11 | |
| 12 | #include "mlir-c/Rewrite.h" |
| 13 | #include "mlir-c/BuiltinTypes.h" |
| 14 | #include "mlir-c/IR.h" |
| 15 | |
| 16 | #include <assert.h> |
| 17 | #include <stdio.h> |
| 18 | |
| 19 | MlirOperation createOperationWithName(MlirContext ctx, const char *name) { |
| 20 | MlirStringRef nameRef = mlirStringRefCreateFromCString(str: name); |
| 21 | MlirLocation loc = mlirLocationUnknownGet(context: ctx); |
| 22 | MlirOperationState state = mlirOperationStateGet(name: nameRef, loc); |
| 23 | MlirType indexType = mlirIndexTypeGet(ctx); |
| 24 | mlirOperationStateAddResults(state: &state, n: 1, results: &indexType); |
| 25 | return mlirOperationCreate(state: &state); |
| 26 | } |
| 27 | |
| 28 | void testInsertionPoint(MlirContext ctx) { |
| 29 | // CHECK-LABEL: @testInsertionPoint |
| 30 | fprintf(stderr, format: "@testInsertionPoint\n" ); |
| 31 | |
| 32 | const char *moduleString = "\"dialect.op1\"() : () -> ()\n" ; |
| 33 | MlirModule module = |
| 34 | mlirModuleCreateParse(context: ctx, module: mlirStringRefCreateFromCString(str: moduleString)); |
| 35 | MlirOperation op = mlirModuleGetOperation(module); |
| 36 | MlirBlock body = mlirModuleGetBody(module); |
| 37 | MlirOperation op1 = mlirBlockGetFirstOperation(block: body); |
| 38 | |
| 39 | // IRRewriter create |
| 40 | MlirRewriterBase rewriter = mlirIRRewriterCreate(context: ctx); |
| 41 | |
| 42 | // Insert before op |
| 43 | mlirRewriterBaseSetInsertionPointBefore(rewriter, op: op1); |
| 44 | MlirOperation op2 = createOperationWithName(ctx, name: "dialect.op2" ); |
| 45 | mlirRewriterBaseInsert(rewriter, op: op2); |
| 46 | |
| 47 | // Insert after op |
| 48 | mlirRewriterBaseSetInsertionPointAfter(rewriter, op: op2); |
| 49 | MlirOperation op3 = createOperationWithName(ctx, name: "dialect.op3" ); |
| 50 | mlirRewriterBaseInsert(rewriter, op: op3); |
| 51 | MlirValue op3Res = mlirOperationGetResult(op: op3, pos: 0); |
| 52 | |
| 53 | // Insert after value |
| 54 | mlirRewriterBaseSetInsertionPointAfterValue(rewriter, value: op3Res); |
| 55 | MlirOperation op4 = createOperationWithName(ctx, name: "dialect.op4" ); |
| 56 | mlirRewriterBaseInsert(rewriter, op: op4); |
| 57 | |
| 58 | // Insert at beginning of block |
| 59 | mlirRewriterBaseSetInsertionPointToStart(rewriter, block: body); |
| 60 | MlirOperation op5 = createOperationWithName(ctx, name: "dialect.op5" ); |
| 61 | mlirRewriterBaseInsert(rewriter, op: op5); |
| 62 | |
| 63 | // Insert at end of block |
| 64 | mlirRewriterBaseSetInsertionPointToEnd(rewriter, block: body); |
| 65 | MlirOperation op6 = createOperationWithName(ctx, name: "dialect.op6" ); |
| 66 | mlirRewriterBaseInsert(rewriter, op: op6); |
| 67 | |
| 68 | // Get insertion blocks |
| 69 | MlirBlock block1 = mlirRewriterBaseGetBlock(rewriter); |
| 70 | MlirBlock block2 = mlirRewriterBaseGetInsertionBlock(rewriter); |
| 71 | (void)block1; |
| 72 | (void)block2; |
| 73 | assert(body.ptr == block1.ptr); |
| 74 | assert(body.ptr == block2.ptr); |
| 75 | |
| 76 | // clang-format off |
| 77 | // CHECK-NEXT: module { |
| 78 | // CHECK-NEXT: %{{.*}} = "dialect.op5"() : () -> index |
| 79 | // CHECK-NEXT: %{{.*}} = "dialect.op2"() : () -> index |
| 80 | // CHECK-NEXT: %{{.*}} = "dialect.op3"() : () -> index |
| 81 | // CHECK-NEXT: %{{.*}} = "dialect.op4"() : () -> index |
| 82 | // CHECK-NEXT: "dialect.op1"() : () -> () |
| 83 | // CHECK-NEXT: %{{.*}} = "dialect.op6"() : () -> index |
| 84 | // CHECK-NEXT: } |
| 85 | // clang-format on |
| 86 | mlirOperationDump(op); |
| 87 | |
| 88 | mlirIRRewriterDestroy(rewriter); |
| 89 | mlirModuleDestroy(module); |
| 90 | } |
| 91 | |
| 92 | void testCreateBlock(MlirContext ctx) { |
| 93 | // CHECK-LABEL: @testCreateBlock |
| 94 | fprintf(stderr, format: "@testCreateBlock\n" ); |
| 95 | |
| 96 | const char *moduleString = "\"dialect.op1\"() ({^bb0:}) : () -> ()\n" |
| 97 | "\"dialect.op2\"() ({^bb0:}) : () -> ()\n" ; |
| 98 | MlirModule module = |
| 99 | mlirModuleCreateParse(context: ctx, module: mlirStringRefCreateFromCString(str: moduleString)); |
| 100 | MlirOperation op = mlirModuleGetOperation(module); |
| 101 | MlirBlock body = mlirModuleGetBody(module); |
| 102 | |
| 103 | MlirOperation op1 = mlirBlockGetFirstOperation(block: body); |
| 104 | MlirRegion region1 = mlirOperationGetRegion(op: op1, pos: 0); |
| 105 | MlirBlock block1 = mlirRegionGetFirstBlock(region: region1); |
| 106 | |
| 107 | MlirOperation op2 = mlirOperationGetNextInBlock(op: op1); |
| 108 | MlirRegion region2 = mlirOperationGetRegion(op: op2, pos: 0); |
| 109 | MlirBlock block2 = mlirRegionGetFirstBlock(region: region2); |
| 110 | |
| 111 | MlirRewriterBase rewriter = mlirIRRewriterCreate(context: ctx); |
| 112 | |
| 113 | // Create block before |
| 114 | MlirType indexType = mlirIndexTypeGet(ctx); |
| 115 | MlirLocation unknown = mlirLocationUnknownGet(context: ctx); |
| 116 | mlirRewriterBaseCreateBlockBefore(rewriter, insertBefore: block1, nArgTypes: 1, argTypes: &indexType, locations: &unknown); |
| 117 | |
| 118 | mlirRewriterBaseSetInsertionPointToEnd(rewriter, block: body); |
| 119 | |
| 120 | // Clone operation |
| 121 | mlirRewriterBaseClone(rewriter, op: op1); |
| 122 | |
| 123 | // Clone without regions |
| 124 | mlirRewriterBaseCloneWithoutRegions(rewriter, op: op1); |
| 125 | |
| 126 | // Clone region before |
| 127 | mlirRewriterBaseCloneRegionBefore(rewriter, region: region1, before: block2); |
| 128 | |
| 129 | mlirOperationDump(op); |
| 130 | // clang-format off |
| 131 | // CHECK-NEXT: "builtin.module"() ({ |
| 132 | // CHECK-NEXT: "dialect.op1"() ({ |
| 133 | // CHECK-NEXT: ^{{.*}}(%{{.*}}: index): |
| 134 | // CHECK-NEXT: ^{{.*}}: |
| 135 | // CHECK-NEXT: }) : () -> () |
| 136 | // CHECK-NEXT: "dialect.op2"() ({ |
| 137 | // CHECK-NEXT: ^{{.*}}(%{{.*}}: index): |
| 138 | // CHECK-NEXT: ^{{.*}}: |
| 139 | // CHECK-NEXT: ^{{.*}}: |
| 140 | // CHECK-NEXT: }) : () -> () |
| 141 | // CHECK-NEXT: "dialect.op1"() ({ |
| 142 | // CHECK-NEXT: ^{{.*}}(%{{.*}}: index): |
| 143 | // CHECK-NEXT: ^{{.*}}: |
| 144 | // CHECK-NEXT: }) : () -> () |
| 145 | // CHECK-NEXT: "dialect.op1"() ({ |
| 146 | // CHECK-NEXT: }) : () -> () |
| 147 | // CHECK-NEXT: }) : () -> () |
| 148 | // clang-format on |
| 149 | |
| 150 | mlirIRRewriterDestroy(rewriter); |
| 151 | mlirModuleDestroy(module); |
| 152 | } |
| 153 | |
| 154 | void testInlineRegionBlock(MlirContext ctx) { |
| 155 | // CHECK-LABEL: @testInlineRegionBlock |
| 156 | fprintf(stderr, format: "@testInlineRegionBlock\n" ); |
| 157 | |
| 158 | const char *moduleString = |
| 159 | "\"dialect.op1\"() ({\n" |
| 160 | " ^bb0(%arg0: index):\n" |
| 161 | " \"dialect.op1_in1\"(%arg0) [^bb1] : (index) -> ()\n" |
| 162 | " ^bb1():\n" |
| 163 | " \"dialect.op1_in2\"() : () -> ()\n" |
| 164 | "}) : () -> ()\n" |
| 165 | "\"dialect.op2\"() ({^bb0:}) : () -> ()\n" |
| 166 | "\"dialect.op3\"() ({\n" |
| 167 | " ^bb0(%arg0: index):\n" |
| 168 | " \"dialect.op3_in1\"(%arg0) : (index) -> ()\n" |
| 169 | " ^bb1():\n" |
| 170 | " %x = \"dialect.op3_in2\"() : () -> index\n" |
| 171 | " %y = \"dialect.op3_in3\"() : () -> index\n" |
| 172 | "}) : () -> ()\n" |
| 173 | "\"dialect.op4\"() ({\n" |
| 174 | " ^bb0():\n" |
| 175 | " \"dialect.op4_in1\"() : () -> index\n" |
| 176 | " ^bb1(%arg0: index):\n" |
| 177 | " \"dialect.op4_in2\"(%arg0) : (index) -> ()\n" |
| 178 | "}) : () -> ()\n" ; |
| 179 | MlirModule module = |
| 180 | mlirModuleCreateParse(context: ctx, module: mlirStringRefCreateFromCString(str: moduleString)); |
| 181 | MlirOperation op = mlirModuleGetOperation(module); |
| 182 | MlirBlock body = mlirModuleGetBody(module); |
| 183 | |
| 184 | MlirOperation op1 = mlirBlockGetFirstOperation(block: body); |
| 185 | MlirRegion region1 = mlirOperationGetRegion(op: op1, pos: 0); |
| 186 | |
| 187 | MlirOperation op2 = mlirOperationGetNextInBlock(op: op1); |
| 188 | MlirRegion region2 = mlirOperationGetRegion(op: op2, pos: 0); |
| 189 | MlirBlock block2 = mlirRegionGetFirstBlock(region: region2); |
| 190 | |
| 191 | MlirOperation op3 = mlirOperationGetNextInBlock(op: op2); |
| 192 | MlirRegion region3 = mlirOperationGetRegion(op: op3, pos: 0); |
| 193 | MlirBlock block3_1 = mlirRegionGetFirstBlock(region: region3); |
| 194 | MlirBlock block3_2 = mlirBlockGetNextInRegion(block: block3_1); |
| 195 | MlirOperation op3_in2 = mlirBlockGetFirstOperation(block: block3_2); |
| 196 | MlirValue op3_in2_res = mlirOperationGetResult(op: op3_in2, pos: 0); |
| 197 | MlirOperation op3_in3 = mlirOperationGetNextInBlock(op: op3_in2); |
| 198 | |
| 199 | MlirOperation op4 = mlirOperationGetNextInBlock(op: op3); |
| 200 | MlirRegion region4 = mlirOperationGetRegion(op: op4, pos: 0); |
| 201 | MlirBlock block4_1 = mlirRegionGetFirstBlock(region: region4); |
| 202 | MlirOperation op4_in1 = mlirBlockGetFirstOperation(block: block4_1); |
| 203 | MlirValue op4_in1_res = mlirOperationGetResult(op: op4_in1, pos: 0); |
| 204 | MlirBlock block4_2 = mlirBlockGetNextInRegion(block: block4_1); |
| 205 | |
| 206 | MlirRewriterBase rewriter = mlirIRRewriterCreate(context: ctx); |
| 207 | |
| 208 | // Test these three functions |
| 209 | mlirRewriterBaseInlineRegionBefore(rewriter, region: region1, before: block2); |
| 210 | mlirRewriterBaseInlineBlockBefore(rewriter, source: block3_1, op: op3_in3, nArgValues: 1, |
| 211 | argValues: &op3_in2_res); |
| 212 | mlirRewriterBaseMergeBlocks(rewriter, source: block4_2, dest: block4_1, nArgValues: 1, argValues: &op4_in1_res); |
| 213 | |
| 214 | mlirOperationDump(op); |
| 215 | // clang-format off |
| 216 | // CHECK-NEXT: "builtin.module"() ({ |
| 217 | // CHECK-NEXT: "dialect.op1"() ({ |
| 218 | // CHECK-NEXT: }) : () -> () |
| 219 | // CHECK-NEXT: "dialect.op2"() ({ |
| 220 | // CHECK-NEXT: ^{{.*}}(%{{.*}}: index): |
| 221 | // CHECK-NEXT: "dialect.op1_in1"(%{{.*}})[^[[bb:.*]]] : (index) -> () |
| 222 | // CHECK-NEXT: ^[[bb]]: |
| 223 | // CHECK-NEXT: "dialect.op1_in2"() : () -> () |
| 224 | // CHECK-NEXT: ^{{.*}}: // no predecessors |
| 225 | // CHECK-NEXT: }) : () -> () |
| 226 | // CHECK-NEXT: "dialect.op3"() ({ |
| 227 | // CHECK-NEXT: %{{.*}} = "dialect.op3_in2"() : () -> index |
| 228 | // CHECK-NEXT: "dialect.op3_in1"(%{{.*}}) : (index) -> () |
| 229 | // CHECK-NEXT: %{{.*}} = "dialect.op3_in3"() : () -> index |
| 230 | // CHECK-NEXT: }) : () -> () |
| 231 | // CHECK-NEXT: "dialect.op4"() ({ |
| 232 | // CHECK-NEXT: %{{.*}} = "dialect.op4_in1"() : () -> index |
| 233 | // CHECK-NEXT: "dialect.op4_in2"(%{{.*}}) : (index) -> () |
| 234 | // CHECK-NEXT: }) : () -> () |
| 235 | // CHECK-NEXT: }) : () -> () |
| 236 | // clang-format on |
| 237 | |
| 238 | mlirIRRewriterDestroy(rewriter); |
| 239 | mlirModuleDestroy(module); |
| 240 | } |
| 241 | |
| 242 | void testReplaceOp(MlirContext ctx) { |
| 243 | // CHECK-LABEL: @testReplaceOp |
| 244 | fprintf(stderr, format: "@testReplaceOp\n" ); |
| 245 | |
| 246 | const char *moduleString = |
| 247 | "%x, %y, %z = \"dialect.create_values\"() : () -> (index, index, index)\n" |
| 248 | "%x_1, %y_1 = \"dialect.op1\"() : () -> (index, index)\n" |
| 249 | "\"dialect.use_op1\"(%x_1, %y_1) : (index, index) -> ()\n" |
| 250 | "%x_2, %y_2 = \"dialect.op2\"() : () -> (index, index)\n" |
| 251 | "%x_3, %y_3 = \"dialect.op3\"() : () -> (index, index)\n" |
| 252 | "\"dialect.use_op2\"(%x_2, %y_2) : (index, index) -> ()\n" ; |
| 253 | MlirModule module = |
| 254 | mlirModuleCreateParse(context: ctx, module: mlirStringRefCreateFromCString(str: moduleString)); |
| 255 | MlirOperation op = mlirModuleGetOperation(module); |
| 256 | MlirBlock body = mlirModuleGetBody(module); |
| 257 | |
| 258 | // get a handle to all operations/values |
| 259 | MlirOperation createValues = mlirBlockGetFirstOperation(block: body); |
| 260 | MlirValue x = mlirOperationGetResult(op: createValues, pos: 0); |
| 261 | MlirValue z = mlirOperationGetResult(op: createValues, pos: 2); |
| 262 | MlirOperation op1 = mlirOperationGetNextInBlock(op: createValues); |
| 263 | MlirOperation useOp1 = mlirOperationGetNextInBlock(op: op1); |
| 264 | MlirOperation op2 = mlirOperationGetNextInBlock(op: useOp1); |
| 265 | MlirOperation op3 = mlirOperationGetNextInBlock(op: op2); |
| 266 | |
| 267 | MlirRewriterBase rewriter = mlirIRRewriterCreate(context: ctx); |
| 268 | |
| 269 | // Test replace op with values |
| 270 | MlirValue xz[2] = {x, z}; |
| 271 | mlirRewriterBaseReplaceOpWithValues(rewriter, op: op1, nValues: 2, values: xz); |
| 272 | |
| 273 | // Test replace op with op |
| 274 | mlirRewriterBaseReplaceOpWithOperation(rewriter, op: op2, newOp: op3); |
| 275 | |
| 276 | mlirOperationDump(op); |
| 277 | // clang-format off |
| 278 | // CHECK-NEXT: module { |
| 279 | // CHECK-NEXT: %[[res:.*]]:3 = "dialect.create_values"() : () -> (index, index, index) |
| 280 | // CHECK-NEXT: "dialect.use_op1"(%[[res]]#0, %[[res]]#2) : (index, index) -> () |
| 281 | // CHECK-NEXT: %[[res2:.*]]:2 = "dialect.op3"() : () -> (index, index) |
| 282 | // CHECK-NEXT: "dialect.use_op2"(%[[res2]]#0, %[[res2]]#1) : (index, index) -> () |
| 283 | // CHECK-NEXT: } |
| 284 | // clang-format on |
| 285 | |
| 286 | mlirIRRewriterDestroy(rewriter); |
| 287 | mlirModuleDestroy(module); |
| 288 | } |
| 289 | |
| 290 | void testErase(MlirContext ctx) { |
| 291 | // CHECK-LABEL: @testErase |
| 292 | fprintf(stderr, format: "@testErase\n" ); |
| 293 | |
| 294 | const char *moduleString = "\"dialect.op_to_erase\"() : () -> ()\n" |
| 295 | "\"dialect.op2\"() ({\n" |
| 296 | "^bb0():\n" |
| 297 | " \"dialect.op2_nested\"() : () -> ()" |
| 298 | "^block_to_erase():\n" |
| 299 | " \"dialect.op2_nested\"() : () -> ()" |
| 300 | "^bb1():\n" |
| 301 | " \"dialect.op2_nested\"() : () -> ()" |
| 302 | "}) : () -> ()\n" ; |
| 303 | MlirModule module = |
| 304 | mlirModuleCreateParse(context: ctx, module: mlirStringRefCreateFromCString(str: moduleString)); |
| 305 | MlirOperation op = mlirModuleGetOperation(module); |
| 306 | MlirBlock body = mlirModuleGetBody(module); |
| 307 | |
| 308 | // get a handle to all operations/values |
| 309 | MlirOperation opToErase = mlirBlockGetFirstOperation(block: body); |
| 310 | MlirOperation op2 = mlirOperationGetNextInBlock(op: opToErase); |
| 311 | MlirRegion op2Region = mlirOperationGetRegion(op: op2, pos: 0); |
| 312 | MlirBlock bb0 = mlirRegionGetFirstBlock(region: op2Region); |
| 313 | MlirBlock blockToErase = mlirBlockGetNextInRegion(block: bb0); |
| 314 | |
| 315 | MlirRewriterBase rewriter = mlirIRRewriterCreate(context: ctx); |
| 316 | mlirRewriterBaseEraseOp(rewriter, op: opToErase); |
| 317 | mlirRewriterBaseEraseBlock(rewriter, block: blockToErase); |
| 318 | |
| 319 | mlirOperationDump(op); |
| 320 | // CHECK-NEXT: module { |
| 321 | // CHECK-NEXT: "dialect.op2"() ({ |
| 322 | // CHECK-NEXT: "dialect.op2_nested"() : () -> () |
| 323 | // CHECK-NEXT: ^{{.*}}: |
| 324 | // CHECK-NEXT: "dialect.op2_nested"() : () -> () |
| 325 | // CHECK-NEXT: }) : () -> () |
| 326 | // CHECK-NEXT: } |
| 327 | |
| 328 | mlirIRRewriterDestroy(rewriter); |
| 329 | mlirModuleDestroy(module); |
| 330 | } |
| 331 | |
| 332 | void testMove(MlirContext ctx) { |
| 333 | // CHECK-LABEL: @testMove |
| 334 | fprintf(stderr, format: "@testMove\n" ); |
| 335 | |
| 336 | const char *moduleString = "\"dialect.op1\"() : () -> ()\n" |
| 337 | "\"dialect.op2\"() ({\n" |
| 338 | "^bb0(%arg0: index):\n" |
| 339 | " \"dialect.op2_1\"(%arg0) : (index) -> ()" |
| 340 | "^bb1(%arg1: index):\n" |
| 341 | " \"dialect.op2_2\"(%arg1) : (index) -> ()" |
| 342 | "}) : () -> ()\n" |
| 343 | "\"dialect.op3\"() : () -> ()\n" |
| 344 | "\"dialect.op4\"() : () -> ()\n" ; |
| 345 | |
| 346 | MlirModule module = |
| 347 | mlirModuleCreateParse(context: ctx, module: mlirStringRefCreateFromCString(str: moduleString)); |
| 348 | MlirOperation op = mlirModuleGetOperation(module); |
| 349 | MlirBlock body = mlirModuleGetBody(module); |
| 350 | |
| 351 | // get a handle to all operations/values |
| 352 | MlirOperation op1 = mlirBlockGetFirstOperation(block: body); |
| 353 | MlirOperation op2 = mlirOperationGetNextInBlock(op: op1); |
| 354 | MlirOperation op3 = mlirOperationGetNextInBlock(op: op2); |
| 355 | MlirOperation op4 = mlirOperationGetNextInBlock(op: op3); |
| 356 | |
| 357 | MlirRegion region2 = mlirOperationGetRegion(op: op2, pos: 0); |
| 358 | MlirBlock block0 = mlirRegionGetFirstBlock(region: region2); |
| 359 | MlirBlock block1 = mlirBlockGetNextInRegion(block: block0); |
| 360 | |
| 361 | // Test move operations. |
| 362 | MlirRewriterBase rewriter = mlirIRRewriterCreate(context: ctx); |
| 363 | mlirRewriterBaseMoveOpBefore(rewriter, op: op3, existingOp: op1); |
| 364 | mlirRewriterBaseMoveOpAfter(rewriter, op: op4, existingOp: op1); |
| 365 | mlirRewriterBaseMoveBlockBefore(rewriter, block: block1, existingBlock: block0); |
| 366 | |
| 367 | mlirOperationDump(op); |
| 368 | // CHECK-NEXT: module { |
| 369 | // CHECK-NEXT: "dialect.op3"() : () -> () |
| 370 | // CHECK-NEXT: "dialect.op1"() : () -> () |
| 371 | // CHECK-NEXT: "dialect.op4"() : () -> () |
| 372 | // CHECK-NEXT: "dialect.op2"() ({ |
| 373 | // CHECK-NEXT: ^{{.*}}(%[[arg0:.*]]: index): |
| 374 | // CHECK-NEXT: "dialect.op2_2"(%[[arg0]]) : (index) -> () |
| 375 | // CHECK-NEXT: ^{{.*}}(%[[arg1:.*]]: index): // no predecessors |
| 376 | // CHECK-NEXT: "dialect.op2_1"(%[[arg1]]) : (index) -> () |
| 377 | // CHECK-NEXT: }) : () -> () |
| 378 | // CHECK-NEXT: } |
| 379 | |
| 380 | mlirIRRewriterDestroy(rewriter); |
| 381 | mlirModuleDestroy(module); |
| 382 | } |
| 383 | |
| 384 | void testOpModification(MlirContext ctx) { |
| 385 | // CHECK-LABEL: @testOpModification |
| 386 | fprintf(stderr, format: "@testOpModification\n" ); |
| 387 | |
| 388 | const char *moduleString = |
| 389 | "%x, %y = \"dialect.op1\"() : () -> (index, index)\n" |
| 390 | "\"dialect.op2\"(%x) : (index) -> ()\n" ; |
| 391 | |
| 392 | MlirModule module = |
| 393 | mlirModuleCreateParse(context: ctx, module: mlirStringRefCreateFromCString(str: moduleString)); |
| 394 | MlirOperation op = mlirModuleGetOperation(module); |
| 395 | MlirBlock body = mlirModuleGetBody(module); |
| 396 | |
| 397 | // get a handle to all operations/values |
| 398 | MlirOperation op1 = mlirBlockGetFirstOperation(block: body); |
| 399 | MlirValue y = mlirOperationGetResult(op: op1, pos: 1); |
| 400 | MlirOperation op2 = mlirOperationGetNextInBlock(op: op1); |
| 401 | |
| 402 | MlirRewriterBase rewriter = mlirIRRewriterCreate(context: ctx); |
| 403 | mlirRewriterBaseStartOpModification(rewriter, op: op1); |
| 404 | mlirRewriterBaseCancelOpModification(rewriter, op: op1); |
| 405 | |
| 406 | mlirRewriterBaseStartOpModification(rewriter, op: op2); |
| 407 | mlirOperationSetOperand(op: op2, pos: 0, newValue: y); |
| 408 | mlirRewriterBaseFinalizeOpModification(rewriter, op: op2); |
| 409 | |
| 410 | mlirOperationDump(op); |
| 411 | // CHECK-NEXT: module { |
| 412 | // CHECK-NEXT: %[[xy:.*]]:2 = "dialect.op1"() : () -> (index, index) |
| 413 | // CHECK-NEXT: "dialect.op2"(%[[xy]]#1) : (index) -> () |
| 414 | // CHECK-NEXT: } |
| 415 | |
| 416 | mlirIRRewriterDestroy(rewriter); |
| 417 | mlirModuleDestroy(module); |
| 418 | } |
| 419 | |
| 420 | void testReplaceUses(MlirContext ctx) { |
| 421 | // CHECK-LABEL: @testReplaceUses |
| 422 | fprintf(stderr, format: "@testReplaceUses\n" ); |
| 423 | |
| 424 | const char *moduleString = |
| 425 | // Replace values with values |
| 426 | "%x1, %y1, %z1 = \"dialect.op1\"() : () -> (index, index, index)\n" |
| 427 | "%x2, %y2, %z2 = \"dialect.op2\"() : () -> (index, index, index)\n" |
| 428 | "\"dialect.op1_uses\"(%x1, %y1, %z1) : (index, index, index) -> ()\n" |
| 429 | // Replace op with values |
| 430 | "%x3 = \"dialect.op3\"() : () -> index\n" |
| 431 | "%x4 = \"dialect.op4\"() : () -> index\n" |
| 432 | "\"dialect.op3_uses\"(%x3) : (index) -> ()\n" |
| 433 | // Replace op with op |
| 434 | "%x5 = \"dialect.op5\"() : () -> index\n" |
| 435 | "%x6 = \"dialect.op6\"() : () -> index\n" |
| 436 | "\"dialect.op5_uses\"(%x5) : (index) -> ()\n" |
| 437 | // Replace op in block; |
| 438 | "%x7 = \"dialect.op7\"() : () -> index\n" |
| 439 | "%x8 = \"dialect.op8\"() : () -> index\n" |
| 440 | "\"dialect.op9\"() ({\n" |
| 441 | "^bb0:\n" |
| 442 | " \"dialect.op7_uses\"(%x7) : (index) -> ()\n" |
| 443 | "}): () -> ()\n" |
| 444 | "\"dialect.op7_uses\"(%x7) : (index) -> ()\n" |
| 445 | // Replace value with value except in op |
| 446 | "%x10 = \"dialect.op10\"() : () -> index\n" |
| 447 | "%x11 = \"dialect.op11\"() : () -> index\n" |
| 448 | "\"dialect.op10_uses\"(%x10) : (index) -> ()\n" |
| 449 | "\"dialect.op10_uses\"(%x10) : (index) -> ()\n" ; |
| 450 | |
| 451 | MlirModule module = |
| 452 | mlirModuleCreateParse(context: ctx, module: mlirStringRefCreateFromCString(str: moduleString)); |
| 453 | MlirOperation op = mlirModuleGetOperation(module); |
| 454 | MlirBlock body = mlirModuleGetBody(module); |
| 455 | |
| 456 | // get a handle to all operations/values |
| 457 | MlirOperation op1 = mlirBlockGetFirstOperation(block: body); |
| 458 | MlirValue x1 = mlirOperationGetResult(op: op1, pos: 0); |
| 459 | MlirValue y1 = mlirOperationGetResult(op: op1, pos: 1); |
| 460 | MlirValue z1 = mlirOperationGetResult(op: op1, pos: 2); |
| 461 | MlirOperation op2 = mlirOperationGetNextInBlock(op: op1); |
| 462 | MlirValue x2 = mlirOperationGetResult(op: op2, pos: 0); |
| 463 | MlirValue y2 = mlirOperationGetResult(op: op2, pos: 1); |
| 464 | MlirValue z2 = mlirOperationGetResult(op: op2, pos: 2); |
| 465 | MlirOperation op1Uses = mlirOperationGetNextInBlock(op: op2); |
| 466 | |
| 467 | MlirOperation op3 = mlirOperationGetNextInBlock(op: op1Uses); |
| 468 | MlirOperation op4 = mlirOperationGetNextInBlock(op: op3); |
| 469 | MlirValue x4 = mlirOperationGetResult(op: op4, pos: 0); |
| 470 | MlirOperation op3Uses = mlirOperationGetNextInBlock(op: op4); |
| 471 | |
| 472 | MlirOperation op5 = mlirOperationGetNextInBlock(op: op3Uses); |
| 473 | MlirOperation op6 = mlirOperationGetNextInBlock(op: op5); |
| 474 | MlirOperation op5Uses = mlirOperationGetNextInBlock(op: op6); |
| 475 | |
| 476 | MlirOperation op7 = mlirOperationGetNextInBlock(op: op5Uses); |
| 477 | MlirOperation op8 = mlirOperationGetNextInBlock(op: op7); |
| 478 | MlirValue x8 = mlirOperationGetResult(op: op8, pos: 0); |
| 479 | MlirOperation op9 = mlirOperationGetNextInBlock(op: op8); |
| 480 | MlirRegion region9 = mlirOperationGetRegion(op: op9, pos: 0); |
| 481 | MlirBlock block9 = mlirRegionGetFirstBlock(region: region9); |
| 482 | MlirOperation op7Uses = mlirOperationGetNextInBlock(op: op9); |
| 483 | |
| 484 | MlirOperation op10 = mlirOperationGetNextInBlock(op: op7Uses); |
| 485 | MlirValue x10 = mlirOperationGetResult(op: op10, pos: 0); |
| 486 | MlirOperation op11 = mlirOperationGetNextInBlock(op: op10); |
| 487 | MlirValue x11 = mlirOperationGetResult(op: op11, pos: 0); |
| 488 | MlirOperation op10Uses1 = mlirOperationGetNextInBlock(op: op11); |
| 489 | |
| 490 | MlirRewriterBase rewriter = mlirIRRewriterCreate(context: ctx); |
| 491 | |
| 492 | // Replace values |
| 493 | mlirRewriterBaseReplaceAllUsesWith(rewriter, from: x1, to: x2); |
| 494 | MlirValue y1z1[2] = {y1, z1}; |
| 495 | MlirValue y2z2[2] = {y2, z2}; |
| 496 | mlirRewriterBaseReplaceAllValueRangeUsesWith(rewriter, nValues: 2, from: y1z1, to: y2z2); |
| 497 | |
| 498 | // Replace op with values |
| 499 | mlirRewriterBaseReplaceOpWithValues(rewriter, op: op3, nValues: 1, values: &x4); |
| 500 | |
| 501 | // Replace op with op |
| 502 | mlirRewriterBaseReplaceOpWithOperation(rewriter, op: op5, newOp: op6); |
| 503 | |
| 504 | // Replace op with op in block |
| 505 | mlirRewriterBaseReplaceOpUsesWithinBlock(rewriter, op: op7, nNewValues: 1, newValues: &x8, block: block9); |
| 506 | |
| 507 | // Replace value with value except in op |
| 508 | mlirRewriterBaseReplaceAllUsesExcept(rewriter, from: x10, to: x11, exceptedUser: op10Uses1); |
| 509 | |
| 510 | mlirOperationDump(op); |
| 511 | // clang-format off |
| 512 | // CHECK-NEXT: module { |
| 513 | // CHECK-NEXT: %{{.*}}:3 = "dialect.op1"() : () -> (index, index, index) |
| 514 | // CHECK-NEXT: %[[res2:.*]]:3 = "dialect.op2"() : () -> (index, index, index) |
| 515 | // CHECK-NEXT: "dialect.op1_uses"(%[[res2]]#0, %[[res2]]#1, %[[res2]]#2) : (index, index, index) -> () |
| 516 | // CHECK-NEXT: %[[res4:.*]] = "dialect.op4"() : () -> index |
| 517 | // CHECK-NEXT: "dialect.op3_uses"(%[[res4]]) : (index) -> () |
| 518 | // CHECK-NEXT: %[[res6:.*]] = "dialect.op6"() : () -> index |
| 519 | // CHECK-NEXT: "dialect.op5_uses"(%[[res6]]) : (index) -> () |
| 520 | // CHECK-NEXT: %[[res7:.*]] = "dialect.op7"() : () -> index |
| 521 | // CHECK-NEXT: %[[res8:.*]] = "dialect.op8"() : () -> index |
| 522 | // CHECK-NEXT: "dialect.op9"() ({ |
| 523 | // CHECK-NEXT: "dialect.op7_uses"(%[[res8]]) : (index) -> () |
| 524 | // CHECK-NEXT: }) : () -> () |
| 525 | // CHECK-NEXT: "dialect.op7_uses"(%[[res7]]) : (index) -> () |
| 526 | // CHECK-NEXT: %[[res10:.*]] = "dialect.op10"() : () -> index |
| 527 | // CHECK-NEXT: %[[res11:.*]] = "dialect.op11"() : () -> index |
| 528 | // CHECK-NEXT: "dialect.op10_uses"(%[[res10]]) : (index) -> () |
| 529 | // CHECK-NEXT: "dialect.op10_uses"(%[[res11]]) : (index) -> () |
| 530 | // CHECK-NEXT: } |
| 531 | // clang-format on |
| 532 | |
| 533 | mlirIRRewriterDestroy(rewriter); |
| 534 | mlirModuleDestroy(module); |
| 535 | } |
| 536 | |
| 537 | int main(void) { |
| 538 | MlirContext ctx = mlirContextCreate(); |
| 539 | mlirContextSetAllowUnregisteredDialects(context: ctx, true); |
| 540 | mlirContextGetOrLoadDialect(context: ctx, name: mlirStringRefCreateFromCString(str: "builtin" )); |
| 541 | |
| 542 | testInsertionPoint(ctx); |
| 543 | testCreateBlock(ctx); |
| 544 | testInlineRegionBlock(ctx); |
| 545 | testReplaceOp(ctx); |
| 546 | testErase(ctx); |
| 547 | testMove(ctx); |
| 548 | testOpModification(ctx); |
| 549 | testReplaceUses(ctx); |
| 550 | |
| 551 | mlirContextDestroy(context: ctx); |
| 552 | return 0; |
| 553 | } |
| 554 | |