1//===- SCFToEmitC.cpp - SCF to EmitC conversion ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert scf.if ops into emitc ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
14
15#include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/EmitC/IR/EmitC.h"
18#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinOps.h"
22#include "mlir/IR/IRMapping.h"
23#include "mlir/IR/MLIRContext.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/Transforms/DialectConversion.h"
26#include "mlir/Transforms/Passes.h"
27
28namespace mlir {
29#define GEN_PASS_DEF_SCFTOEMITC
30#include "mlir/Conversion/Passes.h.inc"
31} // namespace mlir
32
33using namespace mlir;
34using namespace mlir::scf;
35
36namespace {
37
38/// Implement the interface to convert SCF to EmitC.
39struct SCFToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
40 using ConvertToEmitCPatternInterface::ConvertToEmitCPatternInterface;
41
42 /// Hook for derived dialect interface to provide conversion patterns
43 /// and mark dialect legal for the conversion target.
44 void populateConvertToEmitCConversionPatterns(
45 ConversionTarget &target, TypeConverter &typeConverter,
46 RewritePatternSet &patterns) const final {
47 populateEmitCSizeTTypeConversions(converter&: typeConverter);
48 populateSCFToEmitCConversionPatterns(patterns, typeConverter);
49 }
50};
51} // namespace
52
53void mlir::registerConvertSCFToEmitCInterface(DialectRegistry &registry) {
54 registry.addExtension(extensionFn: +[](MLIRContext *ctx, scf::SCFDialect *dialect) {
55 dialect->addInterfaces<SCFToEmitCDialectInterface>();
56 });
57}
58
59namespace {
60
61struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
62 void runOnOperation() override;
63};
64
65// Lower scf::for to emitc::for, implementing result values using
66// emitc::variable's updated within the loop body.
67struct ForLowering : public OpConversionPattern<ForOp> {
68 using OpConversionPattern<ForOp>::OpConversionPattern;
69
70 LogicalResult
71 matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
72 ConversionPatternRewriter &rewriter) const override;
73};
74
75// Create an uninitialized emitc::variable op for each result of the given op.
76template <typename T>
77static LogicalResult
78createVariablesForResults(T op, const TypeConverter *typeConverter,
79 ConversionPatternRewriter &rewriter,
80 SmallVector<Value> &resultVariables) {
81 if (!op.getNumResults())
82 return success();
83
84 Location loc = op->getLoc();
85 MLIRContext *context = op.getContext();
86
87 OpBuilder::InsertionGuard guard(rewriter);
88 rewriter.setInsertionPoint(op);
89
90 for (OpResult result : op.getResults()) {
91 Type resultType = typeConverter->convertType(t: result.getType());
92 if (!resultType)
93 return rewriter.notifyMatchFailure(op, "result type conversion failed");
94 Type varType = emitc::LValueType::get(resultType);
95 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
96 emitc::VariableOp var =
97 rewriter.create<emitc::VariableOp>(loc, varType, noInit);
98 resultVariables.push_back(Elt: var);
99 }
100
101 return success();
102}
103
104// Create a series of assign ops assigning given values to given variables at
105// the current insertion point of given rewriter.
106static void assignValues(ValueRange values, ValueRange variables,
107 ConversionPatternRewriter &rewriter, Location loc) {
108 for (auto [value, var] : llvm::zip(values, variables))
109 rewriter.create<emitc::AssignOp>(loc, var, value);
110}
111
112SmallVector<Value> loadValues(const SmallVector<Value> &variables,
113 PatternRewriter &rewriter, Location loc) {
114 return llvm::map_to_vector<>(variables, [&](Value var) {
115 Type type = cast<emitc::LValueType>(var.getType()).getValueType();
116 return rewriter.create<emitc::LoadOp>(loc, type, var).getResult();
117 });
118}
119
120static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
121 ConversionPatternRewriter &rewriter,
122 scf::YieldOp yield) {
123 Location loc = yield.getLoc();
124
125 OpBuilder::InsertionGuard guard(rewriter);
126 rewriter.setInsertionPoint(yield);
127
128 SmallVector<Value> yieldOperands;
129 if (failed(rewriter.getRemappedValues(keys: yield.getOperands(), results&: yieldOperands))) {
130 return rewriter.notifyMatchFailure(arg&: op, msg: "failed to lower yield operands");
131 }
132
133 assignValues(values: yieldOperands, variables: resultVariables, rewriter, loc);
134
135 rewriter.create<emitc::YieldOp>(loc);
136 rewriter.eraseOp(op: yield);
137
138 return success();
139}
140
141// Lower the contents of an scf::if/scf::index_switch regions to an
142// emitc::if/emitc::switch region. The contents of the lowering region is
143// moved into the respective lowered region, but the scf::yield is replaced not
144// only with an emitc::yield, but also with a sequence of emitc::assign ops that
145// set the yielded values into the result variables.
146static LogicalResult lowerRegion(Operation *op, ValueRange resultVariables,
147 ConversionPatternRewriter &rewriter,
148 Region &region, Region &loweredRegion) {
149 rewriter.inlineRegionBefore(region, parent&: loweredRegion, before: loweredRegion.end());
150 Operation *terminator = loweredRegion.back().getTerminator();
151 return lowerYield(op, resultVariables, rewriter,
152 cast<scf::YieldOp>(terminator));
153}
154
155LogicalResult
156ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
157 ConversionPatternRewriter &rewriter) const {
158 Location loc = forOp.getLoc();
159
160 // Create an emitc::variable op for each result. These variables will be
161 // assigned to by emitc::assign ops within the loop body.
162 SmallVector<Value> resultVariables;
163 if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter,
164 resultVariables)))
165 return rewriter.notifyMatchFailure(forOp,
166 "create variables for results failed");
167
168 assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);
169
170 emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
171 loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep());
172
173 Block *loweredBody = loweredFor.getBody();
174
175 // Erase the auto-generated terminator for the lowered for op.
176 rewriter.eraseOp(op: loweredBody->getTerminator());
177
178 IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint();
179 rewriter.setInsertionPointToEnd(loweredBody);
180
181 SmallVector<Value> iterArgsValues =
182 loadValues(variables: resultVariables, rewriter, loc);
183
184 rewriter.restoreInsertionPoint(ip);
185
186 // Convert the original region types into the new types by adding unrealized
187 // casts in the beginning of the loop. This performs the conversion in place.
188 if (failed(rewriter.convertRegionTypes(region: &forOp.getRegion(),
189 converter: *getTypeConverter(), entryConversion: nullptr))) {
190 return rewriter.notifyMatchFailure(forOp, "region types conversion failed");
191 }
192
193 // Register the replacements for the block arguments and inline the body of
194 // the scf.for loop into the body of the emitc::for loop.
195 Block *scfBody = &(forOp.getRegion().front());
196 SmallVector<Value> replacingValues;
197 replacingValues.push_back(Elt: loweredFor.getInductionVar());
198 replacingValues.append(in_start: iterArgsValues.begin(), in_end: iterArgsValues.end());
199 rewriter.mergeBlocks(source: scfBody, dest: loweredBody, argValues: replacingValues);
200
201 auto result = lowerYield(forOp, resultVariables, rewriter,
202 cast<scf::YieldOp>(loweredBody->getTerminator()));
203
204 if (failed(result)) {
205 return result;
206 }
207
208 // Load variables into SSA values after the for loop.
209 SmallVector<Value> resultValues = loadValues(variables: resultVariables, rewriter, loc);
210
211 rewriter.replaceOp(forOp, resultValues);
212 return success();
213}
214
215// Lower scf::if to emitc::if, implementing result values as emitc::variable's
216// updated within the then and else regions.
217struct IfLowering : public OpConversionPattern<IfOp> {
218 using OpConversionPattern<IfOp>::OpConversionPattern;
219
220 LogicalResult
221 matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
222 ConversionPatternRewriter &rewriter) const override;
223};
224
225} // namespace
226
227LogicalResult
228IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
229 ConversionPatternRewriter &rewriter) const {
230 Location loc = ifOp.getLoc();
231
232 // Create an emitc::variable op for each result. These variables will be
233 // assigned to by emitc::assign ops within the then & else regions.
234 SmallVector<Value> resultVariables;
235 if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter,
236 resultVariables)))
237 return rewriter.notifyMatchFailure(ifOp,
238 "create variables for results failed");
239
240 // Utility function to lower the contents of an scf::if region to an emitc::if
241 // region. The contents of the scf::if regions is moved into the respective
242 // emitc::if regions, but the scf::yield is replaced not only with an
243 // emitc::yield, but also with a sequence of emitc::assign ops that set the
244 // yielded values into the result variables.
245 auto lowerRegion = [&resultVariables, &rewriter,
246 &ifOp](Region &region, Region &loweredRegion) {
247 rewriter.inlineRegionBefore(region, parent&: loweredRegion, before: loweredRegion.end());
248 Operation *terminator = loweredRegion.back().getTerminator();
249 auto result = lowerYield(ifOp, resultVariables, rewriter,
250 cast<scf::YieldOp>(terminator));
251 if (failed(result)) {
252 return result;
253 }
254 return success();
255 };
256
257 Region &thenRegion = adaptor.getThenRegion();
258 Region &elseRegion = adaptor.getElseRegion();
259
260 bool hasElseBlock = !elseRegion.empty();
261
262 auto loweredIf =
263 rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false);
264
265 Region &loweredThenRegion = loweredIf.getThenRegion();
266 auto result = lowerRegion(thenRegion, loweredThenRegion);
267 if (failed(result)) {
268 return result;
269 }
270
271 if (hasElseBlock) {
272 Region &loweredElseRegion = loweredIf.getElseRegion();
273 auto result = lowerRegion(elseRegion, loweredElseRegion);
274 if (failed(result)) {
275 return result;
276 }
277 }
278
279 rewriter.setInsertionPointAfter(ifOp);
280 SmallVector<Value> results = loadValues(variables: resultVariables, rewriter, loc);
281
282 rewriter.replaceOp(ifOp, results);
283 return success();
284}
285
286// Lower scf::index_switch to emitc::switch, implementing result values as
287// emitc::variable's updated within the case and default regions.
288struct IndexSwitchOpLowering : public OpConversionPattern<IndexSwitchOp> {
289 using OpConversionPattern::OpConversionPattern;
290
291 LogicalResult
292 matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
293 ConversionPatternRewriter &rewriter) const override;
294};
295
296LogicalResult IndexSwitchOpLowering::matchAndRewrite(
297 IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
298 ConversionPatternRewriter &rewriter) const {
299 Location loc = indexSwitchOp.getLoc();
300
301 // Create an emitc::variable op for each result. These variables will be
302 // assigned to by emitc::assign ops within the case and default regions.
303 SmallVector<Value> resultVariables;
304 if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(),
305 rewriter, resultVariables))) {
306 return rewriter.notifyMatchFailure(indexSwitchOp,
307 "create variables for results failed");
308 }
309
310 auto loweredSwitch = rewriter.create<emitc::SwitchOp>(
311 loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases());
312
313 // Lowering all case regions.
314 for (auto pair :
315 llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) {
316 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
317 *std::get<0>(pair), std::get<1>(pair)))) {
318 return failure();
319 }
320 }
321
322 // Lowering default region.
323 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
324 adaptor.getDefaultRegion(),
325 loweredSwitch.getDefaultRegion()))) {
326 return failure();
327 }
328
329 rewriter.setInsertionPointAfter(indexSwitchOp);
330 SmallVector<Value> results = loadValues(variables: resultVariables, rewriter, loc);
331
332 rewriter.replaceOp(indexSwitchOp, results);
333 return success();
334}
335
336void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
337 TypeConverter &typeConverter) {
338 patterns.add<ForLowering>(arg&: typeConverter, args: patterns.getContext());
339 patterns.add<IfLowering>(arg&: typeConverter, args: patterns.getContext());
340 patterns.add<IndexSwitchOpLowering>(arg&: typeConverter, args: patterns.getContext());
341}
342
343void SCFToEmitCPass::runOnOperation() {
344 RewritePatternSet patterns(&getContext());
345 TypeConverter typeConverter;
346 // Fallback for other types.
347 typeConverter.addConversion(callback: [](Type type) -> std::optional<Type> {
348 if (!emitc::isSupportedEmitCType(type))
349 return {};
350 return type;
351 });
352 populateEmitCSizeTTypeConversions(converter&: typeConverter);
353 populateSCFToEmitCConversionPatterns(patterns, typeConverter);
354
355 // Configure conversion to lower out SCF operations.
356 ConversionTarget target(getContext());
357 target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>();
358 target.markUnknownOpDynamicallyLegal(fn: [](Operation *) { return true; });
359 if (failed(
360 applyPartialConversion(getOperation(), target, std::move(patterns))))
361 signalPassFailure();
362}
363

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp