1 | //===- StructuralTypeConversions.cpp - scf structural type conversions ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SCF/IR/SCF.h" |
10 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
11 | #include "mlir/Transforms/DialectConversion.h" |
12 | #include <optional> |
13 | |
14 | using namespace mlir; |
15 | using namespace mlir::scf; |
16 | |
17 | namespace { |
18 | |
19 | // Unpacks the single unrealized_conversion_cast using the list of inputs |
20 | // e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d) |
21 | static void unpackUnrealizedConversionCast(Value v, |
22 | SmallVectorImpl<Value> &unpacked) { |
23 | if (auto cast = |
24 | dyn_cast_or_null<UnrealizedConversionCastOp>(v.getDefiningOp())) { |
25 | if (cast.getInputs().size() != 1) { |
26 | // 1 : N type conversion. |
27 | unpacked.append(cast.getInputs().begin(), cast.getInputs().end()); |
28 | return; |
29 | } |
30 | } |
31 | // 1 : 1 type conversion. |
32 | unpacked.push_back(Elt: v); |
33 | } |
34 | |
35 | // CRTP |
36 | // A base class that takes care of 1:N type conversion, which maps the converted |
37 | // op results (computed by the derived class) and materializes 1:N conversion. |
38 | template <typename SourceOp, typename ConcretePattern> |
39 | class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> { |
40 | public: |
41 | using OpConversionPattern<SourceOp>::typeConverter; |
42 | using OpConversionPattern<SourceOp>::OpConversionPattern; |
43 | using OpAdaptor = typename OpConversionPattern<SourceOp>::OpAdaptor; |
44 | |
45 | // |
46 | // Derived classes should provide the following method which performs the |
47 | // actual conversion. It should return std::nullopt upon conversion failure |
48 | // and return the converted operation upon success. |
49 | // |
50 | // std::optional<SourceOp> convertSourceOp(SourceOp op, OpAdaptor adaptor, |
51 | // ConversionPatternRewriter &rewriter, |
52 | // TypeRange dstTypes) const; |
53 | |
54 | LogicalResult |
55 | matchAndRewrite(SourceOp op, OpAdaptor adaptor, |
56 | ConversionPatternRewriter &rewriter) const override { |
57 | SmallVector<Type> dstTypes; |
58 | SmallVector<unsigned> offsets; |
59 | offsets.push_back(Elt: 0); |
60 | // Do the type conversion and record the offsets. |
61 | for (Type type : op.getResultTypes()) { |
62 | if (failed(typeConverter->convertTypes(type, dstTypes))) |
63 | return rewriter.notifyMatchFailure(op, "could not convert result type" ); |
64 | offsets.push_back(Elt: dstTypes.size()); |
65 | } |
66 | |
67 | // Calls the actual converter implementation to convert the operation. |
68 | std::optional<SourceOp> newOp = |
69 | static_cast<const ConcretePattern *>(this)->convertSourceOp( |
70 | op, adaptor, rewriter, dstTypes); |
71 | |
72 | if (!newOp) |
73 | return rewriter.notifyMatchFailure(op, "could not convert operation" ); |
74 | |
75 | // Packs the return value. |
76 | SmallVector<Value> packedRets; |
77 | for (unsigned i = 1, e = offsets.size(); i < e; i++) { |
78 | unsigned start = offsets[i - 1], end = offsets[i]; |
79 | unsigned len = end - start; |
80 | ValueRange mappedValue = newOp->getResults().slice(start, len); |
81 | if (len != 1) { |
82 | // 1 : N type conversion. |
83 | Type origType = op.getResultTypes()[i - 1]; |
84 | Value mat = typeConverter->materializeSourceConversion( |
85 | rewriter, op.getLoc(), origType, mappedValue); |
86 | if (!mat) { |
87 | return rewriter.notifyMatchFailure( |
88 | op, "Failed to materialize 1:N type conversion" ); |
89 | } |
90 | packedRets.push_back(Elt: mat); |
91 | } else { |
92 | // 1 : 1 type conversion. |
93 | packedRets.push_back(Elt: mappedValue.front()); |
94 | } |
95 | } |
96 | |
97 | rewriter.replaceOp(op, packedRets); |
98 | return success(); |
99 | } |
100 | }; |
101 | |
102 | class ConvertForOpTypes |
103 | : public Structural1ToNConversionPattern<ForOp, ConvertForOpTypes> { |
104 | public: |
105 | using Structural1ToNConversionPattern::Structural1ToNConversionPattern; |
106 | |
107 | // The callback required by CRTP. |
108 | std::optional<ForOp> convertSourceOp(ForOp op, OpAdaptor adaptor, |
109 | ConversionPatternRewriter &rewriter, |
110 | TypeRange dstTypes) const { |
111 | // Create a empty new op and inline the regions from the old op. |
112 | // |
113 | // This is a little bit tricky. We have two concerns here: |
114 | // |
115 | // 1. We cannot update the op in place because the dialect conversion |
116 | // framework does not track type changes for ops updated in place, so it |
117 | // won't insert appropriate materializations on the changed result types. |
118 | // PR47938 tracks this issue, but it seems hard to fix. Instead, we need |
119 | // to clone the op. |
120 | // |
121 | // 2. We need to resue the original region instead of cloning it, otherwise |
122 | // the dialect conversion framework thinks that we just inserted all the |
123 | // cloned child ops. But what we want is to "take" the child regions and let |
124 | // the dialect conversion framework continue recursively into ops inside |
125 | // those regions (which are already in its worklist; inlining them into the |
126 | // new op's regions doesn't remove the child ops from the worklist). |
127 | |
128 | // convertRegionTypes already takes care of 1:N conversion. |
129 | if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter))) |
130 | return std::nullopt; |
131 | |
132 | // Unpacked the iteration arguments. |
133 | SmallVector<Value> flatArgs; |
134 | for (Value arg : adaptor.getInitArgs()) |
135 | unpackUnrealizedConversionCast(arg, flatArgs); |
136 | |
137 | // We can not do clone as the number of result types after conversion |
138 | // might be different. |
139 | ForOp newOp = rewriter.create<ForOp>(op.getLoc(), adaptor.getLowerBound(), |
140 | adaptor.getUpperBound(), |
141 | adaptor.getStep(), flatArgs); |
142 | |
143 | // Reserve whatever attributes in the original op. |
144 | newOp->setAttrs(op->getAttrs()); |
145 | |
146 | // We do not need the empty block created by rewriter. |
147 | rewriter.eraseBlock(block: newOp.getBody(0)); |
148 | // Inline the type converted region from the original operation. |
149 | rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), |
150 | newOp.getRegion().end()); |
151 | |
152 | return newOp; |
153 | } |
154 | }; |
155 | } // namespace |
156 | |
157 | namespace { |
158 | class ConvertIfOpTypes |
159 | : public Structural1ToNConversionPattern<IfOp, ConvertIfOpTypes> { |
160 | public: |
161 | using Structural1ToNConversionPattern::Structural1ToNConversionPattern; |
162 | |
163 | std::optional<IfOp> convertSourceOp(IfOp op, OpAdaptor adaptor, |
164 | ConversionPatternRewriter &rewriter, |
165 | TypeRange dstTypes) const { |
166 | |
167 | IfOp newOp = rewriter.create<IfOp>(op.getLoc(), dstTypes, |
168 | adaptor.getCondition(), true); |
169 | newOp->setAttrs(op->getAttrs()); |
170 | |
171 | // We do not need the empty blocks created by rewriter. |
172 | rewriter.eraseBlock(block: newOp.elseBlock()); |
173 | rewriter.eraseBlock(block: newOp.thenBlock()); |
174 | |
175 | // Inlines block from the original operation. |
176 | rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), |
177 | newOp.getThenRegion().end()); |
178 | rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), |
179 | newOp.getElseRegion().end()); |
180 | |
181 | return newOp; |
182 | } |
183 | }; |
184 | } // namespace |
185 | |
186 | namespace { |
187 | class ConvertWhileOpTypes |
188 | : public Structural1ToNConversionPattern<WhileOp, ConvertWhileOpTypes> { |
189 | public: |
190 | using Structural1ToNConversionPattern::Structural1ToNConversionPattern; |
191 | |
192 | std::optional<WhileOp> convertSourceOp(WhileOp op, OpAdaptor adaptor, |
193 | ConversionPatternRewriter &rewriter, |
194 | TypeRange dstTypes) const { |
195 | // Unpacked the iteration arguments. |
196 | SmallVector<Value> flatArgs; |
197 | for (Value arg : adaptor.getOperands()) |
198 | unpackUnrealizedConversionCast(arg, flatArgs); |
199 | |
200 | auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, flatArgs); |
201 | |
202 | for (auto i : {0u, 1u}) { |
203 | if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter))) |
204 | return std::nullopt; |
205 | auto &dstRegion = newOp.getRegion(i); |
206 | rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); |
207 | } |
208 | return newOp; |
209 | } |
210 | }; |
211 | } // namespace |
212 | |
213 | namespace { |
214 | // When the result types of a ForOp/IfOp get changed, the operand types of the |
215 | // corresponding yield op need to be changed. In order to trigger the |
216 | // appropriate type conversions / materializations, we need a dummy pattern. |
217 | class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> { |
218 | public: |
219 | using OpConversionPattern::OpConversionPattern; |
220 | LogicalResult |
221 | matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, |
222 | ConversionPatternRewriter &rewriter) const override { |
223 | SmallVector<Value> unpackedYield; |
224 | for (Value operand : adaptor.getOperands()) |
225 | unpackUnrealizedConversionCast(operand, unpackedYield); |
226 | |
227 | rewriter.replaceOpWithNewOp<scf::YieldOp>(op, unpackedYield); |
228 | return success(); |
229 | } |
230 | }; |
231 | } // namespace |
232 | |
233 | namespace { |
234 | class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> { |
235 | public: |
236 | using OpConversionPattern<ConditionOp>::OpConversionPattern; |
237 | LogicalResult |
238 | matchAndRewrite(ConditionOp op, OpAdaptor adaptor, |
239 | ConversionPatternRewriter &rewriter) const override { |
240 | SmallVector<Value> unpackedYield; |
241 | for (Value operand : adaptor.getOperands()) |
242 | unpackUnrealizedConversionCast(operand, unpackedYield); |
243 | |
244 | rewriter.modifyOpInPlace(op, [&]() { op->setOperands(unpackedYield); }); |
245 | return success(); |
246 | } |
247 | }; |
248 | } // namespace |
249 | |
250 | void mlir::scf::populateSCFStructuralTypeConversions( |
251 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
252 | patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes, |
253 | ConvertWhileOpTypes, ConvertConditionOpTypes>( |
254 | arg&: typeConverter, args: patterns.getContext()); |
255 | } |
256 | |
257 | void mlir::scf::populateSCFStructuralTypeConversionTarget( |
258 | const TypeConverter &typeConverter, ConversionTarget &target) { |
259 | target.addDynamicallyLegalOp<ForOp, IfOp>(callback: [&](Operation *op) { |
260 | return typeConverter.isLegal(range: op->getResultTypes()); |
261 | }); |
262 | target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) { |
263 | // We only have conversions for a subset of ops that use scf.yield |
264 | // terminators. |
265 | if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp())) |
266 | return true; |
267 | return typeConverter.isLegal(op.getOperandTypes()); |
268 | }); |
269 | target.addDynamicallyLegalOp<WhileOp, ConditionOp>( |
270 | [&](Operation *op) { return typeConverter.isLegal(op); }); |
271 | } |
272 | |
273 | void mlir::scf::populateSCFStructuralTypeConversionsAndLegality( |
274 | TypeConverter &typeConverter, RewritePatternSet &patterns, |
275 | ConversionTarget &target) { |
276 | populateSCFStructuralTypeConversions(typeConverter, patterns); |
277 | populateSCFStructuralTypeConversionTarget(typeConverter, target); |
278 | } |
279 | |