1 | //===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir-c/Rewrite.h" |
10 | |
11 | #include "mlir-c/Transforms.h" |
12 | #include "mlir/CAPI/IR.h" |
13 | #include "mlir/CAPI/Rewrite.h" |
14 | #include "mlir/CAPI/Support.h" |
15 | #include "mlir/CAPI/Wrap.h" |
16 | #include "mlir/IR/PatternMatch.h" |
17 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
18 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
19 | |
20 | using namespace mlir; |
21 | |
22 | //===----------------------------------------------------------------------===// |
23 | /// RewriterBase API inherited from OpBuilder |
24 | //===----------------------------------------------------------------------===// |
25 | |
26 | MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) { |
27 | return wrap(cpp: unwrap(c: rewriter)->getContext()); |
28 | } |
29 | |
30 | //===----------------------------------------------------------------------===// |
31 | /// Insertion points methods |
32 | //===----------------------------------------------------------------------===// |
33 | |
34 | void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter) { |
35 | unwrap(c: rewriter)->clearInsertionPoint(); |
36 | } |
37 | |
38 | void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter, |
39 | MlirOperation op) { |
40 | unwrap(c: rewriter)->setInsertionPoint(unwrap(c: op)); |
41 | } |
42 | |
43 | void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter, |
44 | MlirOperation op) { |
45 | unwrap(c: rewriter)->setInsertionPointAfter(unwrap(c: op)); |
46 | } |
47 | |
48 | void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter, |
49 | MlirValue value) { |
50 | unwrap(c: rewriter)->setInsertionPointAfterValue(unwrap(c: value)); |
51 | } |
52 | |
53 | void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter, |
54 | MlirBlock block) { |
55 | unwrap(c: rewriter)->setInsertionPointToStart(unwrap(c: block)); |
56 | } |
57 | |
58 | void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter, |
59 | MlirBlock block) { |
60 | unwrap(c: rewriter)->setInsertionPointToEnd(unwrap(c: block)); |
61 | } |
62 | |
63 | MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter) { |
64 | return wrap(cpp: unwrap(c: rewriter)->getInsertionBlock()); |
65 | } |
66 | |
67 | MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) { |
68 | return wrap(cpp: unwrap(c: rewriter)->getBlock()); |
69 | } |
70 | |
71 | //===----------------------------------------------------------------------===// |
72 | /// Block and operation creation/insertion/cloning |
73 | //===----------------------------------------------------------------------===// |
74 | |
75 | MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter, |
76 | MlirBlock insertBefore, |
77 | intptr_t nArgTypes, |
78 | MlirType const *argTypes, |
79 | MlirLocation const *locations) { |
80 | SmallVector<Type, 4> args; |
81 | ArrayRef<Type> unwrappedArgs = unwrapList(size: nArgTypes, first: argTypes, storage&: args); |
82 | SmallVector<Location, 4> locs; |
83 | ArrayRef<Location> unwrappedLocs = unwrapList(size: nArgTypes, first: locations, storage&: locs); |
84 | return wrap(cpp: unwrap(c: rewriter)->createBlock(insertBefore: unwrap(c: insertBefore), argTypes: unwrappedArgs, |
85 | locs: unwrappedLocs)); |
86 | } |
87 | |
88 | MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter, |
89 | MlirOperation op) { |
90 | return wrap(cpp: unwrap(c: rewriter)->insert(op: unwrap(c: op))); |
91 | } |
92 | |
93 | // Other methods of OpBuilder |
94 | |
95 | MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter, |
96 | MlirOperation op) { |
97 | return wrap(cpp: unwrap(c: rewriter)->clone(op&: *unwrap(c: op))); |
98 | } |
99 | |
100 | MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter, |
101 | MlirOperation op) { |
102 | return wrap(cpp: unwrap(c: rewriter)->cloneWithoutRegions(op&: *unwrap(c: op))); |
103 | } |
104 | |
105 | void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, |
106 | MlirRegion region, MlirBlock before) { |
107 | |
108 | unwrap(c: rewriter)->cloneRegionBefore(region&: *unwrap(c: region), before: unwrap(c: before)); |
109 | } |
110 | |
111 | //===----------------------------------------------------------------------===// |
112 | /// RewriterBase API |
113 | //===----------------------------------------------------------------------===// |
114 | |
115 | void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter, |
116 | MlirRegion region, MlirBlock before) { |
117 | unwrap(c: rewriter)->inlineRegionBefore(region&: *unwrap(c: region), before: unwrap(c: before)); |
118 | } |
119 | |
120 | void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, |
121 | MlirOperation op, intptr_t nValues, |
122 | MlirValue const *values) { |
123 | SmallVector<Value, 4> vals; |
124 | ArrayRef<Value> unwrappedVals = unwrapList(size: nValues, first: values, storage&: vals); |
125 | unwrap(c: rewriter)->replaceOp(op: unwrap(c: op), newValues: unwrappedVals); |
126 | } |
127 | |
128 | void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter, |
129 | MlirOperation op, |
130 | MlirOperation newOp) { |
131 | unwrap(c: rewriter)->replaceOp(op: unwrap(c: op), newOp: unwrap(c: newOp)); |
132 | } |
133 | |
134 | void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op) { |
135 | unwrap(c: rewriter)->eraseOp(op: unwrap(c: op)); |
136 | } |
137 | |
138 | void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, MlirBlock block) { |
139 | unwrap(c: rewriter)->eraseBlock(block: unwrap(c: block)); |
140 | } |
141 | |
142 | void mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter, |
143 | MlirBlock source, MlirOperation op, |
144 | intptr_t nArgValues, |
145 | MlirValue const *argValues) { |
146 | SmallVector<Value, 4> vals; |
147 | ArrayRef<Value> unwrappedVals = unwrapList(size: nArgValues, first: argValues, storage&: vals); |
148 | |
149 | unwrap(c: rewriter)->inlineBlockBefore(source: unwrap(c: source), op: unwrap(c: op), |
150 | argValues: unwrappedVals); |
151 | } |
152 | |
153 | void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, MlirBlock source, |
154 | MlirBlock dest, intptr_t nArgValues, |
155 | MlirValue const *argValues) { |
156 | SmallVector<Value, 4> args; |
157 | ArrayRef<Value> unwrappedArgs = unwrapList(size: nArgValues, first: argValues, storage&: args); |
158 | unwrap(c: rewriter)->mergeBlocks(source: unwrap(c: source), dest: unwrap(c: dest), argValues: unwrappedArgs); |
159 | } |
160 | |
161 | void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, MlirOperation op, |
162 | MlirOperation existingOp) { |
163 | unwrap(c: rewriter)->moveOpBefore(op: unwrap(c: op), existingOp: unwrap(c: existingOp)); |
164 | } |
165 | |
166 | void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, MlirOperation op, |
167 | MlirOperation existingOp) { |
168 | unwrap(c: rewriter)->moveOpAfter(op: unwrap(c: op), existingOp: unwrap(c: existingOp)); |
169 | } |
170 | |
171 | void mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block, |
172 | MlirBlock existingBlock) { |
173 | unwrap(c: rewriter)->moveBlockBefore(block: unwrap(c: block), anotherBlock: unwrap(c: existingBlock)); |
174 | } |
175 | |
176 | void mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter, |
177 | MlirOperation op) { |
178 | unwrap(c: rewriter)->startOpModification(op: unwrap(c: op)); |
179 | } |
180 | |
181 | void mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter, |
182 | MlirOperation op) { |
183 | unwrap(c: rewriter)->finalizeOpModification(op: unwrap(c: op)); |
184 | } |
185 | |
186 | void mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter, |
187 | MlirOperation op) { |
188 | unwrap(c: rewriter)->cancelOpModification(op: unwrap(c: op)); |
189 | } |
190 | |
191 | void mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter, |
192 | MlirValue from, MlirValue to) { |
193 | unwrap(c: rewriter)->replaceAllUsesWith(from: unwrap(c: from), to: unwrap(c: to)); |
194 | } |
195 | |
196 | void mlirRewriterBaseReplaceAllValueRangeUsesWith(MlirRewriterBase rewriter, |
197 | intptr_t nValues, |
198 | MlirValue const *from, |
199 | MlirValue const *to) { |
200 | SmallVector<Value, 4> fromVals; |
201 | ArrayRef<Value> unwrappedFromVals = unwrapList(size: nValues, first: from, storage&: fromVals); |
202 | SmallVector<Value, 4> toVals; |
203 | ArrayRef<Value> unwrappedToVals = unwrapList(size: nValues, first: to, storage&: toVals); |
204 | unwrap(c: rewriter)->replaceAllUsesWith(from: unwrappedFromVals, to: unwrappedToVals); |
205 | } |
206 | |
207 | void mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter, |
208 | MlirOperation from, |
209 | intptr_t nTo, |
210 | MlirValue const *to) { |
211 | SmallVector<Value, 4> toVals; |
212 | ArrayRef<Value> unwrappedToVals = unwrapList(size: nTo, first: to, storage&: toVals); |
213 | unwrap(c: rewriter)->replaceAllOpUsesWith(from: unwrap(c: from), to: unwrappedToVals); |
214 | } |
215 | |
216 | void mlirRewriterBaseReplaceAllOpUsesWithOperation(MlirRewriterBase rewriter, |
217 | MlirOperation from, |
218 | MlirOperation to) { |
219 | unwrap(c: rewriter)->replaceAllOpUsesWith(from: unwrap(c: from), to: unwrap(c: to)); |
220 | } |
221 | |
222 | void mlirRewriterBaseReplaceOpUsesWithinBlock(MlirRewriterBase rewriter, |
223 | MlirOperation op, |
224 | intptr_t nNewValues, |
225 | MlirValue const *newValues, |
226 | MlirBlock block) { |
227 | SmallVector<Value, 4> vals; |
228 | ArrayRef<Value> unwrappedVals = unwrapList(size: nNewValues, first: newValues, storage&: vals); |
229 | unwrap(c: rewriter)->replaceOpUsesWithinBlock(op: unwrap(c: op), newValues: unwrappedVals, |
230 | block: unwrap(c: block)); |
231 | } |
232 | |
233 | void mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter, |
234 | MlirValue from, MlirValue to, |
235 | MlirOperation exceptedUser) { |
236 | unwrap(c: rewriter)->replaceAllUsesExcept(from: unwrap(c: from), to: unwrap(c: to), |
237 | exceptedUser: unwrap(c: exceptedUser)); |
238 | } |
239 | |
240 | //===----------------------------------------------------------------------===// |
241 | /// IRRewriter API |
242 | //===----------------------------------------------------------------------===// |
243 | |
244 | MlirRewriterBase mlirIRRewriterCreate(MlirContext context) { |
245 | return wrap(cpp: new IRRewriter(unwrap(c: context))); |
246 | } |
247 | |
248 | MlirRewriterBase mlirIRRewriterCreateFromOp(MlirOperation op) { |
249 | return wrap(cpp: new IRRewriter(unwrap(c: op))); |
250 | } |
251 | |
252 | void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { |
253 | delete static_cast<IRRewriter *>(unwrap(c: rewriter)); |
254 | } |
255 | |
256 | //===----------------------------------------------------------------------===// |
257 | /// RewritePatternSet and FrozenRewritePatternSet API |
258 | //===----------------------------------------------------------------------===// |
259 | |
260 | inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { |
261 | assert(module.ptr && "unexpected null module"); |
262 | return *(static_cast<mlir::RewritePatternSet *>(module.ptr)); |
263 | } |
264 | |
265 | inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { |
266 | return {.ptr: module}; |
267 | } |
268 | |
269 | inline mlir::FrozenRewritePatternSet * |
270 | unwrap(MlirFrozenRewritePatternSet module) { |
271 | assert(module.ptr && "unexpected null module"); |
272 | return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr); |
273 | } |
274 | |
275 | inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) { |
276 | return {.ptr: module}; |
277 | } |
278 | |
279 | MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) { |
280 | auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(module: op))); |
281 | op.ptr = nullptr; |
282 | return wrap(module: m); |
283 | } |
284 | |
285 | void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) { |
286 | delete unwrap(module: op); |
287 | op.ptr = nullptr; |
288 | } |
289 | |
290 | MlirLogicalResult |
291 | mlirApplyPatternsAndFoldGreedily(MlirModule op, |
292 | MlirFrozenRewritePatternSet patterns, |
293 | MlirGreedyRewriteDriverConfig) { |
294 | return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(module: patterns))); |
295 | } |
296 | |
297 | //===----------------------------------------------------------------------===// |
298 | /// PDLPatternModule API |
299 | //===----------------------------------------------------------------------===// |
300 | |
301 | #if MLIR_ENABLE_PDL_IN_PATTERNMATCH |
302 | inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { |
303 | assert(module.ptr && "unexpected null module"); |
304 | return static_cast<mlir::PDLPatternModule *>(module.ptr); |
305 | } |
306 | |
307 | inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { |
308 | return {.ptr: module}; |
309 | } |
310 | |
311 | MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) { |
312 | return wrap(module: new mlir::PDLPatternModule( |
313 | mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op)))); |
314 | } |
315 | |
316 | void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) { |
317 | delete unwrap(module: op); |
318 | op.ptr = nullptr; |
319 | } |
320 | |
321 | MlirRewritePatternSet |
322 | mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) { |
323 | auto *m = new mlir::RewritePatternSet(std::move(*unwrap(module: op))); |
324 | op.ptr = nullptr; |
325 | return wrap(module: m); |
326 | } |
327 | #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH |
328 |
Definitions
- mlirRewriterBaseGetContext
- mlirRewriterBaseClearInsertionPoint
- mlirRewriterBaseSetInsertionPointBefore
- mlirRewriterBaseSetInsertionPointAfter
- mlirRewriterBaseSetInsertionPointAfterValue
- mlirRewriterBaseSetInsertionPointToStart
- mlirRewriterBaseSetInsertionPointToEnd
- mlirRewriterBaseGetInsertionBlock
- mlirRewriterBaseGetBlock
- mlirRewriterBaseCreateBlockBefore
- mlirRewriterBaseInsert
- mlirRewriterBaseClone
- mlirRewriterBaseCloneWithoutRegions
- mlirRewriterBaseCloneRegionBefore
- mlirRewriterBaseInlineRegionBefore
- mlirRewriterBaseReplaceOpWithValues
- mlirRewriterBaseReplaceOpWithOperation
- mlirRewriterBaseEraseOp
- mlirRewriterBaseEraseBlock
- mlirRewriterBaseInlineBlockBefore
- mlirRewriterBaseMergeBlocks
- mlirRewriterBaseMoveOpBefore
- mlirRewriterBaseMoveOpAfter
- mlirRewriterBaseMoveBlockBefore
- mlirRewriterBaseStartOpModification
- mlirRewriterBaseFinalizeOpModification
- mlirRewriterBaseCancelOpModification
- mlirRewriterBaseReplaceAllUsesWith
- mlirRewriterBaseReplaceAllValueRangeUsesWith
- mlirRewriterBaseReplaceAllOpUsesWithValueRange
- mlirRewriterBaseReplaceAllOpUsesWithOperation
- mlirRewriterBaseReplaceOpUsesWithinBlock
- mlirRewriterBaseReplaceAllUsesExcept
- mlirIRRewriterCreate
- mlirIRRewriterCreateFromOp
- mlirIRRewriterDestroy
- unwrap
- wrap
- unwrap
- wrap
- mlirFreezeRewritePattern
- mlirFrozenRewritePatternSetDestroy
- mlirApplyPatternsAndFoldGreedily
- unwrap
- wrap
- mlirPDLPatternModuleFromModule
- mlirPDLPatternModuleDestroy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more