1//===- StructuralTypeConversions.cpp - scf structural type conversions ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SCF/IR/SCF.h"
10#include "mlir/Dialect/SCF/Transforms/Patterns.h"
11#include "mlir/Transforms/DialectConversion.h"
12#include <optional>
13
14using namespace mlir;
15using namespace mlir::scf;
16
17namespace {
18
19// Unpacks the single unrealized_conversion_cast using the list of inputs
20// e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d)
21static void unpackUnrealizedConversionCast(Value v,
22 SmallVectorImpl<Value> &unpacked) {
23 if (auto cast =
24 dyn_cast_or_null<UnrealizedConversionCastOp>(v.getDefiningOp())) {
25 if (cast.getInputs().size() != 1) {
26 // 1 : N type conversion.
27 unpacked.append(cast.getInputs().begin(), cast.getInputs().end());
28 return;
29 }
30 }
31 // 1 : 1 type conversion.
32 unpacked.push_back(Elt: v);
33}
34
35// CRTP
36// A base class that takes care of 1:N type conversion, which maps the converted
37// op results (computed by the derived class) and materializes 1:N conversion.
38template <typename SourceOp, typename ConcretePattern>
39class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
40public:
41 using OpConversionPattern<SourceOp>::typeConverter;
42 using OpConversionPattern<SourceOp>::OpConversionPattern;
43 using OpAdaptor = typename OpConversionPattern<SourceOp>::OpAdaptor;
44
45 //
46 // Derived classes should provide the following method which performs the
47 // actual conversion. It should return std::nullopt upon conversion failure
48 // and return the converted operation upon success.
49 //
50 // std::optional<SourceOp> convertSourceOp(SourceOp op, OpAdaptor adaptor,
51 // ConversionPatternRewriter &rewriter,
52 // TypeRange dstTypes) const;
53
54 LogicalResult
55 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
56 ConversionPatternRewriter &rewriter) const override {
57 SmallVector<Type> dstTypes;
58 SmallVector<unsigned> offsets;
59 offsets.push_back(Elt: 0);
60 // Do the type conversion and record the offsets.
61 for (Type type : op.getResultTypes()) {
62 if (failed(typeConverter->convertTypes(type, dstTypes)))
63 return rewriter.notifyMatchFailure(op, "could not convert result type");
64 offsets.push_back(Elt: dstTypes.size());
65 }
66
67 // Calls the actual converter implementation to convert the operation.
68 std::optional<SourceOp> newOp =
69 static_cast<const ConcretePattern *>(this)->convertSourceOp(
70 op, adaptor, rewriter, dstTypes);
71
72 if (!newOp)
73 return rewriter.notifyMatchFailure(op, "could not convert operation");
74
75 // Packs the return value.
76 SmallVector<Value> packedRets;
77 for (unsigned i = 1, e = offsets.size(); i < e; i++) {
78 unsigned start = offsets[i - 1], end = offsets[i];
79 unsigned len = end - start;
80 ValueRange mappedValue = newOp->getResults().slice(start, len);
81 if (len != 1) {
82 // 1 : N type conversion.
83 Type origType = op.getResultTypes()[i - 1];
84 Value mat = typeConverter->materializeSourceConversion(
85 rewriter, op.getLoc(), origType, mappedValue);
86 if (!mat) {
87 return rewriter.notifyMatchFailure(
88 op, "Failed to materialize 1:N type conversion");
89 }
90 packedRets.push_back(Elt: mat);
91 } else {
92 // 1 : 1 type conversion.
93 packedRets.push_back(Elt: mappedValue.front());
94 }
95 }
96
97 rewriter.replaceOp(op, packedRets);
98 return success();
99 }
100};
101
102class ConvertForOpTypes
103 : public Structural1ToNConversionPattern<ForOp, ConvertForOpTypes> {
104public:
105 using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
106
107 // The callback required by CRTP.
108 std::optional<ForOp> convertSourceOp(ForOp op, OpAdaptor adaptor,
109 ConversionPatternRewriter &rewriter,
110 TypeRange dstTypes) const {
111 // Create a empty new op and inline the regions from the old op.
112 //
113 // This is a little bit tricky. We have two concerns here:
114 //
115 // 1. We cannot update the op in place because the dialect conversion
116 // framework does not track type changes for ops updated in place, so it
117 // won't insert appropriate materializations on the changed result types.
118 // PR47938 tracks this issue, but it seems hard to fix. Instead, we need
119 // to clone the op.
120 //
121 // 2. We need to resue the original region instead of cloning it, otherwise
122 // the dialect conversion framework thinks that we just inserted all the
123 // cloned child ops. But what we want is to "take" the child regions and let
124 // the dialect conversion framework continue recursively into ops inside
125 // those regions (which are already in its worklist; inlining them into the
126 // new op's regions doesn't remove the child ops from the worklist).
127
128 // convertRegionTypes already takes care of 1:N conversion.
129 if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter)))
130 return std::nullopt;
131
132 // Unpacked the iteration arguments.
133 SmallVector<Value> flatArgs;
134 for (Value arg : adaptor.getInitArgs())
135 unpackUnrealizedConversionCast(arg, flatArgs);
136
137 // We can not do clone as the number of result types after conversion
138 // might be different.
139 ForOp newOp = rewriter.create<ForOp>(op.getLoc(), adaptor.getLowerBound(),
140 adaptor.getUpperBound(),
141 adaptor.getStep(), flatArgs);
142
143 // Reserve whatever attributes in the original op.
144 newOp->setAttrs(op->getAttrs());
145
146 // We do not need the empty block created by rewriter.
147 rewriter.eraseBlock(block: newOp.getBody(0));
148 // Inline the type converted region from the original operation.
149 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
150 newOp.getRegion().end());
151
152 return newOp;
153 }
154};
155} // namespace
156
157namespace {
158class ConvertIfOpTypes
159 : public Structural1ToNConversionPattern<IfOp, ConvertIfOpTypes> {
160public:
161 using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
162
163 std::optional<IfOp> convertSourceOp(IfOp op, OpAdaptor adaptor,
164 ConversionPatternRewriter &rewriter,
165 TypeRange dstTypes) const {
166
167 IfOp newOp = rewriter.create<IfOp>(op.getLoc(), dstTypes,
168 adaptor.getCondition(), true);
169 newOp->setAttrs(op->getAttrs());
170
171 // We do not need the empty blocks created by rewriter.
172 rewriter.eraseBlock(block: newOp.elseBlock());
173 rewriter.eraseBlock(block: newOp.thenBlock());
174
175 // Inlines block from the original operation.
176 rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
177 newOp.getThenRegion().end());
178 rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
179 newOp.getElseRegion().end());
180
181 return newOp;
182 }
183};
184} // namespace
185
186namespace {
187class ConvertWhileOpTypes
188 : public Structural1ToNConversionPattern<WhileOp, ConvertWhileOpTypes> {
189public:
190 using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
191
192 std::optional<WhileOp> convertSourceOp(WhileOp op, OpAdaptor adaptor,
193 ConversionPatternRewriter &rewriter,
194 TypeRange dstTypes) const {
195 // Unpacked the iteration arguments.
196 SmallVector<Value> flatArgs;
197 for (Value arg : adaptor.getOperands())
198 unpackUnrealizedConversionCast(arg, flatArgs);
199
200 auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, flatArgs);
201
202 for (auto i : {0u, 1u}) {
203 if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
204 return std::nullopt;
205 auto &dstRegion = newOp.getRegion(i);
206 rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
207 }
208 return newOp;
209 }
210};
211} // namespace
212
213namespace {
214// When the result types of a ForOp/IfOp get changed, the operand types of the
215// corresponding yield op need to be changed. In order to trigger the
216// appropriate type conversions / materializations, we need a dummy pattern.
217class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
218public:
219 using OpConversionPattern::OpConversionPattern;
220 LogicalResult
221 matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
222 ConversionPatternRewriter &rewriter) const override {
223 SmallVector<Value> unpackedYield;
224 for (Value operand : adaptor.getOperands())
225 unpackUnrealizedConversionCast(operand, unpackedYield);
226
227 rewriter.replaceOpWithNewOp<scf::YieldOp>(op, unpackedYield);
228 return success();
229 }
230};
231} // namespace
232
233namespace {
234class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
235public:
236 using OpConversionPattern<ConditionOp>::OpConversionPattern;
237 LogicalResult
238 matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
239 ConversionPatternRewriter &rewriter) const override {
240 SmallVector<Value> unpackedYield;
241 for (Value operand : adaptor.getOperands())
242 unpackUnrealizedConversionCast(operand, unpackedYield);
243
244 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(unpackedYield); });
245 return success();
246 }
247};
248} // namespace
249
250void mlir::scf::populateSCFStructuralTypeConversions(
251 TypeConverter &typeConverter, RewritePatternSet &patterns) {
252 patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
253 ConvertWhileOpTypes, ConvertConditionOpTypes>(
254 arg&: typeConverter, args: patterns.getContext());
255}
256
257void mlir::scf::populateSCFStructuralTypeConversionTarget(
258 const TypeConverter &typeConverter, ConversionTarget &target) {
259 target.addDynamicallyLegalOp<ForOp, IfOp>(callback: [&](Operation *op) {
260 return typeConverter.isLegal(range: op->getResultTypes());
261 });
262 target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
263 // We only have conversions for a subset of ops that use scf.yield
264 // terminators.
265 if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
266 return true;
267 return typeConverter.isLegal(op.getOperandTypes());
268 });
269 target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
270 [&](Operation *op) { return typeConverter.isLegal(op); });
271}
272
273void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
274 TypeConverter &typeConverter, RewritePatternSet &patterns,
275 ConversionTarget &target) {
276 populateSCFStructuralTypeConversions(typeConverter, patterns);
277 populateSCFStructuralTypeConversionTarget(typeConverter, target);
278}
279

source code of mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp