1//===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===//
2//
3// Licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file provides utils for implementing (poor-man's) dialect conversion
10// passes with 1:N type conversions.
11//
12// The main function, `applyPartialOneToNConversion`, first applies a set of
13// `RewritePattern`s, which produce unrealized casts to convert the operands and
14// results from and to the source types, and then replaces all newly added
15// unrealized casts by user-provided materializations. For this to work, the
16// main function requires a special `TypeConverter`, a special
17// `PatternRewriter`, and special RewritePattern`s, which extend their
18// respective base classes for 1:N type converions.
19//
20// Note that this is much more simple-minded than the "real" dialect conversion,
21// which checks for legality before applying patterns and does probably many
22// other additional things. Ideally, some of the extensions here could be
23// integrated there.
24//
25//===----------------------------------------------------------------------===//
26
27#ifndef MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
28#define MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
29
30#include "mlir/IR/PatternMatch.h"
31#include "mlir/Transforms/DialectConversion.h"
32#include "llvm/ADT/SmallVector.h"
33
34namespace mlir {
35
36/// Extends `TypeConverter` with 1:N target materializations. Such
37/// materializations have to provide the "reverse" of 1:N type conversions,
38/// i.e., they need to materialize N values with target types into one value
39/// with a source type (which isn't possible in the base class currently).
40class OneToNTypeConverter : public TypeConverter {
41public:
42 /// Callback that expresses user-provided materialization logic from the given
43 /// value to N values of the given types. This is useful for expressing target
44 /// materializations for 1:N type conversions, which materialize one value in
45 /// a source type as N values in target types.
46 using OneToNMaterializationCallbackFn =
47 std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
48 Value, Location)>;
49
50 /// Creates the mapping of the given range of original types to target types
51 /// of the conversion and stores that mapping in the given (signature)
52 /// conversion. This function simply calls
53 /// `TypeConverter::convertSignatureArgs` and exists here with a different
54 /// name to reflect the broader semantic.
55 LogicalResult computeTypeMapping(TypeRange types,
56 SignatureConversion &result) {
57 return convertSignatureArgs(types, result);
58 }
59
60 /// Applies one of the user-provided 1:N target materializations. If several
61 /// exists, they are tried out in the reverse order in which they have been
62 /// added until the first one succeeds. If none succeeds, the functions
63 /// returns `std::nullopt`.
64 std::optional<SmallVector<Value>>
65 materializeTargetConversion(OpBuilder &builder, Location loc,
66 TypeRange resultTypes, Value input) const;
67
68 /// Adds a 1:N target materialization to the converter. Such materializations
69 /// build IR that converts N values with target types into 1 value of the
70 /// source type.
71 void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) {
72 oneToNTargetMaterializations.emplace_back(Args: std::move(callback));
73 }
74
75private:
76 SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
77};
78
79/// Stores a 1:N mapping of types and provides several useful accessors. This
80/// class extends `SignatureConversion`, which already supports 1:N type
81/// mappings but lacks some accessors into the mapping as well as access to the
82/// original types.
83class OneToNTypeMapping : public TypeConverter::SignatureConversion {
84public:
85 OneToNTypeMapping(TypeRange originalTypes)
86 : TypeConverter::SignatureConversion(originalTypes.size()),
87 originalTypes(originalTypes) {}
88
89 using TypeConverter::SignatureConversion::getConvertedTypes;
90
91 /// Returns the list of types that corresponds to the original type at the
92 /// given index.
93 TypeRange getConvertedTypes(unsigned originalTypeNo) const;
94
95 /// Returns the list of original types.
96 TypeRange getOriginalTypes() const { return originalTypes; }
97
98 /// Returns the slice of converted values that corresponds the original value
99 /// at the given index.
100 ValueRange getConvertedValues(ValueRange convertedValues,
101 unsigned originalValueNo) const;
102
103 /// Fills the given result vector with as many copies of the location of the
104 /// original value as the number of values it is converted to.
105 void convertLocation(Value originalValue, unsigned originalValueNo,
106 llvm::SmallVectorImpl<Location> &result) const;
107
108 /// Fills the given result vector with as many copies of the lociation of each
109 /// original value as the number of values they are respectively converted to.
110 void convertLocations(ValueRange originalValues,
111 llvm::SmallVectorImpl<Location> &result) const;
112
113 /// Returns true iff at least one type conversion maps an input type to a type
114 /// that is different from itself.
115 bool hasNonIdentityConversion() const;
116
117private:
118 llvm::SmallVector<Type> originalTypes;
119};
120
121/// Extends the basic `RewritePattern` class with a type converter member and
122/// some accessors to it. This is useful for patterns that are not
123/// `ConversionPattern`s but still require access to a type converter.
124class RewritePatternWithConverter : public mlir::RewritePattern {
125public:
126 /// Construct a conversion pattern with the given converter, and forward the
127 /// remaining arguments to RewritePattern.
128 template <typename... Args>
129 RewritePatternWithConverter(TypeConverter &typeConverter, Args &&...args)
130 : RewritePattern(std::forward<Args>(args)...),
131 typeConverter(&typeConverter) {}
132
133 /// Return the type converter held by this pattern, or nullptr if the pattern
134 /// does not require type conversion.
135 TypeConverter *getTypeConverter() const { return typeConverter; }
136
137 template <typename ConverterTy>
138 std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
139 ConverterTy *>
140 getTypeConverter() const {
141 return static_cast<ConverterTy *>(typeConverter);
142 }
143
144protected:
145 /// A type converter for use by this pattern.
146 TypeConverter *const typeConverter;
147};
148
149/// Specialization of `PatternRewriter` that `OneToNConversionPattern`s use. The
150/// class provides additional rewrite methods that are specific to 1:N type
151/// conversions.
152class OneToNPatternRewriter : public PatternRewriter {
153public:
154 OneToNPatternRewriter(MLIRContext *context,
155 OpBuilder::Listener *listener = nullptr)
156 : PatternRewriter(context, listener) {}
157
158 /// Replaces the results of the operation with the specified list of values
159 /// mapped back to the original types as specified in the provided type
160 /// mapping. That type mapping must match the replaced op (i.e., the original
161 /// types must be the same as the result types of the op) and the new values
162 /// (i.e., the converted types must be the same as the types of the new
163 /// values).
164 void replaceOp(Operation *op, ValueRange newValues,
165 const OneToNTypeMapping &resultMapping);
166 using PatternRewriter::replaceOp;
167
168 /// Applies the given argument conversion to the given block. This consists of
169 /// replacing each original argument with N arguments as specified in the
170 /// argument conversion and inserting unrealized casts from the converted
171 /// values to the original types, which are then used in lieu of the original
172 /// ones. (Eventually, `applyPartialOneToNConversion` replaces these casts
173 /// with a user-provided argument materialization if necessary.) This is
174 /// similar to `ArgConverter::applySignatureConversion` but (1) handles 1:N
175 /// type conversion properly and probably (2) doesn't handle many other edge
176 /// cases.
177 Block *applySignatureConversion(Block *block,
178 OneToNTypeMapping &argumentConversion);
179};
180
181/// Base class for patterns with 1:N type conversions. Derived classes have to
182/// overwrite the `matchAndRewrite` overlaod that provides additional
183/// information for 1:N type conversions.
184class OneToNConversionPattern : public RewritePatternWithConverter {
185public:
186 using RewritePatternWithConverter::RewritePatternWithConverter;
187
188 /// This function has to be implemented by derived classes and is called from
189 /// the usual overloads. Like in "normal" `DialectConversion`, the function is
190 /// provided with the converted operands (which thus have target types). Since
191 /// 1:N conversions are supported, there is usually no 1:1 relationship
192 /// between the original and the converted operands. Instead, the provided
193 /// `operandMapping` can be used to access the converted operands that
194 /// correspond to a particular original operand. Similarly, `resultMapping`
195 /// is provided to help with assembling the result values, which may have 1:N
196 /// correspondences as well. In that case, the original op should be replaced
197 /// with the overload of `replaceOp` that takes the provided `resultMapping`
198 /// in order to deal with the mapping of converted result values to their
199 /// usages in the original types correctly.
200 virtual LogicalResult matchAndRewrite(Operation *op,
201 OneToNPatternRewriter &rewriter,
202 const OneToNTypeMapping &operandMapping,
203 const OneToNTypeMapping &resultMapping,
204 ValueRange convertedOperands) const = 0;
205
206 LogicalResult matchAndRewrite(Operation *op,
207 PatternRewriter &rewriter) const final;
208};
209
210/// This class is a wrapper around `OneToNConversionPattern` for matching
211/// against instances of a particular op class.
212template <typename SourceOp>
213class OneToNOpConversionPattern : public OneToNConversionPattern {
214public:
215 OneToNOpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
216 PatternBenefit benefit = 1,
217 ArrayRef<StringRef> generatedNames = {})
218 : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(),
219 benefit, context, generatedNames) {}
220 /// Generic adaptor around the root op of this pattern using the converted
221 /// operands. Importantly, each operand is represented as a *range* of values,
222 /// namely the N values each original operand gets converted to. Concretely,
223 /// this makes the result type of the accessor functions of the adaptor class
224 /// be a `ValueRange`.
225 class OpAdaptor
226 : public SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> {
227 public:
228 using RangeT = ArrayRef<ValueRange>;
229 using BaseT = typename SourceOp::template GenericAdaptor<RangeT>;
230 using Properties = typename SourceOp::template InferredProperties<SourceOp>;
231
232 OpAdaptor(const OneToNTypeMapping *operandMapping,
233 const OneToNTypeMapping *resultMapping,
234 const ValueRange *convertedOperands, RangeT values, SourceOp op)
235 : BaseT(values, op), operandMapping(operandMapping),
236 resultMapping(resultMapping), convertedOperands(convertedOperands) {}
237
238 /// Get the type mapping of the original operands to the converted operands.
239 const OneToNTypeMapping &getOperandMapping() const {
240 return *operandMapping;
241 }
242
243 /// Get the type mapping of the original results to the converted results.
244 const OneToNTypeMapping &getResultMapping() const { return *resultMapping; }
245
246 /// Get a flat range of all converted operands. Unlike `getOperands`, which
247 /// returns an `ArrayRef` with one `ValueRange` for each original operand,
248 /// this function returns a `ValueRange` that contains all converted
249 /// operands irrespectively of which operand they originated from.
250 ValueRange getFlatOperands() const { return *convertedOperands; }
251
252 private:
253 const OneToNTypeMapping *operandMapping;
254 const OneToNTypeMapping *resultMapping;
255 const ValueRange *convertedOperands;
256 };
257
258 using OneToNConversionPattern::matchAndRewrite;
259
260 /// Overload that derived classes have to override for their op type.
261 virtual LogicalResult
262 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
263 OneToNPatternRewriter &rewriter) const = 0;
264
265 LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter,
266 const OneToNTypeMapping &operandMapping,
267 const OneToNTypeMapping &resultMapping,
268 ValueRange convertedOperands) const final {
269 // Wrap converted operands and type mappings into an adaptor.
270 SmallVector<ValueRange> valueRanges;
271 for (int64_t i = 0; i < op->getNumOperands(); i++) {
272 auto values = operandMapping.getConvertedValues(convertedValues: convertedOperands, originalValueNo: i);
273 valueRanges.push_back(Elt: values);
274 }
275 OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands,
276 valueRanges, cast<SourceOp>(op));
277
278 // Call overload implemented by the derived class.
279 return matchAndRewrite(cast<SourceOp>(op), adaptor, rewriter);
280 }
281};
282
283/// Applies the given set of patterns recursively on the given op and adds user
284/// materializations where necessary. The patterns are expected to be
285/// `OneToNConversionPattern`, which help converting the types of the operands
286/// and results of the matched ops. The provided type converter is used to
287/// convert the operands of matched ops from their original types to operands
288/// with different types. Unlike in `DialectConversion`, this supports 1:N type
289/// conversions. Those conversions at the "boundary" of the pattern application,
290/// where converted results are not consumed by replaced ops that expect the
291/// converted operands or vice versa, the function inserts user materializations
292/// from the type converter. Also unlike `DialectConversion`, there are no legal
293/// or illegal types; the function simply applies the given patterns and does
294/// not fail if some ops or types remain unconverted (i.e., the conversion is
295/// only "partial").
296LogicalResult
297applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
298 const FrozenRewritePatternSet &patterns);
299
300} // namespace mlir
301
302#endif // MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
303

source code of mlir/include/mlir/Transforms/OneToNTypeConversion.h