1 | //===-- OneToNTypeConversion.cpp - SCF 1:N type conversion ------*- C++ -*-===// |
2 | // |
3 | // Licensed under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // The patterns in this file are heavily inspired (and copied from) |
10 | // lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp but work for 1:N |
11 | // type conversions. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
16 | |
17 | #include "mlir/Dialect/SCF/IR/SCF.h" |
18 | #include "mlir/Transforms/OneToNTypeConversion.h" |
19 | |
20 | using namespace mlir; |
21 | using namespace mlir::scf; |
22 | |
23 | class ConvertTypesInSCFIfOp : public OneToNOpConversionPattern<IfOp> { |
24 | public: |
25 | using OneToNOpConversionPattern<IfOp>::OneToNOpConversionPattern; |
26 | |
27 | LogicalResult |
28 | matchAndRewrite(IfOp op, OpAdaptor adaptor, |
29 | OneToNPatternRewriter &rewriter) const override { |
30 | Location loc = op->getLoc(); |
31 | const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); |
32 | |
33 | // Nothing to do if there is no non-identity conversion. |
34 | if (!resultMapping.hasNonIdentityConversion()) |
35 | return failure(); |
36 | |
37 | // Create new IfOp. |
38 | TypeRange convertedResultTypes = resultMapping.getConvertedTypes(); |
39 | auto newOp = rewriter.create<IfOp>(loc, convertedResultTypes, |
40 | op.getCondition(), true); |
41 | newOp->setAttrs(op->getAttrs()); |
42 | |
43 | // We do not need the empty blocks created by rewriter. |
44 | rewriter.eraseBlock(block: newOp.elseBlock()); |
45 | rewriter.eraseBlock(block: newOp.thenBlock()); |
46 | |
47 | // Inlines block from the original operation. |
48 | rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), |
49 | newOp.getThenRegion().end()); |
50 | rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), |
51 | newOp.getElseRegion().end()); |
52 | |
53 | rewriter.replaceOp(op, newOp->getResults(), resultMapping); |
54 | return success(); |
55 | } |
56 | }; |
57 | |
58 | class ConvertTypesInSCFWhileOp : public OneToNOpConversionPattern<WhileOp> { |
59 | public: |
60 | using OneToNOpConversionPattern<WhileOp>::OneToNOpConversionPattern; |
61 | |
62 | LogicalResult |
63 | matchAndRewrite(WhileOp op, OpAdaptor adaptor, |
64 | OneToNPatternRewriter &rewriter) const override { |
65 | Location loc = op->getLoc(); |
66 | |
67 | const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping(); |
68 | const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); |
69 | |
70 | // Nothing to do if the op doesn't have any non-identity conversions for its |
71 | // operands or results. |
72 | if (!operandMapping.hasNonIdentityConversion() && |
73 | !resultMapping.hasNonIdentityConversion()) |
74 | return failure(); |
75 | |
76 | // Create new WhileOp. |
77 | TypeRange convertedResultTypes = resultMapping.getConvertedTypes(); |
78 | |
79 | auto newOp = rewriter.create<WhileOp>(loc, convertedResultTypes, |
80 | adaptor.getFlatOperands()); |
81 | newOp->setAttrs(op->getAttrs()); |
82 | |
83 | // Update block signatures. |
84 | std::array<OneToNTypeMapping, 2> blockMappings = {operandMapping, |
85 | resultMapping}; |
86 | for (unsigned int i : {0u, 1u}) { |
87 | Region *region = &op.getRegion(i); |
88 | Block *block = ®ion->front(); |
89 | |
90 | rewriter.applySignatureConversion(block, argumentConversion&: blockMappings[i]); |
91 | |
92 | // Move updated region to new WhileOp. |
93 | Region &dstRegion = newOp.getRegion(i); |
94 | rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); |
95 | } |
96 | |
97 | rewriter.replaceOp(op, newOp->getResults(), resultMapping); |
98 | return success(); |
99 | } |
100 | }; |
101 | |
102 | class ConvertTypesInSCFYieldOp : public OneToNOpConversionPattern<YieldOp> { |
103 | public: |
104 | using OneToNOpConversionPattern<YieldOp>::OneToNOpConversionPattern; |
105 | |
106 | LogicalResult |
107 | matchAndRewrite(YieldOp op, OpAdaptor adaptor, |
108 | OneToNPatternRewriter &rewriter) const override { |
109 | // Nothing to do if there is no non-identity conversion. |
110 | if (!adaptor.getOperandMapping().hasNonIdentityConversion()) |
111 | return failure(); |
112 | |
113 | // Convert operands. |
114 | rewriter.modifyOpInPlace( |
115 | op, [&] { op->setOperands(adaptor.getFlatOperands()); }); |
116 | |
117 | return success(); |
118 | } |
119 | }; |
120 | |
121 | class ConvertTypesInSCFConditionOp |
122 | : public OneToNOpConversionPattern<ConditionOp> { |
123 | public: |
124 | using OneToNOpConversionPattern<ConditionOp>::OneToNOpConversionPattern; |
125 | |
126 | LogicalResult |
127 | matchAndRewrite(ConditionOp op, OpAdaptor adaptor, |
128 | OneToNPatternRewriter &rewriter) const override { |
129 | // Nothing to do if there is no non-identity conversion. |
130 | if (!adaptor.getOperandMapping().hasNonIdentityConversion()) |
131 | return failure(); |
132 | |
133 | // Convert operands. |
134 | rewriter.modifyOpInPlace( |
135 | op, [&] { op->setOperands(adaptor.getFlatOperands()); }); |
136 | |
137 | return success(); |
138 | } |
139 | }; |
140 | |
141 | class ConvertTypesInSCFForOp final : public OneToNOpConversionPattern<ForOp> { |
142 | public: |
143 | using OneToNOpConversionPattern<ForOp>::OneToNOpConversionPattern; |
144 | |
145 | LogicalResult |
146 | matchAndRewrite(ForOp forOp, OpAdaptor adaptor, |
147 | OneToNPatternRewriter &rewriter) const override { |
148 | const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping(); |
149 | const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); |
150 | |
151 | // Nothing to do if there is no non-identity conversion. |
152 | if (!operandMapping.hasNonIdentityConversion() && |
153 | !resultMapping.hasNonIdentityConversion()) |
154 | return failure(); |
155 | |
156 | // If the lower-bound, upper-bound, or step were expanded, abort the |
157 | // conversion. This conversion does not know what to do in such cases. |
158 | ValueRange lbs = adaptor.getLowerBound(); |
159 | ValueRange ubs = adaptor.getUpperBound(); |
160 | ValueRange steps = adaptor.getStep(); |
161 | if (lbs.size() != 1 || ubs.size() != 1 || steps.size() != 1) |
162 | return rewriter.notifyMatchFailure( |
163 | forOp, "index operands converted to multiple values" ); |
164 | |
165 | Location loc = forOp.getLoc(); |
166 | |
167 | Region *region = &forOp.getRegion(); |
168 | Block *block = ®ion->front(); |
169 | |
170 | // Construct the new for-op with an empty body. |
171 | ValueRange newInits = adaptor.getFlatOperands().drop_front(3); |
172 | auto newOp = |
173 | rewriter.create<ForOp>(loc, lbs[0], ubs[0], steps[0], newInits); |
174 | newOp->setAttrs(forOp->getAttrs()); |
175 | |
176 | // We do not need the empty blocks created by rewriter. |
177 | rewriter.eraseBlock(block: newOp.getBody()); |
178 | |
179 | // Convert the signature of the body region. |
180 | OneToNTypeMapping bodyTypeMapping(block->getArgumentTypes()); |
181 | if (failed(result: typeConverter->convertSignatureArgs(types: block->getArgumentTypes(), |
182 | result&: bodyTypeMapping))) |
183 | return failure(); |
184 | |
185 | // Perform signature conversion on the body block. |
186 | rewriter.applySignatureConversion(block, argumentConversion&: bodyTypeMapping); |
187 | |
188 | // Splice the old body region into the new for-op. |
189 | Region &dstRegion = newOp.getBodyRegion(); |
190 | rewriter.inlineRegionBefore(forOp.getRegion(), dstRegion, dstRegion.end()); |
191 | |
192 | rewriter.replaceOp(forOp, newOp.getResults(), resultMapping); |
193 | |
194 | return success(); |
195 | } |
196 | }; |
197 | |
198 | namespace mlir { |
199 | namespace scf { |
200 | |
201 | void populateSCFStructuralOneToNTypeConversions(TypeConverter &typeConverter, |
202 | RewritePatternSet &patterns) { |
203 | patterns.add< |
204 | // clang-format off |
205 | ConvertTypesInSCFConditionOp, |
206 | ConvertTypesInSCFForOp, |
207 | ConvertTypesInSCFIfOp, |
208 | ConvertTypesInSCFWhileOp, |
209 | ConvertTypesInSCFYieldOp |
210 | // clang-format on |
211 | >(arg&: typeConverter, args: patterns.getContext()); |
212 | } |
213 | |
214 | } // namespace scf |
215 | } // namespace mlir |
216 | |