1//===- SCFToEmitC.cpp - SCF to EmitC conversion ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert scf.if ops into emitc ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
14
15#include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h"
16#include "mlir/Dialect/EmitC/IR/EmitC.h"
17#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
18#include "mlir/Dialect/SCF/IR/SCF.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/MLIRContext.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include "mlir/Transforms/Passes.h"
24
25namespace mlir {
26#define GEN_PASS_DEF_SCFTOEMITC
27#include "mlir/Conversion/Passes.h.inc"
28} // namespace mlir
29
30using namespace mlir;
31using namespace mlir::scf;
32
33namespace {
34
35/// Implement the interface to convert SCF to EmitC.
36struct SCFToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
37 using ConvertToEmitCPatternInterface::ConvertToEmitCPatternInterface;
38
39 /// Hook for derived dialect interface to provide conversion patterns
40 /// and mark dialect legal for the conversion target.
41 void populateConvertToEmitCConversionPatterns(
42 ConversionTarget &target, TypeConverter &typeConverter,
43 RewritePatternSet &patterns) const final {
44 populateEmitCSizeTTypeConversions(converter&: typeConverter);
45 populateSCFToEmitCConversionPatterns(patterns, typeConverter);
46 }
47};
48} // namespace
49
50void mlir::registerConvertSCFToEmitCInterface(DialectRegistry &registry) {
51 registry.addExtension(extensionFn: +[](MLIRContext *ctx, scf::SCFDialect *dialect) {
52 dialect->addInterfaces<SCFToEmitCDialectInterface>();
53 });
54}
55
56namespace {
57
58struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
59 void runOnOperation() override;
60};
61
62// Lower scf::for to emitc::for, implementing result values using
63// emitc::variable's updated within the loop body.
64struct ForLowering : public OpConversionPattern<ForOp> {
65 using OpConversionPattern<ForOp>::OpConversionPattern;
66
67 LogicalResult
68 matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
69 ConversionPatternRewriter &rewriter) const override;
70};
71
72// Create an uninitialized emitc::variable op for each result of the given op.
73template <typename T>
74static LogicalResult
75createVariablesForResults(T op, const TypeConverter *typeConverter,
76 ConversionPatternRewriter &rewriter,
77 SmallVector<Value> &resultVariables) {
78 if (!op.getNumResults())
79 return success();
80
81 Location loc = op->getLoc();
82 MLIRContext *context = op.getContext();
83
84 OpBuilder::InsertionGuard guard(rewriter);
85 rewriter.setInsertionPoint(op);
86
87 for (OpResult result : op.getResults()) {
88 Type resultType = typeConverter->convertType(t: result.getType());
89 if (!resultType)
90 return rewriter.notifyMatchFailure(op, "result type conversion failed");
91 Type varType = emitc::LValueType::get(valueType: resultType);
92 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, value: "");
93 emitc::VariableOp var =
94 rewriter.create<emitc::VariableOp>(location: loc, args&: varType, args&: noInit);
95 resultVariables.push_back(Elt: var);
96 }
97
98 return success();
99}
100
101// Create a series of assign ops assigning given values to given variables at
102// the current insertion point of given rewriter.
103static void assignValues(ValueRange values, ValueRange variables,
104 ConversionPatternRewriter &rewriter, Location loc) {
105 for (auto [value, var] : llvm::zip(t&: values, u&: variables))
106 rewriter.create<emitc::AssignOp>(location: loc, args&: var, args&: value);
107}
108
109SmallVector<Value> loadValues(const SmallVector<Value> &variables,
110 PatternRewriter &rewriter, Location loc) {
111 return llvm::map_to_vector<>(C: variables, F: [&](Value var) {
112 Type type = cast<emitc::LValueType>(Val: var.getType()).getValueType();
113 return rewriter.create<emitc::LoadOp>(location: loc, args&: type, args&: var).getResult();
114 });
115}
116
117static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
118 ConversionPatternRewriter &rewriter,
119 scf::YieldOp yield) {
120 Location loc = yield.getLoc();
121
122 OpBuilder::InsertionGuard guard(rewriter);
123 rewriter.setInsertionPoint(yield);
124
125 SmallVector<Value> yieldOperands;
126 if (failed(Result: rewriter.getRemappedValues(keys: yield.getOperands(), results&: yieldOperands))) {
127 return rewriter.notifyMatchFailure(arg&: op, msg: "failed to lower yield operands");
128 }
129
130 assignValues(values: yieldOperands, variables: resultVariables, rewriter, loc);
131
132 rewriter.create<emitc::YieldOp>(location: loc);
133 rewriter.eraseOp(op: yield);
134
135 return success();
136}
137
138// Lower the contents of an scf::if/scf::index_switch regions to an
139// emitc::if/emitc::switch region. The contents of the lowering region is
140// moved into the respective lowered region, but the scf::yield is replaced not
141// only with an emitc::yield, but also with a sequence of emitc::assign ops that
142// set the yielded values into the result variables.
143static LogicalResult lowerRegion(Operation *op, ValueRange resultVariables,
144 ConversionPatternRewriter &rewriter,
145 Region &region, Region &loweredRegion) {
146 rewriter.inlineRegionBefore(region, parent&: loweredRegion, before: loweredRegion.end());
147 Operation *terminator = loweredRegion.back().getTerminator();
148 return lowerYield(op, resultVariables, rewriter,
149 yield: cast<scf::YieldOp>(Val: terminator));
150}
151
152LogicalResult
153ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
154 ConversionPatternRewriter &rewriter) const {
155 Location loc = forOp.getLoc();
156
157 // Create an emitc::variable op for each result. These variables will be
158 // assigned to by emitc::assign ops within the loop body.
159 SmallVector<Value> resultVariables;
160 if (failed(Result: createVariablesForResults(op: forOp, typeConverter: getTypeConverter(), rewriter,
161 resultVariables)))
162 return rewriter.notifyMatchFailure(arg&: forOp,
163 msg: "create variables for results failed");
164
165 assignValues(values: adaptor.getInitArgs(), variables: resultVariables, rewriter, loc);
166
167 emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
168 location: loc, args: adaptor.getLowerBound(), args: adaptor.getUpperBound(), args: adaptor.getStep());
169
170 Block *loweredBody = loweredFor.getBody();
171
172 // Erase the auto-generated terminator for the lowered for op.
173 rewriter.eraseOp(op: loweredBody->getTerminator());
174
175 IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint();
176 rewriter.setInsertionPointToEnd(loweredBody);
177
178 SmallVector<Value> iterArgsValues =
179 loadValues(variables: resultVariables, rewriter, loc);
180
181 rewriter.restoreInsertionPoint(ip);
182
183 // Convert the original region types into the new types by adding unrealized
184 // casts in the beginning of the loop. This performs the conversion in place.
185 if (failed(Result: rewriter.convertRegionTypes(region: &forOp.getRegion(),
186 converter: *getTypeConverter(), entryConversion: nullptr))) {
187 return rewriter.notifyMatchFailure(arg&: forOp, msg: "region types conversion failed");
188 }
189
190 // Register the replacements for the block arguments and inline the body of
191 // the scf.for loop into the body of the emitc::for loop.
192 Block *scfBody = &(forOp.getRegion().front());
193 SmallVector<Value> replacingValues;
194 replacingValues.push_back(Elt: loweredFor.getInductionVar());
195 replacingValues.append(in_start: iterArgsValues.begin(), in_end: iterArgsValues.end());
196 rewriter.mergeBlocks(source: scfBody, dest: loweredBody, argValues: replacingValues);
197
198 auto result = lowerYield(op: forOp, resultVariables, rewriter,
199 yield: cast<scf::YieldOp>(Val: loweredBody->getTerminator()));
200
201 if (failed(Result: result)) {
202 return result;
203 }
204
205 // Load variables into SSA values after the for loop.
206 SmallVector<Value> resultValues = loadValues(variables: resultVariables, rewriter, loc);
207
208 rewriter.replaceOp(op: forOp, newValues: resultValues);
209 return success();
210}
211
212// Lower scf::if to emitc::if, implementing result values as emitc::variable's
213// updated within the then and else regions.
214struct IfLowering : public OpConversionPattern<IfOp> {
215 using OpConversionPattern<IfOp>::OpConversionPattern;
216
217 LogicalResult
218 matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
219 ConversionPatternRewriter &rewriter) const override;
220};
221
222} // namespace
223
224LogicalResult
225IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
226 ConversionPatternRewriter &rewriter) const {
227 Location loc = ifOp.getLoc();
228
229 // Create an emitc::variable op for each result. These variables will be
230 // assigned to by emitc::assign ops within the then & else regions.
231 SmallVector<Value> resultVariables;
232 if (failed(Result: createVariablesForResults(op: ifOp, typeConverter: getTypeConverter(), rewriter,
233 resultVariables)))
234 return rewriter.notifyMatchFailure(arg&: ifOp,
235 msg: "create variables for results failed");
236
237 // Utility function to lower the contents of an scf::if region to an emitc::if
238 // region. The contents of the scf::if regions is moved into the respective
239 // emitc::if regions, but the scf::yield is replaced not only with an
240 // emitc::yield, but also with a sequence of emitc::assign ops that set the
241 // yielded values into the result variables.
242 auto lowerRegion = [&resultVariables, &rewriter,
243 &ifOp](Region &region, Region &loweredRegion) {
244 rewriter.inlineRegionBefore(region, parent&: loweredRegion, before: loweredRegion.end());
245 Operation *terminator = loweredRegion.back().getTerminator();
246 auto result = lowerYield(op: ifOp, resultVariables, rewriter,
247 yield: cast<scf::YieldOp>(Val: terminator));
248 if (failed(Result: result)) {
249 return result;
250 }
251 return success();
252 };
253
254 Region &thenRegion = adaptor.getThenRegion();
255 Region &elseRegion = adaptor.getElseRegion();
256
257 bool hasElseBlock = !elseRegion.empty();
258
259 auto loweredIf =
260 rewriter.create<emitc::IfOp>(location: loc, args: adaptor.getCondition(), args: false, args: false);
261
262 Region &loweredThenRegion = loweredIf.getThenRegion();
263 auto result = lowerRegion(thenRegion, loweredThenRegion);
264 if (failed(Result: result)) {
265 return result;
266 }
267
268 if (hasElseBlock) {
269 Region &loweredElseRegion = loweredIf.getElseRegion();
270 auto result = lowerRegion(elseRegion, loweredElseRegion);
271 if (failed(Result: result)) {
272 return result;
273 }
274 }
275
276 rewriter.setInsertionPointAfter(ifOp);
277 SmallVector<Value> results = loadValues(variables: resultVariables, rewriter, loc);
278
279 rewriter.replaceOp(op: ifOp, newValues: results);
280 return success();
281}
282
283// Lower scf::index_switch to emitc::switch, implementing result values as
284// emitc::variable's updated within the case and default regions.
285struct IndexSwitchOpLowering : public OpConversionPattern<IndexSwitchOp> {
286 using OpConversionPattern::OpConversionPattern;
287
288 LogicalResult
289 matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
290 ConversionPatternRewriter &rewriter) const override;
291};
292
293LogicalResult IndexSwitchOpLowering::matchAndRewrite(
294 IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
295 ConversionPatternRewriter &rewriter) const {
296 Location loc = indexSwitchOp.getLoc();
297
298 // Create an emitc::variable op for each result. These variables will be
299 // assigned to by emitc::assign ops within the case and default regions.
300 SmallVector<Value> resultVariables;
301 if (failed(Result: createVariablesForResults(op: indexSwitchOp, typeConverter: getTypeConverter(),
302 rewriter, resultVariables))) {
303 return rewriter.notifyMatchFailure(arg&: indexSwitchOp,
304 msg: "create variables for results failed");
305 }
306
307 auto loweredSwitch = rewriter.create<emitc::SwitchOp>(
308 location: loc, args: adaptor.getArg(), args: adaptor.getCases(), args: indexSwitchOp.getNumCases());
309
310 // Lowering all case regions.
311 for (auto pair :
312 llvm::zip(t: adaptor.getCaseRegions(), u: loweredSwitch.getCaseRegions())) {
313 if (failed(Result: lowerRegion(op: indexSwitchOp, resultVariables, rewriter,
314 region&: *std::get<0>(t&: pair), loweredRegion&: std::get<1>(t&: pair)))) {
315 return failure();
316 }
317 }
318
319 // Lowering default region.
320 if (failed(Result: lowerRegion(op: indexSwitchOp, resultVariables, rewriter,
321 region&: adaptor.getDefaultRegion(),
322 loweredRegion&: loweredSwitch.getDefaultRegion()))) {
323 return failure();
324 }
325
326 rewriter.setInsertionPointAfter(indexSwitchOp);
327 SmallVector<Value> results = loadValues(variables: resultVariables, rewriter, loc);
328
329 rewriter.replaceOp(op: indexSwitchOp, newValues: results);
330 return success();
331}
332
333void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
334 TypeConverter &typeConverter) {
335 patterns.add<ForLowering>(arg&: typeConverter, args: patterns.getContext());
336 patterns.add<IfLowering>(arg&: typeConverter, args: patterns.getContext());
337 patterns.add<IndexSwitchOpLowering>(arg&: typeConverter, args: patterns.getContext());
338}
339
340void SCFToEmitCPass::runOnOperation() {
341 RewritePatternSet patterns(&getContext());
342 TypeConverter typeConverter;
343 // Fallback for other types.
344 typeConverter.addConversion(callback: [](Type type) -> std::optional<Type> {
345 if (!emitc::isSupportedEmitCType(type))
346 return {};
347 return type;
348 });
349 populateEmitCSizeTTypeConversions(converter&: typeConverter);
350 populateSCFToEmitCConversionPatterns(patterns, typeConverter);
351
352 // Configure conversion to lower out SCF operations.
353 ConversionTarget target(getContext());
354 target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>();
355 target.markUnknownOpDynamicallyLegal(fn: [](Operation *) { return true; });
356 if (failed(
357 Result: applyPartialConversion(op: getOperation(), target, patterns: std::move(patterns))))
358 signalPassFailure();
359}
360

source code of mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp