1//===- SCFToEmitC.cpp - SCF to EmitC conversion ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert scf.if ops into emitc ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/EmitC/IR/EmitC.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/BuiltinOps.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/MLIRContext.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/Transforms/DialectConversion.h"
24#include "mlir/Transforms/Passes.h"
25
26namespace mlir {
27#define GEN_PASS_DEF_SCFTOEMITC
28#include "mlir/Conversion/Passes.h.inc"
29} // namespace mlir
30
31using namespace mlir;
32using namespace mlir::scf;
33
34namespace {
35
36struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
37 void runOnOperation() override;
38};
39
40// Lower scf::for to emitc::for, implementing result values using
41// emitc::variable's updated within the loop body.
42struct ForLowering : public OpRewritePattern<ForOp> {
43 using OpRewritePattern<ForOp>::OpRewritePattern;
44
45 LogicalResult matchAndRewrite(ForOp forOp,
46 PatternRewriter &rewriter) const override;
47};
48
49// Create an uninitialized emitc::variable op for each result of the given op.
50template <typename T>
51static SmallVector<Value> createVariablesForResults(T op,
52 PatternRewriter &rewriter) {
53 SmallVector<Value> resultVariables;
54
55 if (!op.getNumResults())
56 return resultVariables;
57
58 Location loc = op->getLoc();
59 MLIRContext *context = op.getContext();
60
61 OpBuilder::InsertionGuard guard(rewriter);
62 rewriter.setInsertionPoint(op);
63
64 for (OpResult result : op.getResults()) {
65 Type resultType = result.getType();
66 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
67 emitc::VariableOp var =
68 rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
69 resultVariables.push_back(var);
70 }
71
72 return resultVariables;
73}
74
75// Create a series of assign ops assigning given values to given variables at
76// the current insertion point of given rewriter.
77static void assignValues(ValueRange values, SmallVector<Value> &variables,
78 PatternRewriter &rewriter, Location loc) {
79 for (auto [value, var] : llvm::zip(values, variables))
80 rewriter.create<emitc::AssignOp>(loc, var, value);
81}
82
83static void lowerYield(SmallVector<Value> &resultVariables,
84 PatternRewriter &rewriter, scf::YieldOp yield) {
85 Location loc = yield.getLoc();
86 ValueRange operands = yield.getOperands();
87
88 OpBuilder::InsertionGuard guard(rewriter);
89 rewriter.setInsertionPoint(yield);
90
91 assignValues(operands, resultVariables, rewriter, loc);
92
93 rewriter.create<emitc::YieldOp>(loc);
94 rewriter.eraseOp(op: yield);
95}
96
97LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
98 PatternRewriter &rewriter) const {
99 Location loc = forOp.getLoc();
100
101 // Create an emitc::variable op for each result. These variables will be
102 // assigned to by emitc::assign ops within the loop body.
103 SmallVector<Value> resultVariables =
104 createVariablesForResults(forOp, rewriter);
105 SmallVector<Value> iterArgsVariables =
106 createVariablesForResults(forOp, rewriter);
107
108 assignValues(forOp.getInits(), iterArgsVariables, rewriter, loc);
109
110 emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
111 loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
112
113 Block *loweredBody = loweredFor.getBody();
114
115 // Erase the auto-generated terminator for the lowered for op.
116 rewriter.eraseOp(op: loweredBody->getTerminator());
117
118 SmallVector<Value> replacingValues;
119 replacingValues.push_back(loweredFor.getInductionVar());
120 replacingValues.append(iterArgsVariables.begin(), iterArgsVariables.end());
121
122 rewriter.mergeBlocks(source: forOp.getBody(), dest: loweredBody, argValues: replacingValues);
123 lowerYield(iterArgsVariables, rewriter,
124 cast<scf::YieldOp>(loweredBody->getTerminator()));
125
126 // Copy iterArgs into results after the for loop.
127 assignValues(iterArgsVariables, resultVariables, rewriter, loc);
128
129 rewriter.replaceOp(forOp, resultVariables);
130 return success();
131}
132
133// Lower scf::if to emitc::if, implementing result values as emitc::variable's
134// updated within the then and else regions.
135struct IfLowering : public OpRewritePattern<IfOp> {
136 using OpRewritePattern<IfOp>::OpRewritePattern;
137
138 LogicalResult matchAndRewrite(IfOp ifOp,
139 PatternRewriter &rewriter) const override;
140};
141
142} // namespace
143
144LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
145 PatternRewriter &rewriter) const {
146 Location loc = ifOp.getLoc();
147
148 // Create an emitc::variable op for each result. These variables will be
149 // assigned to by emitc::assign ops within the then & else regions.
150 SmallVector<Value> resultVariables =
151 createVariablesForResults(ifOp, rewriter);
152
153 // Utility function to lower the contents of an scf::if region to an emitc::if
154 // region. The contents of the scf::if regions is moved into the respective
155 // emitc::if regions, but the scf::yield is replaced not only with an
156 // emitc::yield, but also with a sequence of emitc::assign ops that set the
157 // yielded values into the result variables.
158 auto lowerRegion = [&resultVariables, &rewriter](Region &region,
159 Region &loweredRegion) {
160 rewriter.inlineRegionBefore(region, parent&: loweredRegion, before: loweredRegion.end());
161 Operation *terminator = loweredRegion.back().getTerminator();
162 lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
163 };
164
165 Region &thenRegion = ifOp.getThenRegion();
166 Region &elseRegion = ifOp.getElseRegion();
167
168 bool hasElseBlock = !elseRegion.empty();
169
170 auto loweredIf =
171 rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);
172
173 Region &loweredThenRegion = loweredIf.getThenRegion();
174 lowerRegion(thenRegion, loweredThenRegion);
175
176 if (hasElseBlock) {
177 Region &loweredElseRegion = loweredIf.getElseRegion();
178 lowerRegion(elseRegion, loweredElseRegion);
179 }
180
181 rewriter.replaceOp(ifOp, resultVariables);
182 return success();
183}
184
185void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
186 patterns.add<ForLowering>(arg: patterns.getContext());
187 patterns.add<IfLowering>(arg: patterns.getContext());
188}
189
190void SCFToEmitCPass::runOnOperation() {
191 RewritePatternSet patterns(&getContext());
192 populateSCFToEmitCConversionPatterns(patterns);
193
194 // Configure conversion to lower out SCF operations.
195 ConversionTarget target(getContext());
196 target.addIllegalOp<scf::ForOp, scf::IfOp>();
197 target.markUnknownOpDynamicallyLegal(fn: [](Operation *) { return true; });
198 if (failed(
199 applyPartialConversion(getOperation(), target, std::move(patterns))))
200 signalPassFailure();
201}
202

source code of mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp