1 | //===- rewrite.c - Test of the rewriting C API ----------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM |
4 | // Exceptions. |
5 | // See https://llvm.org/LICENSE.txt for license information. |
6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
7 | // |
8 | //===----------------------------------------------------------------------===// |
9 | |
10 | // RUN: mlir-capi-rewrite-test 2>&1 | FileCheck %s |
11 | |
12 | #include "mlir-c/Rewrite.h" |
13 | #include "mlir-c/BuiltinTypes.h" |
14 | #include "mlir-c/IR.h" |
15 | |
16 | #include <assert.h> |
17 | #include <stdio.h> |
18 | |
19 | MlirOperation createOperationWithName(MlirContext ctx, const char *name) { |
20 | MlirStringRef nameRef = mlirStringRefCreateFromCString(str: name); |
21 | MlirLocation loc = mlirLocationUnknownGet(context: ctx); |
22 | MlirOperationState state = mlirOperationStateGet(name: nameRef, loc); |
23 | MlirType indexType = mlirIndexTypeGet(ctx); |
24 | mlirOperationStateAddResults(state: &state, n: 1, results: &indexType); |
25 | return mlirOperationCreate(state: &state); |
26 | } |
27 | |
28 | void testInsertionPoint(MlirContext ctx) { |
29 | // CHECK-LABEL: @testInsertionPoint |
30 | fprintf(stderr, format: "@testInsertionPoint\n" ); |
31 | |
32 | const char *moduleString = "\"dialect.op1\"() : () -> ()\n" ; |
33 | MlirModule module = |
34 | mlirModuleCreateParse(context: ctx, module: mlirStringRefCreateFromCString(str: moduleString)); |
35 | MlirOperation op = mlirModuleGetOperation(module); |
36 | MlirBlock body = mlirModuleGetBody(module); |
37 | MlirOperation op1 = mlirBlockGetFirstOperation(block: body); |
38 | |
39 | // IRRewriter create |
40 | MlirRewriterBase rewriter = mlirIRRewriterCreate(context: ctx); |
41 | |
42 | // Insert before op |
43 | mlirRewriterBaseSetInsertionPointBefore(rewriter, op: op1); |
44 | MlirOperation op2 = createOperationWithName(ctx, name: "dialect.op2" ); |
45 | mlirRewriterBaseInsert(rewriter, op: op2); |
46 | |
47 | // Insert after op |
48 | mlirRewriterBaseSetInsertionPointAfter(rewriter, op: op2); |
49 | MlirOperation op3 = createOperationWithName(ctx, name: "dialect.op3" ); |
50 | mlirRewriterBaseInsert(rewriter, op: op3); |
51 | MlirValue op3Res = mlirOperationGetResult(op: op3, pos: 0); |
52 | |
53 | // Insert after value |
54 | mlirRewriterBaseSetInsertionPointAfterValue(rewriter, value: op3Res); |
55 | MlirOperation op4 = createOperationWithName(ctx, name: "dialect.op4" ); |
56 | mlirRewriterBaseInsert(rewriter, op: op4); |
57 | |
58 | // Insert at beginning of block |
59 | mlirRewriterBaseSetInsertionPointToStart(rewriter, block: body); |
60 | MlirOperation op5 = createOperationWithName(ctx, name: "dialect.op5" ); |
61 | mlirRewriterBaseInsert(rewriter, op: op5); |
62 | |
63 | // Insert at end of block |
64 | mlirRewriterBaseSetInsertionPointToEnd(rewriter, block: body); |
65 | MlirOperation op6 = createOperationWithName(ctx, name: "dialect.op6" ); |
66 | mlirRewriterBaseInsert(rewriter, op: op6); |
67 | |
68 | // Get insertion blocks |
69 | MlirBlock block1 = mlirRewriterBaseGetBlock(rewriter); |
70 | MlirBlock block2 = mlirRewriterBaseGetInsertionBlock(rewriter); |
71 | (void)block1; |
72 | (void)block2; |
73 | assert(body.ptr == block1.ptr); |
74 | assert(body.ptr == block2.ptr); |
75 | |
76 | // clang-format off |
77 | // CHECK-NEXT: module { |
78 | // CHECK-NEXT: %{{.*}} = "dialect.op5"() : () -> index |
79 | // CHECK-NEXT: %{{.*}} = "dialect.op2"() : () -> index |
80 | // CHECK-NEXT: %{{.*}} = "dialect.op3"() : () -> index |
81 | // CHECK-NEXT: %{{.*}} = "dialect.op4"() : () -> index |
82 | // CHECK-NEXT: "dialect.op1"() : () -> () |
83 | // CHECK-NEXT: %{{.*}} = "dialect.op6"() : () -> index |
84 | // CHECK-NEXT: } |
85 | // clang-format on |
86 | mlirOperationDump(op); |
87 | |
88 | mlirIRRewriterDestroy(rewriter); |
89 | mlirModuleDestroy(module); |
90 | } |
91 | |
92 | void testCreateBlock(MlirContext ctx) { |
93 | // CHECK-LABEL: @testCreateBlock |
94 | fprintf(stderr, format: "@testCreateBlock\n" ); |
95 | |
96 | const char *moduleString = "\"dialect.op1\"() ({^bb0:}) : () -> ()\n" |
97 | "\"dialect.op2\"() ({^bb0:}) : () -> ()\n" ; |
98 | MlirModule module = |
99 | mlirModuleCreateParse(context: ctx, module: mlirStringRefCreateFromCString(str: moduleString)); |
100 | MlirOperation op = mlirModuleGetOperation(module); |
101 | MlirBlock body = mlirModuleGetBody(module); |
102 | |
103 | MlirOperation op1 = mlirBlockGetFirstOperation(block: body); |
104 | MlirRegion region1 = mlirOperationGetRegion(op: op1, pos: 0); |
105 | MlirBlock block1 = mlirRegionGetFirstBlock(region: region1); |
106 | |
107 | MlirOperation op2 = mlirOperationGetNextInBlock(op: op1); |
108 | MlirRegion region2 = mlirOperationGetRegion(op: op2, pos: 0); |
109 | MlirBlock block2 = mlirRegionGetFirstBlock(region: region2); |
110 | |
111 | MlirRewriterBase rewriter = mlirIRRewriterCreate(context: ctx); |
112 | |
113 | // Create block before |
114 | MlirType indexType = mlirIndexTypeGet(ctx); |
115 | MlirLocation unknown = mlirLocationUnknownGet(context: ctx); |
116 | mlirRewriterBaseCreateBlockBefore(rewriter, insertBefore: block1, nArgTypes: 1, argTypes: &indexType, locations: &unknown); |
117 | |
118 | mlirRewriterBaseSetInsertionPointToEnd(rewriter, block: body); |
119 | |
120 | // Clone operation |
121 | mlirRewriterBaseClone(rewriter, op: op1); |
122 | |
123 | // Clone without regions |
124 | mlirRewriterBaseCloneWithoutRegions(rewriter, op: op1); |
125 | |
126 | // Clone region before |
127 | mlirRewriterBaseCloneRegionBefore(rewriter, region: region1, before: block2); |
128 | |
129 | mlirOperationDump(op); |
130 | // clang-format off |
131 | // CHECK-NEXT: "builtin.module"() ({ |
132 | // CHECK-NEXT: "dialect.op1"() ({ |
133 | // CHECK-NEXT: ^{{.*}}(%{{.*}}: index): |
134 | // CHECK-NEXT: ^{{.*}}: |
135 | // CHECK-NEXT: }) : () -> () |
136 | // CHECK-NEXT: "dialect.op2"() ({ |
137 | // CHECK-NEXT: ^{{.*}}(%{{.*}}: index): |
138 | // CHECK-NEXT: ^{{.*}}: |
139 | // CHECK-NEXT: ^{{.*}}: |
140 | // CHECK-NEXT: }) : () -> () |
141 | // CHECK-NEXT: "dialect.op1"() ({ |
142 | // CHECK-NEXT: ^{{.*}}(%{{.*}}: index): |
143 | // CHECK-NEXT: ^{{.*}}: |
144 | // CHECK-NEXT: }) : () -> () |
145 | // CHECK-NEXT: "dialect.op1"() ({ |
146 | // CHECK-NEXT: }) : () -> () |
147 | // CHECK-NEXT: }) : () -> () |
148 | // clang-format on |
149 | |
150 | mlirIRRewriterDestroy(rewriter); |
151 | mlirModuleDestroy(module); |
152 | } |
153 | |
154 | void testInlineRegionBlock(MlirContext ctx) { |
155 | // CHECK-LABEL: @testInlineRegionBlock |
156 | fprintf(stderr, format: "@testInlineRegionBlock\n" ); |
157 | |
158 | const char *moduleString = |
159 | "\"dialect.op1\"() ({\n" |
160 | " ^bb0(%arg0: index):\n" |
161 | " \"dialect.op1_in1\"(%arg0) [^bb1] : (index) -> ()\n" |
162 | " ^bb1():\n" |
163 | " \"dialect.op1_in2\"() : () -> ()\n" |
164 | "}) : () -> ()\n" |
165 | "\"dialect.op2\"() ({^bb0:}) : () -> ()\n" |
166 | "\"dialect.op3\"() ({\n" |
167 | " ^bb0(%arg0: index):\n" |
168 | " \"dialect.op3_in1\"(%arg0) : (index) -> ()\n" |
169 | " ^bb1():\n" |
170 | " %x = \"dialect.op3_in2\"() : () -> index\n" |
171 | " %y = \"dialect.op3_in3\"() : () -> index\n" |
172 | "}) : () -> ()\n" |
173 | "\"dialect.op4\"() ({\n" |
174 | " ^bb0():\n" |
175 | " \"dialect.op4_in1\"() : () -> index\n" |
176 | " ^bb1(%arg0: index):\n" |
177 | " \"dialect.op4_in2\"(%arg0) : (index) -> ()\n" |
178 | "}) : () -> ()\n" ; |
179 | MlirModule module = |
180 | mlirModuleCreateParse(context: ctx, module: mlirStringRefCreateFromCString(str: moduleString)); |
181 | MlirOperation op = mlirModuleGetOperation(module); |
182 | MlirBlock body = mlirModuleGetBody(module); |
183 | |
184 | MlirOperation op1 = mlirBlockGetFirstOperation(block: body); |
185 | MlirRegion region1 = mlirOperationGetRegion(op: op1, pos: 0); |
186 | |
187 | MlirOperation op2 = mlirOperationGetNextInBlock(op: op1); |
188 | MlirRegion region2 = mlirOperationGetRegion(op: op2, pos: 0); |
189 | MlirBlock block2 = mlirRegionGetFirstBlock(region: region2); |
190 | |
191 | MlirOperation op3 = mlirOperationGetNextInBlock(op: op2); |
192 | MlirRegion region3 = mlirOperationGetRegion(op: op3, pos: 0); |
193 | MlirBlock block3_1 = mlirRegionGetFirstBlock(region: region3); |
194 | MlirBlock block3_2 = mlirBlockGetNextInRegion(block: block3_1); |
195 | MlirOperation op3_in2 = mlirBlockGetFirstOperation(block: block3_2); |
196 | MlirValue op3_in2_res = mlirOperationGetResult(op: op3_in2, pos: 0); |
197 | MlirOperation op3_in3 = mlirOperationGetNextInBlock(op: op3_in2); |
198 | |
199 | MlirOperation op4 = mlirOperationGetNextInBlock(op: op3); |
200 | MlirRegion region4 = mlirOperationGetRegion(op: op4, pos: 0); |
201 | MlirBlock block4_1 = mlirRegionGetFirstBlock(region: region4); |
202 | MlirOperation op4_in1 = mlirBlockGetFirstOperation(block: block4_1); |
203 | MlirValue op4_in1_res = mlirOperationGetResult(op: op4_in1, pos: 0); |
204 | MlirBlock block4_2 = mlirBlockGetNextInRegion(block: block4_1); |
205 | |
206 | MlirRewriterBase rewriter = mlirIRRewriterCreate(context: ctx); |
207 | |
208 | // Test these three functions |
209 | mlirRewriterBaseInlineRegionBefore(rewriter, region: region1, before: block2); |
210 | mlirRewriterBaseInlineBlockBefore(rewriter, source: block3_1, op: op3_in3, nArgValues: 1, |
211 | argValues: &op3_in2_res); |
212 | mlirRewriterBaseMergeBlocks(rewriter, source: block4_2, dest: block4_1, nArgValues: 1, argValues: &op4_in1_res); |
213 | |
214 | mlirOperationDump(op); |
215 | // clang-format off |
216 | // CHECK-NEXT: "builtin.module"() ({ |
217 | // CHECK-NEXT: "dialect.op1"() ({ |
218 | // CHECK-NEXT: }) : () -> () |
219 | // CHECK-NEXT: "dialect.op2"() ({ |
220 | // CHECK-NEXT: ^{{.*}}(%{{.*}}: index): |
221 | // CHECK-NEXT: "dialect.op1_in1"(%{{.*}})[^[[bb:.*]]] : (index) -> () |
222 | // CHECK-NEXT: ^[[bb]]: |
223 | // CHECK-NEXT: "dialect.op1_in2"() : () -> () |
224 | // CHECK-NEXT: ^{{.*}}: // no predecessors |
225 | // CHECK-NEXT: }) : () -> () |
226 | // CHECK-NEXT: "dialect.op3"() ({ |
227 | // CHECK-NEXT: %{{.*}} = "dialect.op3_in2"() : () -> index |
228 | // CHECK-NEXT: "dialect.op3_in1"(%{{.*}}) : (index) -> () |
229 | // CHECK-NEXT: %{{.*}} = "dialect.op3_in3"() : () -> index |
230 | // CHECK-NEXT: }) : () -> () |
231 | // CHECK-NEXT: "dialect.op4"() ({ |
232 | // CHECK-NEXT: %{{.*}} = "dialect.op4_in1"() : () -> index |
233 | // CHECK-NEXT: "dialect.op4_in2"(%{{.*}}) : (index) -> () |
234 | // CHECK-NEXT: }) : () -> () |
235 | // CHECK-NEXT: }) : () -> () |
236 | // clang-format on |
237 | |
238 | mlirIRRewriterDestroy(rewriter); |
239 | mlirModuleDestroy(module); |
240 | } |
241 | |
242 | void testReplaceOp(MlirContext ctx) { |
243 | // CHECK-LABEL: @testReplaceOp |
244 | fprintf(stderr, format: "@testReplaceOp\n" ); |
245 | |
246 | const char *moduleString = |
247 | "%x, %y, %z = \"dialect.create_values\"() : () -> (index, index, index)\n" |
248 | "%x_1, %y_1 = \"dialect.op1\"() : () -> (index, index)\n" |
249 | "\"dialect.use_op1\"(%x_1, %y_1) : (index, index) -> ()\n" |
250 | "%x_2, %y_2 = \"dialect.op2\"() : () -> (index, index)\n" |
251 | "%x_3, %y_3 = \"dialect.op3\"() : () -> (index, index)\n" |
252 | "\"dialect.use_op2\"(%x_2, %y_2) : (index, index) -> ()\n" ; |
253 | MlirModule module = |
254 | mlirModuleCreateParse(context: ctx, module: mlirStringRefCreateFromCString(str: moduleString)); |
255 | MlirOperation op = mlirModuleGetOperation(module); |
256 | MlirBlock body = mlirModuleGetBody(module); |
257 | |
258 | // get a handle to all operations/values |
259 | MlirOperation createValues = mlirBlockGetFirstOperation(block: body); |
260 | MlirValue x = mlirOperationGetResult(op: createValues, pos: 0); |
261 | MlirValue z = mlirOperationGetResult(op: createValues, pos: 2); |
262 | MlirOperation op1 = mlirOperationGetNextInBlock(op: createValues); |
263 | MlirOperation useOp1 = mlirOperationGetNextInBlock(op: op1); |
264 | MlirOperation op2 = mlirOperationGetNextInBlock(op: useOp1); |
265 | MlirOperation op3 = mlirOperationGetNextInBlock(op: op2); |
266 | |
267 | MlirRewriterBase rewriter = mlirIRRewriterCreate(context: ctx); |
268 | |
269 | // Test replace op with values |
270 | MlirValue xz[2] = {x, z}; |
271 | mlirRewriterBaseReplaceOpWithValues(rewriter, op: op1, nValues: 2, values: xz); |
272 | |
273 | // Test replace op with op |
274 | mlirRewriterBaseReplaceOpWithOperation(rewriter, op: op2, newOp: op3); |
275 | |
276 | mlirOperationDump(op); |
277 | // clang-format off |
278 | // CHECK-NEXT: module { |
279 | // CHECK-NEXT: %[[res:.*]]:3 = "dialect.create_values"() : () -> (index, index, index) |
280 | // CHECK-NEXT: "dialect.use_op1"(%[[res]]#0, %[[res]]#2) : (index, index) -> () |
281 | // CHECK-NEXT: %[[res2:.*]]:2 = "dialect.op3"() : () -> (index, index) |
282 | // CHECK-NEXT: "dialect.use_op2"(%[[res2]]#0, %[[res2]]#1) : (index, index) -> () |
283 | // CHECK-NEXT: } |
284 | // clang-format on |
285 | |
286 | mlirIRRewriterDestroy(rewriter); |
287 | mlirModuleDestroy(module); |
288 | } |
289 | |
290 | void testErase(MlirContext ctx) { |
291 | // CHECK-LABEL: @testErase |
292 | fprintf(stderr, format: "@testErase\n" ); |
293 | |
294 | const char *moduleString = "\"dialect.op_to_erase\"() : () -> ()\n" |
295 | "\"dialect.op2\"() ({\n" |
296 | "^bb0():\n" |
297 | " \"dialect.op2_nested\"() : () -> ()" |
298 | "^block_to_erase():\n" |
299 | " \"dialect.op2_nested\"() : () -> ()" |
300 | "^bb1():\n" |
301 | " \"dialect.op2_nested\"() : () -> ()" |
302 | "}) : () -> ()\n" ; |
303 | MlirModule module = |
304 | mlirModuleCreateParse(context: ctx, module: mlirStringRefCreateFromCString(str: moduleString)); |
305 | MlirOperation op = mlirModuleGetOperation(module); |
306 | MlirBlock body = mlirModuleGetBody(module); |
307 | |
308 | // get a handle to all operations/values |
309 | MlirOperation opToErase = mlirBlockGetFirstOperation(block: body); |
310 | MlirOperation op2 = mlirOperationGetNextInBlock(op: opToErase); |
311 | MlirRegion op2Region = mlirOperationGetRegion(op: op2, pos: 0); |
312 | MlirBlock bb0 = mlirRegionGetFirstBlock(region: op2Region); |
313 | MlirBlock blockToErase = mlirBlockGetNextInRegion(block: bb0); |
314 | |
315 | MlirRewriterBase rewriter = mlirIRRewriterCreate(context: ctx); |
316 | mlirRewriterBaseEraseOp(rewriter, op: opToErase); |
317 | mlirRewriterBaseEraseBlock(rewriter, block: blockToErase); |
318 | |
319 | mlirOperationDump(op); |
320 | // CHECK-NEXT: module { |
321 | // CHECK-NEXT: "dialect.op2"() ({ |
322 | // CHECK-NEXT: "dialect.op2_nested"() : () -> () |
323 | // CHECK-NEXT: ^{{.*}}: |
324 | // CHECK-NEXT: "dialect.op2_nested"() : () -> () |
325 | // CHECK-NEXT: }) : () -> () |
326 | // CHECK-NEXT: } |
327 | |
328 | mlirIRRewriterDestroy(rewriter); |
329 | mlirModuleDestroy(module); |
330 | } |
331 | |
332 | void testMove(MlirContext ctx) { |
333 | // CHECK-LABEL: @testMove |
334 | fprintf(stderr, format: "@testMove\n" ); |
335 | |
336 | const char *moduleString = "\"dialect.op1\"() : () -> ()\n" |
337 | "\"dialect.op2\"() ({\n" |
338 | "^bb0(%arg0: index):\n" |
339 | " \"dialect.op2_1\"(%arg0) : (index) -> ()" |
340 | "^bb1(%arg1: index):\n" |
341 | " \"dialect.op2_2\"(%arg1) : (index) -> ()" |
342 | "}) : () -> ()\n" |
343 | "\"dialect.op3\"() : () -> ()\n" |
344 | "\"dialect.op4\"() : () -> ()\n" ; |
345 | |
346 | MlirModule module = |
347 | mlirModuleCreateParse(context: ctx, module: mlirStringRefCreateFromCString(str: moduleString)); |
348 | MlirOperation op = mlirModuleGetOperation(module); |
349 | MlirBlock body = mlirModuleGetBody(module); |
350 | |
351 | // get a handle to all operations/values |
352 | MlirOperation op1 = mlirBlockGetFirstOperation(block: body); |
353 | MlirOperation op2 = mlirOperationGetNextInBlock(op: op1); |
354 | MlirOperation op3 = mlirOperationGetNextInBlock(op: op2); |
355 | MlirOperation op4 = mlirOperationGetNextInBlock(op: op3); |
356 | |
357 | MlirRegion region2 = mlirOperationGetRegion(op: op2, pos: 0); |
358 | MlirBlock block0 = mlirRegionGetFirstBlock(region: region2); |
359 | MlirBlock block1 = mlirBlockGetNextInRegion(block: block0); |
360 | |
361 | // Test move operations. |
362 | MlirRewriterBase rewriter = mlirIRRewriterCreate(context: ctx); |
363 | mlirRewriterBaseMoveOpBefore(rewriter, op: op3, existingOp: op1); |
364 | mlirRewriterBaseMoveOpAfter(rewriter, op: op4, existingOp: op1); |
365 | mlirRewriterBaseMoveBlockBefore(rewriter, block: block1, existingBlock: block0); |
366 | |
367 | mlirOperationDump(op); |
368 | // CHECK-NEXT: module { |
369 | // CHECK-NEXT: "dialect.op3"() : () -> () |
370 | // CHECK-NEXT: "dialect.op1"() : () -> () |
371 | // CHECK-NEXT: "dialect.op4"() : () -> () |
372 | // CHECK-NEXT: "dialect.op2"() ({ |
373 | // CHECK-NEXT: ^{{.*}}(%[[arg0:.*]]: index): |
374 | // CHECK-NEXT: "dialect.op2_2"(%[[arg0]]) : (index) -> () |
375 | // CHECK-NEXT: ^{{.*}}(%[[arg1:.*]]: index): // no predecessors |
376 | // CHECK-NEXT: "dialect.op2_1"(%[[arg1]]) : (index) -> () |
377 | // CHECK-NEXT: }) : () -> () |
378 | // CHECK-NEXT: } |
379 | |
380 | mlirIRRewriterDestroy(rewriter); |
381 | mlirModuleDestroy(module); |
382 | } |
383 | |
384 | void testOpModification(MlirContext ctx) { |
385 | // CHECK-LABEL: @testOpModification |
386 | fprintf(stderr, format: "@testOpModification\n" ); |
387 | |
388 | const char *moduleString = |
389 | "%x, %y = \"dialect.op1\"() : () -> (index, index)\n" |
390 | "\"dialect.op2\"(%x) : (index) -> ()\n" ; |
391 | |
392 | MlirModule module = |
393 | mlirModuleCreateParse(context: ctx, module: mlirStringRefCreateFromCString(str: moduleString)); |
394 | MlirOperation op = mlirModuleGetOperation(module); |
395 | MlirBlock body = mlirModuleGetBody(module); |
396 | |
397 | // get a handle to all operations/values |
398 | MlirOperation op1 = mlirBlockGetFirstOperation(block: body); |
399 | MlirValue y = mlirOperationGetResult(op: op1, pos: 1); |
400 | MlirOperation op2 = mlirOperationGetNextInBlock(op: op1); |
401 | |
402 | MlirRewriterBase rewriter = mlirIRRewriterCreate(context: ctx); |
403 | mlirRewriterBaseStartOpModification(rewriter, op: op1); |
404 | mlirRewriterBaseCancelOpModification(rewriter, op: op1); |
405 | |
406 | mlirRewriterBaseStartOpModification(rewriter, op: op2); |
407 | mlirOperationSetOperand(op: op2, pos: 0, newValue: y); |
408 | mlirRewriterBaseFinalizeOpModification(rewriter, op: op2); |
409 | |
410 | mlirOperationDump(op); |
411 | // CHECK-NEXT: module { |
412 | // CHECK-NEXT: %[[xy:.*]]:2 = "dialect.op1"() : () -> (index, index) |
413 | // CHECK-NEXT: "dialect.op2"(%[[xy]]#1) : (index) -> () |
414 | // CHECK-NEXT: } |
415 | |
416 | mlirIRRewriterDestroy(rewriter); |
417 | mlirModuleDestroy(module); |
418 | } |
419 | |
420 | void testReplaceUses(MlirContext ctx) { |
421 | // CHECK-LABEL: @testReplaceUses |
422 | fprintf(stderr, format: "@testReplaceUses\n" ); |
423 | |
424 | const char *moduleString = |
425 | // Replace values with values |
426 | "%x1, %y1, %z1 = \"dialect.op1\"() : () -> (index, index, index)\n" |
427 | "%x2, %y2, %z2 = \"dialect.op2\"() : () -> (index, index, index)\n" |
428 | "\"dialect.op1_uses\"(%x1, %y1, %z1) : (index, index, index) -> ()\n" |
429 | // Replace op with values |
430 | "%x3 = \"dialect.op3\"() : () -> index\n" |
431 | "%x4 = \"dialect.op4\"() : () -> index\n" |
432 | "\"dialect.op3_uses\"(%x3) : (index) -> ()\n" |
433 | // Replace op with op |
434 | "%x5 = \"dialect.op5\"() : () -> index\n" |
435 | "%x6 = \"dialect.op6\"() : () -> index\n" |
436 | "\"dialect.op5_uses\"(%x5) : (index) -> ()\n" |
437 | // Replace op in block; |
438 | "%x7 = \"dialect.op7\"() : () -> index\n" |
439 | "%x8 = \"dialect.op8\"() : () -> index\n" |
440 | "\"dialect.op9\"() ({\n" |
441 | "^bb0:\n" |
442 | " \"dialect.op7_uses\"(%x7) : (index) -> ()\n" |
443 | "}): () -> ()\n" |
444 | "\"dialect.op7_uses\"(%x7) : (index) -> ()\n" |
445 | // Replace value with value except in op |
446 | "%x10 = \"dialect.op10\"() : () -> index\n" |
447 | "%x11 = \"dialect.op11\"() : () -> index\n" |
448 | "\"dialect.op10_uses\"(%x10) : (index) -> ()\n" |
449 | "\"dialect.op10_uses\"(%x10) : (index) -> ()\n" ; |
450 | |
451 | MlirModule module = |
452 | mlirModuleCreateParse(context: ctx, module: mlirStringRefCreateFromCString(str: moduleString)); |
453 | MlirOperation op = mlirModuleGetOperation(module); |
454 | MlirBlock body = mlirModuleGetBody(module); |
455 | |
456 | // get a handle to all operations/values |
457 | MlirOperation op1 = mlirBlockGetFirstOperation(block: body); |
458 | MlirValue x1 = mlirOperationGetResult(op: op1, pos: 0); |
459 | MlirValue y1 = mlirOperationGetResult(op: op1, pos: 1); |
460 | MlirValue z1 = mlirOperationGetResult(op: op1, pos: 2); |
461 | MlirOperation op2 = mlirOperationGetNextInBlock(op: op1); |
462 | MlirValue x2 = mlirOperationGetResult(op: op2, pos: 0); |
463 | MlirValue y2 = mlirOperationGetResult(op: op2, pos: 1); |
464 | MlirValue z2 = mlirOperationGetResult(op: op2, pos: 2); |
465 | MlirOperation op1Uses = mlirOperationGetNextInBlock(op: op2); |
466 | |
467 | MlirOperation op3 = mlirOperationGetNextInBlock(op: op1Uses); |
468 | MlirOperation op4 = mlirOperationGetNextInBlock(op: op3); |
469 | MlirValue x4 = mlirOperationGetResult(op: op4, pos: 0); |
470 | MlirOperation op3Uses = mlirOperationGetNextInBlock(op: op4); |
471 | |
472 | MlirOperation op5 = mlirOperationGetNextInBlock(op: op3Uses); |
473 | MlirOperation op6 = mlirOperationGetNextInBlock(op: op5); |
474 | MlirOperation op5Uses = mlirOperationGetNextInBlock(op: op6); |
475 | |
476 | MlirOperation op7 = mlirOperationGetNextInBlock(op: op5Uses); |
477 | MlirOperation op8 = mlirOperationGetNextInBlock(op: op7); |
478 | MlirValue x8 = mlirOperationGetResult(op: op8, pos: 0); |
479 | MlirOperation op9 = mlirOperationGetNextInBlock(op: op8); |
480 | MlirRegion region9 = mlirOperationGetRegion(op: op9, pos: 0); |
481 | MlirBlock block9 = mlirRegionGetFirstBlock(region: region9); |
482 | MlirOperation op7Uses = mlirOperationGetNextInBlock(op: op9); |
483 | |
484 | MlirOperation op10 = mlirOperationGetNextInBlock(op: op7Uses); |
485 | MlirValue x10 = mlirOperationGetResult(op: op10, pos: 0); |
486 | MlirOperation op11 = mlirOperationGetNextInBlock(op: op10); |
487 | MlirValue x11 = mlirOperationGetResult(op: op11, pos: 0); |
488 | MlirOperation op10Uses1 = mlirOperationGetNextInBlock(op: op11); |
489 | |
490 | MlirRewriterBase rewriter = mlirIRRewriterCreate(context: ctx); |
491 | |
492 | // Replace values |
493 | mlirRewriterBaseReplaceAllUsesWith(rewriter, from: x1, to: x2); |
494 | MlirValue y1z1[2] = {y1, z1}; |
495 | MlirValue y2z2[2] = {y2, z2}; |
496 | mlirRewriterBaseReplaceAllValueRangeUsesWith(rewriter, nValues: 2, from: y1z1, to: y2z2); |
497 | |
498 | // Replace op with values |
499 | mlirRewriterBaseReplaceOpWithValues(rewriter, op: op3, nValues: 1, values: &x4); |
500 | |
501 | // Replace op with op |
502 | mlirRewriterBaseReplaceOpWithOperation(rewriter, op: op5, newOp: op6); |
503 | |
504 | // Replace op with op in block |
505 | mlirRewriterBaseReplaceOpUsesWithinBlock(rewriter, op: op7, nNewValues: 1, newValues: &x8, block: block9); |
506 | |
507 | // Replace value with value except in op |
508 | mlirRewriterBaseReplaceAllUsesExcept(rewriter, from: x10, to: x11, exceptedUser: op10Uses1); |
509 | |
510 | mlirOperationDump(op); |
511 | // clang-format off |
512 | // CHECK-NEXT: module { |
513 | // CHECK-NEXT: %{{.*}}:3 = "dialect.op1"() : () -> (index, index, index) |
514 | // CHECK-NEXT: %[[res2:.*]]:3 = "dialect.op2"() : () -> (index, index, index) |
515 | // CHECK-NEXT: "dialect.op1_uses"(%[[res2]]#0, %[[res2]]#1, %[[res2]]#2) : (index, index, index) -> () |
516 | // CHECK-NEXT: %[[res4:.*]] = "dialect.op4"() : () -> index |
517 | // CHECK-NEXT: "dialect.op3_uses"(%[[res4]]) : (index) -> () |
518 | // CHECK-NEXT: %[[res6:.*]] = "dialect.op6"() : () -> index |
519 | // CHECK-NEXT: "dialect.op5_uses"(%[[res6]]) : (index) -> () |
520 | // CHECK-NEXT: %[[res7:.*]] = "dialect.op7"() : () -> index |
521 | // CHECK-NEXT: %[[res8:.*]] = "dialect.op8"() : () -> index |
522 | // CHECK-NEXT: "dialect.op9"() ({ |
523 | // CHECK-NEXT: "dialect.op7_uses"(%[[res8]]) : (index) -> () |
524 | // CHECK-NEXT: }) : () -> () |
525 | // CHECK-NEXT: "dialect.op7_uses"(%[[res7]]) : (index) -> () |
526 | // CHECK-NEXT: %[[res10:.*]] = "dialect.op10"() : () -> index |
527 | // CHECK-NEXT: %[[res11:.*]] = "dialect.op11"() : () -> index |
528 | // CHECK-NEXT: "dialect.op10_uses"(%[[res10]]) : (index) -> () |
529 | // CHECK-NEXT: "dialect.op10_uses"(%[[res11]]) : (index) -> () |
530 | // CHECK-NEXT: } |
531 | // clang-format on |
532 | |
533 | mlirIRRewriterDestroy(rewriter); |
534 | mlirModuleDestroy(module); |
535 | } |
536 | |
537 | int main(void) { |
538 | MlirContext ctx = mlirContextCreate(); |
539 | mlirContextSetAllowUnregisteredDialects(context: ctx, true); |
540 | mlirContextGetOrLoadDialect(context: ctx, name: mlirStringRefCreateFromCString(str: "builtin" )); |
541 | |
542 | testInsertionPoint(ctx); |
543 | testCreateBlock(ctx); |
544 | testInlineRegionBlock(ctx); |
545 | testReplaceOp(ctx); |
546 | testErase(ctx); |
547 | testMove(ctx); |
548 | testOpModification(ctx); |
549 | testReplaceUses(ctx); |
550 | |
551 | mlirContextDestroy(context: ctx); |
552 | return 0; |
553 | } |
554 | |