1 | //===- SCFToEmitC.cpp - SCF to EmitC conversion ---------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a pass to convert scf.if ops into emitc ops. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" |
14 | |
15 | #include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/EmitC/IR/EmitC.h" |
18 | #include "mlir/Dialect/EmitC/Transforms/TypeConversions.h" |
19 | #include "mlir/Dialect/SCF/IR/SCF.h" |
20 | #include "mlir/IR/Builders.h" |
21 | #include "mlir/IR/BuiltinOps.h" |
22 | #include "mlir/IR/IRMapping.h" |
23 | #include "mlir/IR/MLIRContext.h" |
24 | #include "mlir/IR/PatternMatch.h" |
25 | #include "mlir/Transforms/DialectConversion.h" |
26 | #include "mlir/Transforms/Passes.h" |
27 | |
28 | namespace mlir { |
29 | #define GEN_PASS_DEF_SCFTOEMITC |
30 | #include "mlir/Conversion/Passes.h.inc" |
31 | } // namespace mlir |
32 | |
33 | using namespace mlir; |
34 | using namespace mlir::scf; |
35 | |
36 | namespace { |
37 | |
38 | /// Implement the interface to convert SCF to EmitC. |
39 | struct SCFToEmitCDialectInterface : public ConvertToEmitCPatternInterface { |
40 | using ConvertToEmitCPatternInterface::ConvertToEmitCPatternInterface; |
41 | |
42 | /// Hook for derived dialect interface to provide conversion patterns |
43 | /// and mark dialect legal for the conversion target. |
44 | void populateConvertToEmitCConversionPatterns( |
45 | ConversionTarget &target, TypeConverter &typeConverter, |
46 | RewritePatternSet &patterns) const final { |
47 | populateEmitCSizeTTypeConversions(converter&: typeConverter); |
48 | populateSCFToEmitCConversionPatterns(patterns, typeConverter); |
49 | } |
50 | }; |
51 | } // namespace |
52 | |
53 | void mlir::registerConvertSCFToEmitCInterface(DialectRegistry ®istry) { |
54 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, scf::SCFDialect *dialect) { |
55 | dialect->addInterfaces<SCFToEmitCDialectInterface>(); |
56 | }); |
57 | } |
58 | |
59 | namespace { |
60 | |
61 | struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> { |
62 | void runOnOperation() override; |
63 | }; |
64 | |
65 | // Lower scf::for to emitc::for, implementing result values using |
66 | // emitc::variable's updated within the loop body. |
67 | struct ForLowering : public OpConversionPattern<ForOp> { |
68 | using OpConversionPattern<ForOp>::OpConversionPattern; |
69 | |
70 | LogicalResult |
71 | matchAndRewrite(ForOp forOp, OpAdaptor adaptor, |
72 | ConversionPatternRewriter &rewriter) const override; |
73 | }; |
74 | |
75 | // Create an uninitialized emitc::variable op for each result of the given op. |
76 | template <typename T> |
77 | static LogicalResult |
78 | createVariablesForResults(T op, const TypeConverter *typeConverter, |
79 | ConversionPatternRewriter &rewriter, |
80 | SmallVector<Value> &resultVariables) { |
81 | if (!op.getNumResults()) |
82 | return success(); |
83 | |
84 | Location loc = op->getLoc(); |
85 | MLIRContext *context = op.getContext(); |
86 | |
87 | OpBuilder::InsertionGuard guard(rewriter); |
88 | rewriter.setInsertionPoint(op); |
89 | |
90 | for (OpResult result : op.getResults()) { |
91 | Type resultType = typeConverter->convertType(t: result.getType()); |
92 | if (!resultType) |
93 | return rewriter.notifyMatchFailure(op, "result type conversion failed"); |
94 | Type varType = emitc::LValueType::get(resultType); |
95 | emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, ""); |
96 | emitc::VariableOp var = |
97 | rewriter.create<emitc::VariableOp>(loc, varType, noInit); |
98 | resultVariables.push_back(Elt: var); |
99 | } |
100 | |
101 | return success(); |
102 | } |
103 | |
104 | // Create a series of assign ops assigning given values to given variables at |
105 | // the current insertion point of given rewriter. |
106 | static void assignValues(ValueRange values, ValueRange variables, |
107 | ConversionPatternRewriter &rewriter, Location loc) { |
108 | for (auto [value, var] : llvm::zip(values, variables)) |
109 | rewriter.create<emitc::AssignOp>(loc, var, value); |
110 | } |
111 | |
112 | SmallVector<Value> loadValues(const SmallVector<Value> &variables, |
113 | PatternRewriter &rewriter, Location loc) { |
114 | return llvm::map_to_vector<>(variables, [&](Value var) { |
115 | Type type = cast<emitc::LValueType>(var.getType()).getValueType(); |
116 | return rewriter.create<emitc::LoadOp>(loc, type, var).getResult(); |
117 | }); |
118 | } |
119 | |
120 | static LogicalResult lowerYield(Operation *op, ValueRange resultVariables, |
121 | ConversionPatternRewriter &rewriter, |
122 | scf::YieldOp yield) { |
123 | Location loc = yield.getLoc(); |
124 | |
125 | OpBuilder::InsertionGuard guard(rewriter); |
126 | rewriter.setInsertionPoint(yield); |
127 | |
128 | SmallVector<Value> yieldOperands; |
129 | if (failed(rewriter.getRemappedValues(keys: yield.getOperands(), results&: yieldOperands))) { |
130 | return rewriter.notifyMatchFailure(arg&: op, msg: "failed to lower yield operands"); |
131 | } |
132 | |
133 | assignValues(values: yieldOperands, variables: resultVariables, rewriter, loc); |
134 | |
135 | rewriter.create<emitc::YieldOp>(loc); |
136 | rewriter.eraseOp(op: yield); |
137 | |
138 | return success(); |
139 | } |
140 | |
141 | // Lower the contents of an scf::if/scf::index_switch regions to an |
142 | // emitc::if/emitc::switch region. The contents of the lowering region is |
143 | // moved into the respective lowered region, but the scf::yield is replaced not |
144 | // only with an emitc::yield, but also with a sequence of emitc::assign ops that |
145 | // set the yielded values into the result variables. |
146 | static LogicalResult lowerRegion(Operation *op, ValueRange resultVariables, |
147 | ConversionPatternRewriter &rewriter, |
148 | Region ®ion, Region &loweredRegion) { |
149 | rewriter.inlineRegionBefore(region, parent&: loweredRegion, before: loweredRegion.end()); |
150 | Operation *terminator = loweredRegion.back().getTerminator(); |
151 | return lowerYield(op, resultVariables, rewriter, |
152 | cast<scf::YieldOp>(terminator)); |
153 | } |
154 | |
155 | LogicalResult |
156 | ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor, |
157 | ConversionPatternRewriter &rewriter) const { |
158 | Location loc = forOp.getLoc(); |
159 | |
160 | // Create an emitc::variable op for each result. These variables will be |
161 | // assigned to by emitc::assign ops within the loop body. |
162 | SmallVector<Value> resultVariables; |
163 | if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter, |
164 | resultVariables))) |
165 | return rewriter.notifyMatchFailure(forOp, |
166 | "create variables for results failed"); |
167 | |
168 | assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc); |
169 | |
170 | emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>( |
171 | loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep()); |
172 | |
173 | Block *loweredBody = loweredFor.getBody(); |
174 | |
175 | // Erase the auto-generated terminator for the lowered for op. |
176 | rewriter.eraseOp(op: loweredBody->getTerminator()); |
177 | |
178 | IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint(); |
179 | rewriter.setInsertionPointToEnd(loweredBody); |
180 | |
181 | SmallVector<Value> iterArgsValues = |
182 | loadValues(variables: resultVariables, rewriter, loc); |
183 | |
184 | rewriter.restoreInsertionPoint(ip); |
185 | |
186 | // Convert the original region types into the new types by adding unrealized |
187 | // casts in the beginning of the loop. This performs the conversion in place. |
188 | if (failed(rewriter.convertRegionTypes(region: &forOp.getRegion(), |
189 | converter: *getTypeConverter(), entryConversion: nullptr))) { |
190 | return rewriter.notifyMatchFailure(forOp, "region types conversion failed"); |
191 | } |
192 | |
193 | // Register the replacements for the block arguments and inline the body of |
194 | // the scf.for loop into the body of the emitc::for loop. |
195 | Block *scfBody = &(forOp.getRegion().front()); |
196 | SmallVector<Value> replacingValues; |
197 | replacingValues.push_back(Elt: loweredFor.getInductionVar()); |
198 | replacingValues.append(in_start: iterArgsValues.begin(), in_end: iterArgsValues.end()); |
199 | rewriter.mergeBlocks(source: scfBody, dest: loweredBody, argValues: replacingValues); |
200 | |
201 | auto result = lowerYield(forOp, resultVariables, rewriter, |
202 | cast<scf::YieldOp>(loweredBody->getTerminator())); |
203 | |
204 | if (failed(result)) { |
205 | return result; |
206 | } |
207 | |
208 | // Load variables into SSA values after the for loop. |
209 | SmallVector<Value> resultValues = loadValues(variables: resultVariables, rewriter, loc); |
210 | |
211 | rewriter.replaceOp(forOp, resultValues); |
212 | return success(); |
213 | } |
214 | |
215 | // Lower scf::if to emitc::if, implementing result values as emitc::variable's |
216 | // updated within the then and else regions. |
217 | struct IfLowering : public OpConversionPattern<IfOp> { |
218 | using OpConversionPattern<IfOp>::OpConversionPattern; |
219 | |
220 | LogicalResult |
221 | matchAndRewrite(IfOp ifOp, OpAdaptor adaptor, |
222 | ConversionPatternRewriter &rewriter) const override; |
223 | }; |
224 | |
225 | } // namespace |
226 | |
227 | LogicalResult |
228 | IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor, |
229 | ConversionPatternRewriter &rewriter) const { |
230 | Location loc = ifOp.getLoc(); |
231 | |
232 | // Create an emitc::variable op for each result. These variables will be |
233 | // assigned to by emitc::assign ops within the then & else regions. |
234 | SmallVector<Value> resultVariables; |
235 | if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter, |
236 | resultVariables))) |
237 | return rewriter.notifyMatchFailure(ifOp, |
238 | "create variables for results failed"); |
239 | |
240 | // Utility function to lower the contents of an scf::if region to an emitc::if |
241 | // region. The contents of the scf::if regions is moved into the respective |
242 | // emitc::if regions, but the scf::yield is replaced not only with an |
243 | // emitc::yield, but also with a sequence of emitc::assign ops that set the |
244 | // yielded values into the result variables. |
245 | auto lowerRegion = [&resultVariables, &rewriter, |
246 | &ifOp](Region ®ion, Region &loweredRegion) { |
247 | rewriter.inlineRegionBefore(region, parent&: loweredRegion, before: loweredRegion.end()); |
248 | Operation *terminator = loweredRegion.back().getTerminator(); |
249 | auto result = lowerYield(ifOp, resultVariables, rewriter, |
250 | cast<scf::YieldOp>(terminator)); |
251 | if (failed(result)) { |
252 | return result; |
253 | } |
254 | return success(); |
255 | }; |
256 | |
257 | Region &thenRegion = adaptor.getThenRegion(); |
258 | Region &elseRegion = adaptor.getElseRegion(); |
259 | |
260 | bool hasElseBlock = !elseRegion.empty(); |
261 | |
262 | auto loweredIf = |
263 | rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false); |
264 | |
265 | Region &loweredThenRegion = loweredIf.getThenRegion(); |
266 | auto result = lowerRegion(thenRegion, loweredThenRegion); |
267 | if (failed(result)) { |
268 | return result; |
269 | } |
270 | |
271 | if (hasElseBlock) { |
272 | Region &loweredElseRegion = loweredIf.getElseRegion(); |
273 | auto result = lowerRegion(elseRegion, loweredElseRegion); |
274 | if (failed(result)) { |
275 | return result; |
276 | } |
277 | } |
278 | |
279 | rewriter.setInsertionPointAfter(ifOp); |
280 | SmallVector<Value> results = loadValues(variables: resultVariables, rewriter, loc); |
281 | |
282 | rewriter.replaceOp(ifOp, results); |
283 | return success(); |
284 | } |
285 | |
286 | // Lower scf::index_switch to emitc::switch, implementing result values as |
287 | // emitc::variable's updated within the case and default regions. |
288 | struct IndexSwitchOpLowering : public OpConversionPattern<IndexSwitchOp> { |
289 | using OpConversionPattern::OpConversionPattern; |
290 | |
291 | LogicalResult |
292 | matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, |
293 | ConversionPatternRewriter &rewriter) const override; |
294 | }; |
295 | |
296 | LogicalResult IndexSwitchOpLowering::matchAndRewrite( |
297 | IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, |
298 | ConversionPatternRewriter &rewriter) const { |
299 | Location loc = indexSwitchOp.getLoc(); |
300 | |
301 | // Create an emitc::variable op for each result. These variables will be |
302 | // assigned to by emitc::assign ops within the case and default regions. |
303 | SmallVector<Value> resultVariables; |
304 | if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(), |
305 | rewriter, resultVariables))) { |
306 | return rewriter.notifyMatchFailure(indexSwitchOp, |
307 | "create variables for results failed"); |
308 | } |
309 | |
310 | auto loweredSwitch = rewriter.create<emitc::SwitchOp>( |
311 | loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases()); |
312 | |
313 | // Lowering all case regions. |
314 | for (auto pair : |
315 | llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) { |
316 | if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter, |
317 | *std::get<0>(pair), std::get<1>(pair)))) { |
318 | return failure(); |
319 | } |
320 | } |
321 | |
322 | // Lowering default region. |
323 | if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter, |
324 | adaptor.getDefaultRegion(), |
325 | loweredSwitch.getDefaultRegion()))) { |
326 | return failure(); |
327 | } |
328 | |
329 | rewriter.setInsertionPointAfter(indexSwitchOp); |
330 | SmallVector<Value> results = loadValues(variables: resultVariables, rewriter, loc); |
331 | |
332 | rewriter.replaceOp(indexSwitchOp, results); |
333 | return success(); |
334 | } |
335 | |
336 | void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, |
337 | TypeConverter &typeConverter) { |
338 | patterns.add<ForLowering>(arg&: typeConverter, args: patterns.getContext()); |
339 | patterns.add<IfLowering>(arg&: typeConverter, args: patterns.getContext()); |
340 | patterns.add<IndexSwitchOpLowering>(arg&: typeConverter, args: patterns.getContext()); |
341 | } |
342 | |
343 | void SCFToEmitCPass::runOnOperation() { |
344 | RewritePatternSet patterns(&getContext()); |
345 | TypeConverter typeConverter; |
346 | // Fallback for other types. |
347 | typeConverter.addConversion(callback: [](Type type) -> std::optional<Type> { |
348 | if (!emitc::isSupportedEmitCType(type)) |
349 | return {}; |
350 | return type; |
351 | }); |
352 | populateEmitCSizeTTypeConversions(converter&: typeConverter); |
353 | populateSCFToEmitCConversionPatterns(patterns, typeConverter); |
354 | |
355 | // Configure conversion to lower out SCF operations. |
356 | ConversionTarget target(getContext()); |
357 | target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(); |
358 | target.markUnknownOpDynamicallyLegal(fn: [](Operation *) { return true; }); |
359 | if (failed( |
360 | applyPartialConversion(getOperation(), target, std::move(patterns)))) |
361 | signalPassFailure(); |
362 | } |
363 |
Definitions
- SCFToEmitCDialectInterface
- populateConvertToEmitCConversionPatterns
- registerConvertSCFToEmitCInterface
- SCFToEmitCPass
- ForLowering
- createVariablesForResults
- assignValues
- loadValues
- lowerYield
- lowerRegion
- matchAndRewrite
- IfLowering
- matchAndRewrite
- IndexSwitchOpLowering
- matchAndRewrite
- populateSCFToEmitCConversionPatterns
Learn to use CMake with our Intro Training
Find out more