1 | //===- SCFToEmitC.cpp - SCF to EmitC conversion ---------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a pass to convert scf.if ops into emitc ops. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" |
14 | |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/EmitC/IR/EmitC.h" |
17 | #include "mlir/Dialect/SCF/IR/SCF.h" |
18 | #include "mlir/IR/Builders.h" |
19 | #include "mlir/IR/BuiltinOps.h" |
20 | #include "mlir/IR/IRMapping.h" |
21 | #include "mlir/IR/MLIRContext.h" |
22 | #include "mlir/IR/PatternMatch.h" |
23 | #include "mlir/Transforms/DialectConversion.h" |
24 | #include "mlir/Transforms/Passes.h" |
25 | |
26 | namespace mlir { |
27 | #define GEN_PASS_DEF_SCFTOEMITC |
28 | #include "mlir/Conversion/Passes.h.inc" |
29 | } // namespace mlir |
30 | |
31 | using namespace mlir; |
32 | using namespace mlir::scf; |
33 | |
34 | namespace { |
35 | |
36 | struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> { |
37 | void runOnOperation() override; |
38 | }; |
39 | |
40 | // Lower scf::for to emitc::for, implementing result values using |
41 | // emitc::variable's updated within the loop body. |
42 | struct ForLowering : public OpRewritePattern<ForOp> { |
43 | using OpRewritePattern<ForOp>::OpRewritePattern; |
44 | |
45 | LogicalResult matchAndRewrite(ForOp forOp, |
46 | PatternRewriter &rewriter) const override; |
47 | }; |
48 | |
49 | // Create an uninitialized emitc::variable op for each result of the given op. |
50 | template <typename T> |
51 | static SmallVector<Value> createVariablesForResults(T op, |
52 | PatternRewriter &rewriter) { |
53 | SmallVector<Value> resultVariables; |
54 | |
55 | if (!op.getNumResults()) |
56 | return resultVariables; |
57 | |
58 | Location loc = op->getLoc(); |
59 | MLIRContext *context = op.getContext(); |
60 | |
61 | OpBuilder::InsertionGuard guard(rewriter); |
62 | rewriter.setInsertionPoint(op); |
63 | |
64 | for (OpResult result : op.getResults()) { |
65 | Type resultType = result.getType(); |
66 | emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "" ); |
67 | emitc::VariableOp var = |
68 | rewriter.create<emitc::VariableOp>(loc, resultType, noInit); |
69 | resultVariables.push_back(var); |
70 | } |
71 | |
72 | return resultVariables; |
73 | } |
74 | |
75 | // Create a series of assign ops assigning given values to given variables at |
76 | // the current insertion point of given rewriter. |
77 | static void assignValues(ValueRange values, SmallVector<Value> &variables, |
78 | PatternRewriter &rewriter, Location loc) { |
79 | for (auto [value, var] : llvm::zip(values, variables)) |
80 | rewriter.create<emitc::AssignOp>(loc, var, value); |
81 | } |
82 | |
83 | static void lowerYield(SmallVector<Value> &resultVariables, |
84 | PatternRewriter &rewriter, scf::YieldOp yield) { |
85 | Location loc = yield.getLoc(); |
86 | ValueRange operands = yield.getOperands(); |
87 | |
88 | OpBuilder::InsertionGuard guard(rewriter); |
89 | rewriter.setInsertionPoint(yield); |
90 | |
91 | assignValues(operands, resultVariables, rewriter, loc); |
92 | |
93 | rewriter.create<emitc::YieldOp>(loc); |
94 | rewriter.eraseOp(op: yield); |
95 | } |
96 | |
97 | LogicalResult ForLowering::matchAndRewrite(ForOp forOp, |
98 | PatternRewriter &rewriter) const { |
99 | Location loc = forOp.getLoc(); |
100 | |
101 | // Create an emitc::variable op for each result. These variables will be |
102 | // assigned to by emitc::assign ops within the loop body. |
103 | SmallVector<Value> resultVariables = |
104 | createVariablesForResults(forOp, rewriter); |
105 | SmallVector<Value> iterArgsVariables = |
106 | createVariablesForResults(forOp, rewriter); |
107 | |
108 | assignValues(forOp.getInits(), iterArgsVariables, rewriter, loc); |
109 | |
110 | emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>( |
111 | loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()); |
112 | |
113 | Block *loweredBody = loweredFor.getBody(); |
114 | |
115 | // Erase the auto-generated terminator for the lowered for op. |
116 | rewriter.eraseOp(op: loweredBody->getTerminator()); |
117 | |
118 | SmallVector<Value> replacingValues; |
119 | replacingValues.push_back(loweredFor.getInductionVar()); |
120 | replacingValues.append(iterArgsVariables.begin(), iterArgsVariables.end()); |
121 | |
122 | rewriter.mergeBlocks(source: forOp.getBody(), dest: loweredBody, argValues: replacingValues); |
123 | lowerYield(iterArgsVariables, rewriter, |
124 | cast<scf::YieldOp>(loweredBody->getTerminator())); |
125 | |
126 | // Copy iterArgs into results after the for loop. |
127 | assignValues(iterArgsVariables, resultVariables, rewriter, loc); |
128 | |
129 | rewriter.replaceOp(forOp, resultVariables); |
130 | return success(); |
131 | } |
132 | |
133 | // Lower scf::if to emitc::if, implementing result values as emitc::variable's |
134 | // updated within the then and else regions. |
135 | struct IfLowering : public OpRewritePattern<IfOp> { |
136 | using OpRewritePattern<IfOp>::OpRewritePattern; |
137 | |
138 | LogicalResult matchAndRewrite(IfOp ifOp, |
139 | PatternRewriter &rewriter) const override; |
140 | }; |
141 | |
142 | } // namespace |
143 | |
144 | LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, |
145 | PatternRewriter &rewriter) const { |
146 | Location loc = ifOp.getLoc(); |
147 | |
148 | // Create an emitc::variable op for each result. These variables will be |
149 | // assigned to by emitc::assign ops within the then & else regions. |
150 | SmallVector<Value> resultVariables = |
151 | createVariablesForResults(ifOp, rewriter); |
152 | |
153 | // Utility function to lower the contents of an scf::if region to an emitc::if |
154 | // region. The contents of the scf::if regions is moved into the respective |
155 | // emitc::if regions, but the scf::yield is replaced not only with an |
156 | // emitc::yield, but also with a sequence of emitc::assign ops that set the |
157 | // yielded values into the result variables. |
158 | auto lowerRegion = [&resultVariables, &rewriter](Region ®ion, |
159 | Region &loweredRegion) { |
160 | rewriter.inlineRegionBefore(region, parent&: loweredRegion, before: loweredRegion.end()); |
161 | Operation *terminator = loweredRegion.back().getTerminator(); |
162 | lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator)); |
163 | }; |
164 | |
165 | Region &thenRegion = ifOp.getThenRegion(); |
166 | Region &elseRegion = ifOp.getElseRegion(); |
167 | |
168 | bool hasElseBlock = !elseRegion.empty(); |
169 | |
170 | auto loweredIf = |
171 | rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false); |
172 | |
173 | Region &loweredThenRegion = loweredIf.getThenRegion(); |
174 | lowerRegion(thenRegion, loweredThenRegion); |
175 | |
176 | if (hasElseBlock) { |
177 | Region &loweredElseRegion = loweredIf.getElseRegion(); |
178 | lowerRegion(elseRegion, loweredElseRegion); |
179 | } |
180 | |
181 | rewriter.replaceOp(ifOp, resultVariables); |
182 | return success(); |
183 | } |
184 | |
185 | void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) { |
186 | patterns.add<ForLowering>(arg: patterns.getContext()); |
187 | patterns.add<IfLowering>(arg: patterns.getContext()); |
188 | } |
189 | |
190 | void SCFToEmitCPass::runOnOperation() { |
191 | RewritePatternSet patterns(&getContext()); |
192 | populateSCFToEmitCConversionPatterns(patterns); |
193 | |
194 | // Configure conversion to lower out SCF operations. |
195 | ConversionTarget target(getContext()); |
196 | target.addIllegalOp<scf::ForOp, scf::IfOp>(); |
197 | target.markUnknownOpDynamicallyLegal(fn: [](Operation *) { return true; }); |
198 | if (failed( |
199 | applyPartialConversion(getOperation(), target, std::move(patterns)))) |
200 | signalPassFailure(); |
201 | } |
202 | |