1//===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir-c/Rewrite.h"
10
11#include "mlir-c/Transforms.h"
12#include "mlir/CAPI/IR.h"
13#include "mlir/CAPI/Rewrite.h"
14#include "mlir/CAPI/Support.h"
15#include "mlir/CAPI/Wrap.h"
16#include "mlir/IR/PatternMatch.h"
17#include "mlir/Rewrite/FrozenRewritePatternSet.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19
20using namespace mlir;
21
22//===----------------------------------------------------------------------===//
23/// RewriterBase API inherited from OpBuilder
24//===----------------------------------------------------------------------===//
25
26MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) {
27 return wrap(cpp: unwrap(c: rewriter)->getContext());
28}
29
30//===----------------------------------------------------------------------===//
31/// Insertion points methods
32//===----------------------------------------------------------------------===//
33
34void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter) {
35 unwrap(c: rewriter)->clearInsertionPoint();
36}
37
38void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter,
39 MlirOperation op) {
40 unwrap(c: rewriter)->setInsertionPoint(unwrap(c: op));
41}
42
43void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter,
44 MlirOperation op) {
45 unwrap(c: rewriter)->setInsertionPointAfter(unwrap(c: op));
46}
47
48void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter,
49 MlirValue value) {
50 unwrap(c: rewriter)->setInsertionPointAfterValue(unwrap(c: value));
51}
52
53void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter,
54 MlirBlock block) {
55 unwrap(c: rewriter)->setInsertionPointToStart(unwrap(c: block));
56}
57
58void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter,
59 MlirBlock block) {
60 unwrap(c: rewriter)->setInsertionPointToEnd(unwrap(c: block));
61}
62
63MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter) {
64 return wrap(cpp: unwrap(c: rewriter)->getInsertionBlock());
65}
66
67MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) {
68 return wrap(cpp: unwrap(c: rewriter)->getBlock());
69}
70
71//===----------------------------------------------------------------------===//
72/// Block and operation creation/insertion/cloning
73//===----------------------------------------------------------------------===//
74
75MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter,
76 MlirBlock insertBefore,
77 intptr_t nArgTypes,
78 MlirType const *argTypes,
79 MlirLocation const *locations) {
80 SmallVector<Type, 4> args;
81 ArrayRef<Type> unwrappedArgs = unwrapList(size: nArgTypes, first: argTypes, storage&: args);
82 SmallVector<Location, 4> locs;
83 ArrayRef<Location> unwrappedLocs = unwrapList(size: nArgTypes, first: locations, storage&: locs);
84 return wrap(cpp: unwrap(c: rewriter)->createBlock(insertBefore: unwrap(c: insertBefore), argTypes: unwrappedArgs,
85 locs: unwrappedLocs));
86}
87
88MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter,
89 MlirOperation op) {
90 return wrap(cpp: unwrap(c: rewriter)->insert(op: unwrap(c: op)));
91}
92
93// Other methods of OpBuilder
94
95MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter,
96 MlirOperation op) {
97 return wrap(cpp: unwrap(c: rewriter)->clone(op&: *unwrap(c: op)));
98}
99
100MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter,
101 MlirOperation op) {
102 return wrap(cpp: unwrap(c: rewriter)->cloneWithoutRegions(op&: *unwrap(c: op)));
103}
104
105void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter,
106 MlirRegion region, MlirBlock before) {
107
108 unwrap(c: rewriter)->cloneRegionBefore(region&: *unwrap(c: region), before: unwrap(c: before));
109}
110
111//===----------------------------------------------------------------------===//
112/// RewriterBase API
113//===----------------------------------------------------------------------===//
114
115void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter,
116 MlirRegion region, MlirBlock before) {
117 unwrap(c: rewriter)->inlineRegionBefore(region&: *unwrap(c: region), before: unwrap(c: before));
118}
119
120void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter,
121 MlirOperation op, intptr_t nValues,
122 MlirValue const *values) {
123 SmallVector<Value, 4> vals;
124 ArrayRef<Value> unwrappedVals = unwrapList(size: nValues, first: values, storage&: vals);
125 unwrap(c: rewriter)->replaceOp(op: unwrap(c: op), newValues: unwrappedVals);
126}
127
128void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter,
129 MlirOperation op,
130 MlirOperation newOp) {
131 unwrap(c: rewriter)->replaceOp(op: unwrap(c: op), newOp: unwrap(c: newOp));
132}
133
134void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op) {
135 unwrap(c: rewriter)->eraseOp(op: unwrap(c: op));
136}
137
138void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, MlirBlock block) {
139 unwrap(c: rewriter)->eraseBlock(block: unwrap(c: block));
140}
141
142void mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter,
143 MlirBlock source, MlirOperation op,
144 intptr_t nArgValues,
145 MlirValue const *argValues) {
146 SmallVector<Value, 4> vals;
147 ArrayRef<Value> unwrappedVals = unwrapList(size: nArgValues, first: argValues, storage&: vals);
148
149 unwrap(c: rewriter)->inlineBlockBefore(source: unwrap(c: source), op: unwrap(c: op),
150 argValues: unwrappedVals);
151}
152
153void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, MlirBlock source,
154 MlirBlock dest, intptr_t nArgValues,
155 MlirValue const *argValues) {
156 SmallVector<Value, 4> args;
157 ArrayRef<Value> unwrappedArgs = unwrapList(size: nArgValues, first: argValues, storage&: args);
158 unwrap(c: rewriter)->mergeBlocks(source: unwrap(c: source), dest: unwrap(c: dest), argValues: unwrappedArgs);
159}
160
161void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, MlirOperation op,
162 MlirOperation existingOp) {
163 unwrap(c: rewriter)->moveOpBefore(op: unwrap(c: op), existingOp: unwrap(c: existingOp));
164}
165
166void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, MlirOperation op,
167 MlirOperation existingOp) {
168 unwrap(c: rewriter)->moveOpAfter(op: unwrap(c: op), existingOp: unwrap(c: existingOp));
169}
170
171void mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block,
172 MlirBlock existingBlock) {
173 unwrap(c: rewriter)->moveBlockBefore(block: unwrap(c: block), anotherBlock: unwrap(c: existingBlock));
174}
175
176void mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter,
177 MlirOperation op) {
178 unwrap(c: rewriter)->startOpModification(op: unwrap(c: op));
179}
180
181void mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter,
182 MlirOperation op) {
183 unwrap(c: rewriter)->finalizeOpModification(op: unwrap(c: op));
184}
185
186void mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter,
187 MlirOperation op) {
188 unwrap(c: rewriter)->cancelOpModification(op: unwrap(c: op));
189}
190
191void mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter,
192 MlirValue from, MlirValue to) {
193 unwrap(c: rewriter)->replaceAllUsesWith(from: unwrap(c: from), to: unwrap(c: to));
194}
195
196void mlirRewriterBaseReplaceAllValueRangeUsesWith(MlirRewriterBase rewriter,
197 intptr_t nValues,
198 MlirValue const *from,
199 MlirValue const *to) {
200 SmallVector<Value, 4> fromVals;
201 ArrayRef<Value> unwrappedFromVals = unwrapList(size: nValues, first: from, storage&: fromVals);
202 SmallVector<Value, 4> toVals;
203 ArrayRef<Value> unwrappedToVals = unwrapList(size: nValues, first: to, storage&: toVals);
204 unwrap(c: rewriter)->replaceAllUsesWith(from: unwrappedFromVals, to: unwrappedToVals);
205}
206
207void mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter,
208 MlirOperation from,
209 intptr_t nTo,
210 MlirValue const *to) {
211 SmallVector<Value, 4> toVals;
212 ArrayRef<Value> unwrappedToVals = unwrapList(size: nTo, first: to, storage&: toVals);
213 unwrap(c: rewriter)->replaceAllOpUsesWith(from: unwrap(c: from), to: unwrappedToVals);
214}
215
216void mlirRewriterBaseReplaceAllOpUsesWithOperation(MlirRewriterBase rewriter,
217 MlirOperation from,
218 MlirOperation to) {
219 unwrap(c: rewriter)->replaceAllOpUsesWith(from: unwrap(c: from), to: unwrap(c: to));
220}
221
222void mlirRewriterBaseReplaceOpUsesWithinBlock(MlirRewriterBase rewriter,
223 MlirOperation op,
224 intptr_t nNewValues,
225 MlirValue const *newValues,
226 MlirBlock block) {
227 SmallVector<Value, 4> vals;
228 ArrayRef<Value> unwrappedVals = unwrapList(size: nNewValues, first: newValues, storage&: vals);
229 unwrap(c: rewriter)->replaceOpUsesWithinBlock(op: unwrap(c: op), newValues: unwrappedVals,
230 block: unwrap(c: block));
231}
232
233void mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter,
234 MlirValue from, MlirValue to,
235 MlirOperation exceptedUser) {
236 unwrap(c: rewriter)->replaceAllUsesExcept(from: unwrap(c: from), to: unwrap(c: to),
237 exceptedUser: unwrap(c: exceptedUser));
238}
239
240//===----------------------------------------------------------------------===//
241/// IRRewriter API
242//===----------------------------------------------------------------------===//
243
244MlirRewriterBase mlirIRRewriterCreate(MlirContext context) {
245 return wrap(cpp: new IRRewriter(unwrap(c: context)));
246}
247
248MlirRewriterBase mlirIRRewriterCreateFromOp(MlirOperation op) {
249 return wrap(cpp: new IRRewriter(unwrap(c: op)));
250}
251
252void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
253 delete static_cast<IRRewriter *>(unwrap(c: rewriter));
254}
255
256//===----------------------------------------------------------------------===//
257/// RewritePatternSet and FrozenRewritePatternSet API
258//===----------------------------------------------------------------------===//
259
260inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
261 assert(module.ptr && "unexpected null module");
262 return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
263}
264
265inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
266 return {.ptr: module};
267}
268
269inline mlir::FrozenRewritePatternSet *
270unwrap(MlirFrozenRewritePatternSet module) {
271 assert(module.ptr && "unexpected null module");
272 return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr);
273}
274
275inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) {
276 return {.ptr: module};
277}
278
279MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
280 auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(module: op)));
281 op.ptr = nullptr;
282 return wrap(module: m);
283}
284
285void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
286 delete unwrap(module: op);
287 op.ptr = nullptr;
288}
289
290MlirLogicalResult
291mlirApplyPatternsAndFoldGreedily(MlirModule op,
292 MlirFrozenRewritePatternSet patterns,
293 MlirGreedyRewriteDriverConfig) {
294 return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(module: patterns)));
295}
296
297//===----------------------------------------------------------------------===//
298/// PDLPatternModule API
299//===----------------------------------------------------------------------===//
300
301#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
302inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
303 assert(module.ptr && "unexpected null module");
304 return static_cast<mlir::PDLPatternModule *>(module.ptr);
305}
306
307inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
308 return {.ptr: module};
309}
310
311MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) {
312 return wrap(module: new mlir::PDLPatternModule(
313 mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op))));
314}
315
316void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) {
317 delete unwrap(module: op);
318 op.ptr = nullptr;
319}
320
321MlirRewritePatternSet
322mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
323 auto *m = new mlir::RewritePatternSet(std::move(*unwrap(module: op)));
324 op.ptr = nullptr;
325 return wrap(module: m);
326}
327#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
328

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/CAPI/Transforms/Rewrite.cpp