1 | //===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===// |
2 | // |
3 | // Licensed under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file provides utils for implementing (poor-man's) dialect conversion |
10 | // passes with 1:N type conversions. |
11 | // |
12 | // The main function, `applyPartialOneToNConversion`, first applies a set of |
13 | // `RewritePattern`s, which produce unrealized casts to convert the operands and |
14 | // results from and to the source types, and then replaces all newly added |
15 | // unrealized casts by user-provided materializations. For this to work, the |
16 | // main function requires a special `TypeConverter`, a special |
17 | // `PatternRewriter`, and special RewritePattern`s, which extend their |
18 | // respective base classes for 1:N type converions. |
19 | // |
20 | // Note that this is much more simple-minded than the "real" dialect conversion, |
21 | // which checks for legality before applying patterns and does probably many |
22 | // other additional things. Ideally, some of the extensions here could be |
23 | // integrated there. |
24 | // |
25 | //===----------------------------------------------------------------------===// |
26 | |
27 | #ifndef MLIR_TRANSFORMS_ONETONTYPECONVERSION_H |
28 | #define MLIR_TRANSFORMS_ONETONTYPECONVERSION_H |
29 | |
30 | #include "mlir/IR/PatternMatch.h" |
31 | #include "mlir/Transforms/DialectConversion.h" |
32 | #include "llvm/ADT/SmallVector.h" |
33 | |
34 | namespace mlir { |
35 | |
36 | /// Extends `TypeConverter` with 1:N target materializations. Such |
37 | /// materializations have to provide the "reverse" of 1:N type conversions, |
38 | /// i.e., they need to materialize N values with target types into one value |
39 | /// with a source type (which isn't possible in the base class currently). |
40 | class OneToNTypeConverter : public TypeConverter { |
41 | public: |
42 | /// Callback that expresses user-provided materialization logic from the given |
43 | /// value to N values of the given types. This is useful for expressing target |
44 | /// materializations for 1:N type conversions, which materialize one value in |
45 | /// a source type as N values in target types. |
46 | using OneToNMaterializationCallbackFn = |
47 | std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange, |
48 | Value, Location)>; |
49 | |
50 | /// Creates the mapping of the given range of original types to target types |
51 | /// of the conversion and stores that mapping in the given (signature) |
52 | /// conversion. This function simply calls |
53 | /// `TypeConverter::convertSignatureArgs` and exists here with a different |
54 | /// name to reflect the broader semantic. |
55 | LogicalResult computeTypeMapping(TypeRange types, |
56 | SignatureConversion &result) { |
57 | return convertSignatureArgs(types, result); |
58 | } |
59 | |
60 | /// Applies one of the user-provided 1:N target materializations. If several |
61 | /// exists, they are tried out in the reverse order in which they have been |
62 | /// added until the first one succeeds. If none succeeds, the functions |
63 | /// returns `std::nullopt`. |
64 | std::optional<SmallVector<Value>> |
65 | materializeTargetConversion(OpBuilder &builder, Location loc, |
66 | TypeRange resultTypes, Value input) const; |
67 | |
68 | /// Adds a 1:N target materialization to the converter. Such materializations |
69 | /// build IR that converts N values with target types into 1 value of the |
70 | /// source type. |
71 | void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) { |
72 | oneToNTargetMaterializations.emplace_back(Args: std::move(callback)); |
73 | } |
74 | |
75 | private: |
76 | SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations; |
77 | }; |
78 | |
79 | /// Stores a 1:N mapping of types and provides several useful accessors. This |
80 | /// class extends `SignatureConversion`, which already supports 1:N type |
81 | /// mappings but lacks some accessors into the mapping as well as access to the |
82 | /// original types. |
83 | class OneToNTypeMapping : public TypeConverter::SignatureConversion { |
84 | public: |
85 | OneToNTypeMapping(TypeRange originalTypes) |
86 | : TypeConverter::SignatureConversion(originalTypes.size()), |
87 | originalTypes(originalTypes) {} |
88 | |
89 | using TypeConverter::SignatureConversion::getConvertedTypes; |
90 | |
91 | /// Returns the list of types that corresponds to the original type at the |
92 | /// given index. |
93 | TypeRange getConvertedTypes(unsigned originalTypeNo) const; |
94 | |
95 | /// Returns the list of original types. |
96 | TypeRange getOriginalTypes() const { return originalTypes; } |
97 | |
98 | /// Returns the slice of converted values that corresponds the original value |
99 | /// at the given index. |
100 | ValueRange getConvertedValues(ValueRange convertedValues, |
101 | unsigned originalValueNo) const; |
102 | |
103 | /// Fills the given result vector with as many copies of the location of the |
104 | /// original value as the number of values it is converted to. |
105 | void convertLocation(Value originalValue, unsigned originalValueNo, |
106 | llvm::SmallVectorImpl<Location> &result) const; |
107 | |
108 | /// Fills the given result vector with as many copies of the lociation of each |
109 | /// original value as the number of values they are respectively converted to. |
110 | void convertLocations(ValueRange originalValues, |
111 | llvm::SmallVectorImpl<Location> &result) const; |
112 | |
113 | /// Returns true iff at least one type conversion maps an input type to a type |
114 | /// that is different from itself. |
115 | bool hasNonIdentityConversion() const; |
116 | |
117 | private: |
118 | llvm::SmallVector<Type> originalTypes; |
119 | }; |
120 | |
121 | /// Extends the basic `RewritePattern` class with a type converter member and |
122 | /// some accessors to it. This is useful for patterns that are not |
123 | /// `ConversionPattern`s but still require access to a type converter. |
124 | class RewritePatternWithConverter : public mlir::RewritePattern { |
125 | public: |
126 | /// Construct a conversion pattern with the given converter, and forward the |
127 | /// remaining arguments to RewritePattern. |
128 | template <typename... Args> |
129 | RewritePatternWithConverter(TypeConverter &typeConverter, Args &&...args) |
130 | : RewritePattern(std::forward<Args>(args)...), |
131 | typeConverter(&typeConverter) {} |
132 | |
133 | /// Return the type converter held by this pattern, or nullptr if the pattern |
134 | /// does not require type conversion. |
135 | TypeConverter *getTypeConverter() const { return typeConverter; } |
136 | |
137 | template <typename ConverterTy> |
138 | std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value, |
139 | ConverterTy *> |
140 | getTypeConverter() const { |
141 | return static_cast<ConverterTy *>(typeConverter); |
142 | } |
143 | |
144 | protected: |
145 | /// A type converter for use by this pattern. |
146 | TypeConverter *const typeConverter; |
147 | }; |
148 | |
149 | /// Specialization of `PatternRewriter` that `OneToNConversionPattern`s use. The |
150 | /// class provides additional rewrite methods that are specific to 1:N type |
151 | /// conversions. |
152 | class OneToNPatternRewriter : public PatternRewriter { |
153 | public: |
154 | OneToNPatternRewriter(MLIRContext *context, |
155 | OpBuilder::Listener *listener = nullptr) |
156 | : PatternRewriter(context, listener) {} |
157 | |
158 | /// Replaces the results of the operation with the specified list of values |
159 | /// mapped back to the original types as specified in the provided type |
160 | /// mapping. That type mapping must match the replaced op (i.e., the original |
161 | /// types must be the same as the result types of the op) and the new values |
162 | /// (i.e., the converted types must be the same as the types of the new |
163 | /// values). |
164 | void replaceOp(Operation *op, ValueRange newValues, |
165 | const OneToNTypeMapping &resultMapping); |
166 | using PatternRewriter::replaceOp; |
167 | |
168 | /// Applies the given argument conversion to the given block. This consists of |
169 | /// replacing each original argument with N arguments as specified in the |
170 | /// argument conversion and inserting unrealized casts from the converted |
171 | /// values to the original types, which are then used in lieu of the original |
172 | /// ones. (Eventually, `applyPartialOneToNConversion` replaces these casts |
173 | /// with a user-provided argument materialization if necessary.) This is |
174 | /// similar to `ArgConverter::applySignatureConversion` but (1) handles 1:N |
175 | /// type conversion properly and probably (2) doesn't handle many other edge |
176 | /// cases. |
177 | Block *applySignatureConversion(Block *block, |
178 | OneToNTypeMapping &argumentConversion); |
179 | }; |
180 | |
181 | /// Base class for patterns with 1:N type conversions. Derived classes have to |
182 | /// overwrite the `matchAndRewrite` overlaod that provides additional |
183 | /// information for 1:N type conversions. |
184 | class OneToNConversionPattern : public RewritePatternWithConverter { |
185 | public: |
186 | using RewritePatternWithConverter::RewritePatternWithConverter; |
187 | |
188 | /// This function has to be implemented by derived classes and is called from |
189 | /// the usual overloads. Like in "normal" `DialectConversion`, the function is |
190 | /// provided with the converted operands (which thus have target types). Since |
191 | /// 1:N conversions are supported, there is usually no 1:1 relationship |
192 | /// between the original and the converted operands. Instead, the provided |
193 | /// `operandMapping` can be used to access the converted operands that |
194 | /// correspond to a particular original operand. Similarly, `resultMapping` |
195 | /// is provided to help with assembling the result values, which may have 1:N |
196 | /// correspondences as well. In that case, the original op should be replaced |
197 | /// with the overload of `replaceOp` that takes the provided `resultMapping` |
198 | /// in order to deal with the mapping of converted result values to their |
199 | /// usages in the original types correctly. |
200 | virtual LogicalResult matchAndRewrite(Operation *op, |
201 | OneToNPatternRewriter &rewriter, |
202 | const OneToNTypeMapping &operandMapping, |
203 | const OneToNTypeMapping &resultMapping, |
204 | ValueRange convertedOperands) const = 0; |
205 | |
206 | LogicalResult matchAndRewrite(Operation *op, |
207 | PatternRewriter &rewriter) const final; |
208 | }; |
209 | |
210 | /// This class is a wrapper around `OneToNConversionPattern` for matching |
211 | /// against instances of a particular op class. |
212 | template <typename SourceOp> |
213 | class OneToNOpConversionPattern : public OneToNConversionPattern { |
214 | public: |
215 | OneToNOpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, |
216 | PatternBenefit benefit = 1, |
217 | ArrayRef<StringRef> generatedNames = {}) |
218 | : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(), |
219 | benefit, context, generatedNames) {} |
220 | /// Generic adaptor around the root op of this pattern using the converted |
221 | /// operands. Importantly, each operand is represented as a *range* of values, |
222 | /// namely the N values each original operand gets converted to. Concretely, |
223 | /// this makes the result type of the accessor functions of the adaptor class |
224 | /// be a `ValueRange`. |
225 | class OpAdaptor |
226 | : public SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> { |
227 | public: |
228 | using RangeT = ArrayRef<ValueRange>; |
229 | using BaseT = typename SourceOp::template GenericAdaptor<RangeT>; |
230 | using Properties = typename SourceOp::template InferredProperties<SourceOp>; |
231 | |
232 | OpAdaptor(const OneToNTypeMapping *operandMapping, |
233 | const OneToNTypeMapping *resultMapping, |
234 | const ValueRange *convertedOperands, RangeT values, SourceOp op) |
235 | : BaseT(values, op), operandMapping(operandMapping), |
236 | resultMapping(resultMapping), convertedOperands(convertedOperands) {} |
237 | |
238 | /// Get the type mapping of the original operands to the converted operands. |
239 | const OneToNTypeMapping &getOperandMapping() const { |
240 | return *operandMapping; |
241 | } |
242 | |
243 | /// Get the type mapping of the original results to the converted results. |
244 | const OneToNTypeMapping &getResultMapping() const { return *resultMapping; } |
245 | |
246 | /// Get a flat range of all converted operands. Unlike `getOperands`, which |
247 | /// returns an `ArrayRef` with one `ValueRange` for each original operand, |
248 | /// this function returns a `ValueRange` that contains all converted |
249 | /// operands irrespectively of which operand they originated from. |
250 | ValueRange getFlatOperands() const { return *convertedOperands; } |
251 | |
252 | private: |
253 | const OneToNTypeMapping *operandMapping; |
254 | const OneToNTypeMapping *resultMapping; |
255 | const ValueRange *convertedOperands; |
256 | }; |
257 | |
258 | using OneToNConversionPattern::matchAndRewrite; |
259 | |
260 | /// Overload that derived classes have to override for their op type. |
261 | virtual LogicalResult |
262 | matchAndRewrite(SourceOp op, OpAdaptor adaptor, |
263 | OneToNPatternRewriter &rewriter) const = 0; |
264 | |
265 | LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, |
266 | const OneToNTypeMapping &operandMapping, |
267 | const OneToNTypeMapping &resultMapping, |
268 | ValueRange convertedOperands) const final { |
269 | // Wrap converted operands and type mappings into an adaptor. |
270 | SmallVector<ValueRange> valueRanges; |
271 | for (int64_t i = 0; i < op->getNumOperands(); i++) { |
272 | auto values = operandMapping.getConvertedValues(convertedValues: convertedOperands, originalValueNo: i); |
273 | valueRanges.push_back(Elt: values); |
274 | } |
275 | OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands, |
276 | valueRanges, cast<SourceOp>(op)); |
277 | |
278 | // Call overload implemented by the derived class. |
279 | return matchAndRewrite(cast<SourceOp>(op), adaptor, rewriter); |
280 | } |
281 | }; |
282 | |
283 | /// Applies the given set of patterns recursively on the given op and adds user |
284 | /// materializations where necessary. The patterns are expected to be |
285 | /// `OneToNConversionPattern`, which help converting the types of the operands |
286 | /// and results of the matched ops. The provided type converter is used to |
287 | /// convert the operands of matched ops from their original types to operands |
288 | /// with different types. Unlike in `DialectConversion`, this supports 1:N type |
289 | /// conversions. Those conversions at the "boundary" of the pattern application, |
290 | /// where converted results are not consumed by replaced ops that expect the |
291 | /// converted operands or vice versa, the function inserts user materializations |
292 | /// from the type converter. Also unlike `DialectConversion`, there are no legal |
293 | /// or illegal types; the function simply applies the given patterns and does |
294 | /// not fail if some ops or types remain unconverted (i.e., the conversion is |
295 | /// only "partial"). |
296 | LogicalResult |
297 | applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, |
298 | const FrozenRewritePatternSet &patterns); |
299 | |
300 | } // namespace mlir |
301 | |
302 | #endif // MLIR_TRANSFORMS_ONETONTYPECONVERSION_H |
303 | |