1//===- StructuralTypeConversions.cpp - scf structural type conversions ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SCF/IR/SCF.h"
10#include "mlir/Dialect/SCF/Transforms/Patterns.h"
11#include "mlir/Transforms/DialectConversion.h"
12#include <optional>
13
14using namespace mlir;
15using namespace mlir::scf;
16
17namespace {
18
19/// Flatten the given value ranges into a single vector of values.
20static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
21 SmallVector<Value> result;
22 for (const auto &vals : values)
23 llvm::append_range(C&: result, R: vals);
24 return result;
25}
26
27// CRTP
28// A base class that takes care of 1:N type conversion, which maps the converted
29// op results (computed by the derived class) and materializes 1:N conversion.
30template <typename SourceOp, typename ConcretePattern>
31class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
32public:
33 using OpConversionPattern<SourceOp>::typeConverter;
34 using OpConversionPattern<SourceOp>::OpConversionPattern;
35 using OneToNOpAdaptor =
36 typename OpConversionPattern<SourceOp>::OneToNOpAdaptor;
37
38 //
39 // Derived classes should provide the following method which performs the
40 // actual conversion. It should return std::nullopt upon conversion failure
41 // and return the converted operation upon success.
42 //
43 // std::optional<SourceOp> convertSourceOp(
44 // SourceOp op, OneToNOpAdaptor adaptor,
45 // ConversionPatternRewriter &rewriter,
46 // TypeRange dstTypes) const;
47
48 LogicalResult
49 matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
50 ConversionPatternRewriter &rewriter) const override {
51 SmallVector<Type> dstTypes;
52 SmallVector<unsigned> offsets;
53 offsets.push_back(Elt: 0);
54 // Do the type conversion and record the offsets.
55 for (Type type : op.getResultTypes()) {
56 if (failed(typeConverter->convertTypes(type, dstTypes)))
57 return rewriter.notifyMatchFailure(op, "could not convert result type");
58 offsets.push_back(Elt: dstTypes.size());
59 }
60
61 // Calls the actual converter implementation to convert the operation.
62 std::optional<SourceOp> newOp =
63 static_cast<const ConcretePattern *>(this)->convertSourceOp(
64 op, adaptor, rewriter, dstTypes);
65
66 if (!newOp)
67 return rewriter.notifyMatchFailure(op, "could not convert operation");
68
69 // Packs the return value.
70 SmallVector<ValueRange> packedRets;
71 for (unsigned i = 1, e = offsets.size(); i < e; i++) {
72 unsigned start = offsets[i - 1], end = offsets[i];
73 unsigned len = end - start;
74 ValueRange mappedValue = newOp->getResults().slice(start, len);
75 packedRets.push_back(Elt: mappedValue);
76 }
77
78 rewriter.replaceOpWithMultiple(op, packedRets);
79 return success();
80 }
81};
82
83class ConvertForOpTypes
84 : public Structural1ToNConversionPattern<ForOp, ConvertForOpTypes> {
85public:
86 using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
87
88 // The callback required by CRTP.
89 std::optional<ForOp> convertSourceOp(ForOp op, OneToNOpAdaptor adaptor,
90 ConversionPatternRewriter &rewriter,
91 TypeRange dstTypes) const {
92 // Create a empty new op and inline the regions from the old op.
93 //
94 // This is a little bit tricky. We have two concerns here:
95 //
96 // 1. We cannot update the op in place because the dialect conversion
97 // framework does not track type changes for ops updated in place, so it
98 // won't insert appropriate materializations on the changed result types.
99 // PR47938 tracks this issue, but it seems hard to fix. Instead, we need
100 // to clone the op.
101 //
102 // 2. We need to reuse the original region instead of cloning it, otherwise
103 // the dialect conversion framework thinks that we just inserted all the
104 // cloned child ops. But what we want is to "take" the child regions and let
105 // the dialect conversion framework continue recursively into ops inside
106 // those regions (which are already in its worklist; inlining them into the
107 // new op's regions doesn't remove the child ops from the worklist).
108
109 // convertRegionTypes already takes care of 1:N conversion.
110 if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter)))
111 return std::nullopt;
112
113 // We can not do clone as the number of result types after conversion
114 // might be different.
115 ForOp newOp = rewriter.create<ForOp>(
116 op.getLoc(), llvm::getSingleElement(adaptor.getLowerBound()),
117 llvm::getSingleElement(adaptor.getUpperBound()),
118 llvm::getSingleElement(adaptor.getStep()),
119 flattenValues(adaptor.getInitArgs()));
120
121 // Reserve whatever attributes in the original op.
122 newOp->setAttrs(op->getAttrs());
123
124 // We do not need the empty block created by rewriter.
125 rewriter.eraseBlock(block: newOp.getBody(0));
126 // Inline the type converted region from the original operation.
127 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
128 newOp.getRegion().end());
129
130 return newOp;
131 }
132};
133} // namespace
134
135namespace {
136class ConvertIfOpTypes
137 : public Structural1ToNConversionPattern<IfOp, ConvertIfOpTypes> {
138public:
139 using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
140
141 std::optional<IfOp> convertSourceOp(IfOp op, OneToNOpAdaptor adaptor,
142 ConversionPatternRewriter &rewriter,
143 TypeRange dstTypes) const {
144
145 IfOp newOp = rewriter.create<IfOp>(
146 op.getLoc(), dstTypes, llvm::getSingleElement(adaptor.getCondition()),
147 true);
148 newOp->setAttrs(op->getAttrs());
149
150 // We do not need the empty blocks created by rewriter.
151 rewriter.eraseBlock(block: newOp.elseBlock());
152 rewriter.eraseBlock(block: newOp.thenBlock());
153
154 // Inlines block from the original operation.
155 rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
156 newOp.getThenRegion().end());
157 rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
158 newOp.getElseRegion().end());
159
160 return newOp;
161 }
162};
163} // namespace
164
165namespace {
166class ConvertWhileOpTypes
167 : public Structural1ToNConversionPattern<WhileOp, ConvertWhileOpTypes> {
168public:
169 using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
170
171 std::optional<WhileOp> convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor,
172 ConversionPatternRewriter &rewriter,
173 TypeRange dstTypes) const {
174 auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes,
175 flattenValues(adaptor.getOperands()));
176
177 for (auto i : {0u, 1u}) {
178 if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
179 return std::nullopt;
180 auto &dstRegion = newOp.getRegion(i);
181 rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
182 }
183 return newOp;
184 }
185};
186} // namespace
187
188namespace {
189// When the result types of a ForOp/IfOp get changed, the operand types of the
190// corresponding yield op need to be changed. In order to trigger the
191// appropriate type conversions / materializations, we need a dummy pattern.
192class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
193public:
194 using OpConversionPattern::OpConversionPattern;
195 LogicalResult
196 matchAndRewrite(scf::YieldOp op, OneToNOpAdaptor adaptor,
197 ConversionPatternRewriter &rewriter) const override {
198 rewriter.replaceOpWithNewOp<scf::YieldOp>(
199 op, flattenValues(adaptor.getOperands()));
200 return success();
201 }
202};
203} // namespace
204
205namespace {
206class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
207public:
208 using OpConversionPattern<ConditionOp>::OpConversionPattern;
209 LogicalResult
210 matchAndRewrite(ConditionOp op, OneToNOpAdaptor adaptor,
211 ConversionPatternRewriter &rewriter) const override {
212 rewriter.modifyOpInPlace(
213 op, [&]() { op->setOperands(flattenValues(adaptor.getOperands())); });
214 return success();
215 }
216};
217} // namespace
218
219void mlir::scf::populateSCFStructuralTypeConversions(
220 const TypeConverter &typeConverter, RewritePatternSet &patterns) {
221 patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
222 ConvertWhileOpTypes, ConvertConditionOpTypes>(
223 arg: typeConverter, args: patterns.getContext());
224}
225
226void mlir::scf::populateSCFStructuralTypeConversionTarget(
227 const TypeConverter &typeConverter, ConversionTarget &target) {
228 target.addDynamicallyLegalOp<ForOp, IfOp>(callback: [&](Operation *op) {
229 return typeConverter.isLegal(range: op->getResultTypes());
230 });
231 target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
232 // We only have conversions for a subset of ops that use scf.yield
233 // terminators.
234 if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
235 return true;
236 return typeConverter.isLegal(op.getOperandTypes());
237 });
238 target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
239 [&](Operation *op) { return typeConverter.isLegal(op); });
240}
241
242void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
243 const TypeConverter &typeConverter, RewritePatternSet &patterns,
244 ConversionTarget &target) {
245 populateSCFStructuralTypeConversions(typeConverter, patterns);
246 populateSCFStructuralTypeConversionTarget(typeConverter, target);
247}
248

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp