1//===-- OneToNTypeConversion.cpp - SCF 1:N type conversion ------*- C++ -*-===//
2//
3// Licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The patterns in this file are heavily inspired (and copied from)
10// lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp but work for 1:N
11// type conversions.
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/SCF/Transforms/Transforms.h"
16
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Transforms/OneToNTypeConversion.h"
19
20using namespace mlir;
21using namespace mlir::scf;
22
23class ConvertTypesInSCFIfOp : public OneToNOpConversionPattern<IfOp> {
24public:
25 using OneToNOpConversionPattern<IfOp>::OneToNOpConversionPattern;
26
27 LogicalResult
28 matchAndRewrite(IfOp op, OpAdaptor adaptor,
29 OneToNPatternRewriter &rewriter) const override {
30 Location loc = op->getLoc();
31 const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
32
33 // Nothing to do if there is no non-identity conversion.
34 if (!resultMapping.hasNonIdentityConversion())
35 return failure();
36
37 // Create new IfOp.
38 TypeRange convertedResultTypes = resultMapping.getConvertedTypes();
39 auto newOp = rewriter.create<IfOp>(loc, convertedResultTypes,
40 op.getCondition(), true);
41 newOp->setAttrs(op->getAttrs());
42
43 // We do not need the empty blocks created by rewriter.
44 rewriter.eraseBlock(block: newOp.elseBlock());
45 rewriter.eraseBlock(block: newOp.thenBlock());
46
47 // Inlines block from the original operation.
48 rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
49 newOp.getThenRegion().end());
50 rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
51 newOp.getElseRegion().end());
52
53 rewriter.replaceOp(op, newOp->getResults(), resultMapping);
54 return success();
55 }
56};
57
58class ConvertTypesInSCFWhileOp : public OneToNOpConversionPattern<WhileOp> {
59public:
60 using OneToNOpConversionPattern<WhileOp>::OneToNOpConversionPattern;
61
62 LogicalResult
63 matchAndRewrite(WhileOp op, OpAdaptor adaptor,
64 OneToNPatternRewriter &rewriter) const override {
65 Location loc = op->getLoc();
66
67 const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping();
68 const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
69
70 // Nothing to do if the op doesn't have any non-identity conversions for its
71 // operands or results.
72 if (!operandMapping.hasNonIdentityConversion() &&
73 !resultMapping.hasNonIdentityConversion())
74 return failure();
75
76 // Create new WhileOp.
77 TypeRange convertedResultTypes = resultMapping.getConvertedTypes();
78
79 auto newOp = rewriter.create<WhileOp>(loc, convertedResultTypes,
80 adaptor.getFlatOperands());
81 newOp->setAttrs(op->getAttrs());
82
83 // Update block signatures.
84 std::array<OneToNTypeMapping, 2> blockMappings = {operandMapping,
85 resultMapping};
86 for (unsigned int i : {0u, 1u}) {
87 Region *region = &op.getRegion(i);
88 Block *block = &region->front();
89
90 rewriter.applySignatureConversion(block, argumentConversion&: blockMappings[i]);
91
92 // Move updated region to new WhileOp.
93 Region &dstRegion = newOp.getRegion(i);
94 rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
95 }
96
97 rewriter.replaceOp(op, newOp->getResults(), resultMapping);
98 return success();
99 }
100};
101
102class ConvertTypesInSCFYieldOp : public OneToNOpConversionPattern<YieldOp> {
103public:
104 using OneToNOpConversionPattern<YieldOp>::OneToNOpConversionPattern;
105
106 LogicalResult
107 matchAndRewrite(YieldOp op, OpAdaptor adaptor,
108 OneToNPatternRewriter &rewriter) const override {
109 // Nothing to do if there is no non-identity conversion.
110 if (!adaptor.getOperandMapping().hasNonIdentityConversion())
111 return failure();
112
113 // Convert operands.
114 rewriter.modifyOpInPlace(
115 op, [&] { op->setOperands(adaptor.getFlatOperands()); });
116
117 return success();
118 }
119};
120
121class ConvertTypesInSCFConditionOp
122 : public OneToNOpConversionPattern<ConditionOp> {
123public:
124 using OneToNOpConversionPattern<ConditionOp>::OneToNOpConversionPattern;
125
126 LogicalResult
127 matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
128 OneToNPatternRewriter &rewriter) const override {
129 // Nothing to do if there is no non-identity conversion.
130 if (!adaptor.getOperandMapping().hasNonIdentityConversion())
131 return failure();
132
133 // Convert operands.
134 rewriter.modifyOpInPlace(
135 op, [&] { op->setOperands(adaptor.getFlatOperands()); });
136
137 return success();
138 }
139};
140
141class ConvertTypesInSCFForOp final : public OneToNOpConversionPattern<ForOp> {
142public:
143 using OneToNOpConversionPattern<ForOp>::OneToNOpConversionPattern;
144
145 LogicalResult
146 matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
147 OneToNPatternRewriter &rewriter) const override {
148 const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping();
149 const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
150
151 // Nothing to do if there is no non-identity conversion.
152 if (!operandMapping.hasNonIdentityConversion() &&
153 !resultMapping.hasNonIdentityConversion())
154 return failure();
155
156 // If the lower-bound, upper-bound, or step were expanded, abort the
157 // conversion. This conversion does not know what to do in such cases.
158 ValueRange lbs = adaptor.getLowerBound();
159 ValueRange ubs = adaptor.getUpperBound();
160 ValueRange steps = adaptor.getStep();
161 if (lbs.size() != 1 || ubs.size() != 1 || steps.size() != 1)
162 return rewriter.notifyMatchFailure(
163 forOp, "index operands converted to multiple values");
164
165 Location loc = forOp.getLoc();
166
167 Region *region = &forOp.getRegion();
168 Block *block = &region->front();
169
170 // Construct the new for-op with an empty body.
171 ValueRange newInits = adaptor.getFlatOperands().drop_front(3);
172 auto newOp =
173 rewriter.create<ForOp>(loc, lbs[0], ubs[0], steps[0], newInits);
174 newOp->setAttrs(forOp->getAttrs());
175
176 // We do not need the empty blocks created by rewriter.
177 rewriter.eraseBlock(block: newOp.getBody());
178
179 // Convert the signature of the body region.
180 OneToNTypeMapping bodyTypeMapping(block->getArgumentTypes());
181 if (failed(result: typeConverter->convertSignatureArgs(types: block->getArgumentTypes(),
182 result&: bodyTypeMapping)))
183 return failure();
184
185 // Perform signature conversion on the body block.
186 rewriter.applySignatureConversion(block, argumentConversion&: bodyTypeMapping);
187
188 // Splice the old body region into the new for-op.
189 Region &dstRegion = newOp.getBodyRegion();
190 rewriter.inlineRegionBefore(forOp.getRegion(), dstRegion, dstRegion.end());
191
192 rewriter.replaceOp(forOp, newOp.getResults(), resultMapping);
193
194 return success();
195 }
196};
197
198namespace mlir {
199namespace scf {
200
201void populateSCFStructuralOneToNTypeConversions(TypeConverter &typeConverter,
202 RewritePatternSet &patterns) {
203 patterns.add<
204 // clang-format off
205 ConvertTypesInSCFConditionOp,
206 ConvertTypesInSCFForOp,
207 ConvertTypesInSCFIfOp,
208 ConvertTypesInSCFWhileOp,
209 ConvertTypesInSCFYieldOp
210 // clang-format on
211 >(arg&: typeConverter, args: patterns.getContext());
212}
213
214} // namespace scf
215} // namespace mlir
216

source code of mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp