1 | //===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file declares a generic pass for converting between MLIR dialects. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_ |
14 | #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_ |
15 | |
16 | #include "mlir/Config/mlir-config.h" |
17 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
18 | #include "llvm/ADT/MapVector.h" |
19 | #include "llvm/ADT/StringMap.h" |
20 | #include <type_traits> |
21 | |
22 | namespace mlir { |
23 | |
24 | // Forward declarations. |
25 | class Attribute; |
26 | class Block; |
27 | struct ConversionConfig; |
28 | class ConversionPatternRewriter; |
29 | class MLIRContext; |
30 | class Operation; |
31 | struct OperationConverter; |
32 | class Type; |
33 | class Value; |
34 | |
35 | //===----------------------------------------------------------------------===// |
36 | // Type Conversion |
37 | //===----------------------------------------------------------------------===// |
38 | |
39 | /// Type conversion class. Specific conversions and materializations can be |
40 | /// registered using addConversion and addMaterialization, respectively. |
41 | class TypeConverter { |
42 | public: |
43 | virtual ~TypeConverter() = default; |
44 | TypeConverter() = default; |
45 | // Copy the registered conversions, but not the caches |
46 | TypeConverter(const TypeConverter &other) |
47 | : conversions(other.conversions), |
48 | argumentMaterializations(other.argumentMaterializations), |
49 | sourceMaterializations(other.sourceMaterializations), |
50 | targetMaterializations(other.targetMaterializations), |
51 | typeAttributeConversions(other.typeAttributeConversions) {} |
52 | TypeConverter &operator=(const TypeConverter &other) { |
53 | conversions = other.conversions; |
54 | argumentMaterializations = other.argumentMaterializations; |
55 | sourceMaterializations = other.sourceMaterializations; |
56 | targetMaterializations = other.targetMaterializations; |
57 | typeAttributeConversions = other.typeAttributeConversions; |
58 | return *this; |
59 | } |
60 | |
61 | /// This class provides all of the information necessary to convert a type |
62 | /// signature. |
63 | class SignatureConversion { |
64 | public: |
65 | SignatureConversion(unsigned numOrigInputs) |
66 | : remappedInputs(numOrigInputs) {} |
67 | |
68 | /// This struct represents a range of new types or a single value that |
69 | /// remaps an existing signature input. |
70 | struct InputMapping { |
71 | size_t inputNo, size; |
72 | Value replacementValue; |
73 | }; |
74 | |
75 | /// Return the argument types for the new signature. |
76 | ArrayRef<Type> getConvertedTypes() const { return argTypes; } |
77 | |
78 | /// Get the input mapping for the given argument. |
79 | std::optional<InputMapping> getInputMapping(unsigned input) const { |
80 | return remappedInputs[input]; |
81 | } |
82 | |
83 | //===------------------------------------------------------------------===// |
84 | // Conversion Hooks |
85 | //===------------------------------------------------------------------===// |
86 | |
87 | /// Remap an input of the original signature with a new set of types. The |
88 | /// new types are appended to the new signature conversion. |
89 | void addInputs(unsigned origInputNo, ArrayRef<Type> types); |
90 | |
91 | /// Append new input types to the signature conversion, this should only be |
92 | /// used if the new types are not intended to remap an existing input. |
93 | void addInputs(ArrayRef<Type> types); |
94 | |
95 | /// Remap an input of the original signature to another `replacement` |
96 | /// value. This drops the original argument. |
97 | void remapInput(unsigned origInputNo, Value replacement); |
98 | |
99 | private: |
100 | /// Remap an input of the original signature with a range of types in the |
101 | /// new signature. |
102 | void remapInput(unsigned origInputNo, unsigned newInputNo, |
103 | unsigned newInputCount = 1); |
104 | |
105 | /// The remapping information for each of the original arguments. |
106 | SmallVector<std::optional<InputMapping>, 4> remappedInputs; |
107 | |
108 | /// The set of new argument types. |
109 | SmallVector<Type, 4> argTypes; |
110 | }; |
111 | |
112 | /// The general result of a type attribute conversion callback, allowing |
113 | /// for early termination. The default constructor creates the na case. |
114 | class AttributeConversionResult { |
115 | public: |
116 | constexpr AttributeConversionResult() : impl() {} |
117 | AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {} |
118 | |
119 | static AttributeConversionResult result(Attribute attr); |
120 | static AttributeConversionResult na(); |
121 | static AttributeConversionResult abort(); |
122 | |
123 | bool hasResult() const; |
124 | bool isNa() const; |
125 | bool isAbort() const; |
126 | |
127 | Attribute getResult() const; |
128 | |
129 | private: |
130 | AttributeConversionResult(Attribute attr, unsigned tag) : impl(attr, tag) {} |
131 | |
132 | llvm::PointerIntPair<Attribute, 2> impl; |
133 | // Note that na is 0 so that we can use PointerIntPair's default |
134 | // constructor. |
135 | static constexpr unsigned naTag = 0; |
136 | static constexpr unsigned resultTag = 1; |
137 | static constexpr unsigned abortTag = 2; |
138 | }; |
139 | |
140 | /// Register a conversion function. A conversion function must be convertible |
141 | /// to any of the following forms(where `T` is a class derived from `Type`: |
142 | /// * std::optional<Type>(T) |
143 | /// - This form represents a 1-1 type conversion. It should return nullptr |
144 | /// or `std::nullopt` to signify failure. If `std::nullopt` is returned, |
145 | /// the converter is allowed to try another conversion function to |
146 | /// perform the conversion. |
147 | /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &) |
148 | /// - This form represents a 1-N type conversion. It should return |
149 | /// `failure` or `std::nullopt` to signify a failed conversion. If the |
150 | /// new set of types is empty, the type is removed and any usages of the |
151 | /// existing value are expected to be removed during conversion. If |
152 | /// `std::nullopt` is returned, the converter is allowed to try another |
153 | /// conversion function to perform the conversion. |
154 | /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &, |
155 | /// ArrayRef<Type>) |
156 | /// - This form represents a 1-N type conversion supporting recursive |
157 | /// types. The first two arguments and the return value are the same as |
158 | /// for the regular 1-N form. The third argument is contains is the |
159 | /// "call stack" of the recursive conversion: it contains the list of |
160 | /// types currently being converted, with the current type being the |
161 | /// last one. If it is present more than once in the list, the |
162 | /// conversion concerns a recursive type. |
163 | /// Note: When attempting to convert a type, e.g. via 'convertType', the |
164 | /// mostly recently added conversions will be invoked first. |
165 | template <typename FnT, typename T = typename llvm::function_traits< |
166 | std::decay_t<FnT>>::template arg_t<0>> |
167 | void addConversion(FnT &&callback) { |
168 | registerConversion(callback: wrapCallback<T>(std::forward<FnT>(callback))); |
169 | } |
170 | |
171 | /// Register a materialization function, which must be convertible to the |
172 | /// following form: |
173 | /// `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`, |
174 | /// where `T` is any subclass of `Type`. This function is responsible for |
175 | /// creating an operation, using the OpBuilder and Location provided, that |
176 | /// "casts" a range of values into a single value of the given type `T`. It |
177 | /// must return a Value of the converted type on success, an `std::nullopt` if |
178 | /// it failed but other materialization can be attempted, and `nullptr` on |
179 | /// unrecoverable failure. It will only be called for (sub)types of `T`. |
180 | /// Materialization functions must be provided when a type conversion may |
181 | /// persist after the conversion has finished. |
182 | /// |
183 | /// This method registers a materialization that will be called when |
184 | /// converting an illegal block argument type, to a legal type. |
185 | template <typename FnT, typename T = typename llvm::function_traits< |
186 | std::decay_t<FnT>>::template arg_t<1>> |
187 | void addArgumentMaterialization(FnT &&callback) { |
188 | argumentMaterializations.emplace_back( |
189 | wrapMaterialization<T>(std::forward<FnT>(callback))); |
190 | } |
191 | /// This method registers a materialization that will be called when |
192 | /// converting a legal type to an illegal source type. This is used when |
193 | /// conversions to an illegal type must persist beyond the main conversion. |
194 | template <typename FnT, typename T = typename llvm::function_traits< |
195 | std::decay_t<FnT>>::template arg_t<1>> |
196 | void addSourceMaterialization(FnT &&callback) { |
197 | sourceMaterializations.emplace_back( |
198 | wrapMaterialization<T>(std::forward<FnT>(callback))); |
199 | } |
200 | /// This method registers a materialization that will be called when |
201 | /// converting type from an illegal, or source, type to a legal type. |
202 | template <typename FnT, typename T = typename llvm::function_traits< |
203 | std::decay_t<FnT>>::template arg_t<1>> |
204 | void addTargetMaterialization(FnT &&callback) { |
205 | targetMaterializations.emplace_back( |
206 | wrapMaterialization<T>(std::forward<FnT>(callback))); |
207 | } |
208 | |
209 | /// Register a conversion function for attributes within types. Type |
210 | /// converters may call this function in order to allow hoking into the |
211 | /// translation of attributes that exist within types. For example, a type |
212 | /// converter for the `memref` type could use these conversions to convert |
213 | /// memory spaces or layouts in an extensible way. |
214 | /// |
215 | /// The conversion functions take a non-null Type or subclass of Type and a |
216 | /// non-null Attribute (or subclass of Attribute), and returns a |
217 | /// `AttributeConversionResult`. This result can either contan an `Attribute`, |
218 | /// which may be `nullptr`, representing the conversion's success, |
219 | /// `AttributeConversionResult::na()` (the default empty value), indicating |
220 | /// that the conversion function did not apply and that further conversion |
221 | /// functions should be checked, or `AttributeConversionResult::abort()` |
222 | /// indicating that the conversion process should be aborted. |
223 | /// |
224 | /// Registered conversion functions are callled in the reverse of the order in |
225 | /// which they were registered. |
226 | template < |
227 | typename FnT, |
228 | typename T = |
229 | typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>, |
230 | typename A = |
231 | typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>> |
232 | void addTypeAttributeConversion(FnT &&callback) { |
233 | registerTypeAttributeConversion( |
234 | callback: wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback))); |
235 | } |
236 | |
237 | /// Convert the given type. This function should return failure if no valid |
238 | /// conversion exists, success otherwise. If the new set of types is empty, |
239 | /// the type is removed and any usages of the existing value are expected to |
240 | /// be removed during conversion. |
241 | LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const; |
242 | |
243 | /// This hook simplifies defining 1-1 type conversions. This function returns |
244 | /// the type to convert to on success, and a null type on failure. |
245 | Type convertType(Type t) const; |
246 | |
247 | /// Attempts a 1-1 type conversion, expecting the result type to be |
248 | /// `TargetType`. Returns the converted type cast to `TargetType` on success, |
249 | /// and a null type on conversion or cast failure. |
250 | template <typename TargetType> TargetType convertType(Type t) const { |
251 | return dyn_cast_or_null<TargetType>(convertType(t)); |
252 | } |
253 | |
254 | /// Convert the given set of types, filling 'results' as necessary. This |
255 | /// returns failure if the conversion of any of the types fails, success |
256 | /// otherwise. |
257 | LogicalResult convertTypes(TypeRange types, |
258 | SmallVectorImpl<Type> &results) const; |
259 | |
260 | /// Return true if the given type is legal for this type converter, i.e. the |
261 | /// type converts to itself. |
262 | bool isLegal(Type type) const; |
263 | |
264 | /// Return true if all of the given types are legal for this type converter. |
265 | template <typename RangeT> |
266 | std::enable_if_t<!std::is_convertible<RangeT, Type>::value && |
267 | !std::is_convertible<RangeT, Operation *>::value, |
268 | bool> |
269 | isLegal(RangeT &&range) const { |
270 | return llvm::all_of(range, [this](Type type) { return isLegal(type); }); |
271 | } |
272 | /// Return true if the given operation has legal operand and result types. |
273 | bool isLegal(Operation *op) const; |
274 | |
275 | /// Return true if the types of block arguments within the region are legal. |
276 | bool isLegal(Region *region) const; |
277 | |
278 | /// Return true if the inputs and outputs of the given function type are |
279 | /// legal. |
280 | bool isSignatureLegal(FunctionType ty) const; |
281 | |
282 | /// This method allows for converting a specific argument of a signature. It |
283 | /// takes as inputs the original argument input number, type. |
284 | /// On success, it populates 'result' with any new mappings. |
285 | LogicalResult convertSignatureArg(unsigned inputNo, Type type, |
286 | SignatureConversion &result) const; |
287 | LogicalResult convertSignatureArgs(TypeRange types, |
288 | SignatureConversion &result, |
289 | unsigned origInputOffset = 0) const; |
290 | |
291 | /// This function converts the type signature of the given block, by invoking |
292 | /// 'convertSignatureArg' for each argument. This function should return a |
293 | /// valid conversion for the signature on success, std::nullopt otherwise. |
294 | std::optional<SignatureConversion> convertBlockSignature(Block *block) const; |
295 | |
296 | /// Materialize a conversion from a set of types into one result type by |
297 | /// generating a cast sequence of some kind. See the respective |
298 | /// `add*Materialization` for more information on the context for these |
299 | /// methods. |
300 | Value materializeArgumentConversion(OpBuilder &builder, Location loc, |
301 | Type resultType, |
302 | ValueRange inputs) const { |
303 | return materializeConversion(materializations: argumentMaterializations, builder, loc, |
304 | resultType, inputs); |
305 | } |
306 | Value materializeSourceConversion(OpBuilder &builder, Location loc, |
307 | Type resultType, ValueRange inputs) const { |
308 | return materializeConversion(materializations: sourceMaterializations, builder, loc, |
309 | resultType, inputs); |
310 | } |
311 | Value materializeTargetConversion(OpBuilder &builder, Location loc, |
312 | Type resultType, ValueRange inputs) const { |
313 | return materializeConversion(materializations: targetMaterializations, builder, loc, |
314 | resultType, inputs); |
315 | } |
316 | |
317 | /// Convert an attribute present `attr` from within the type `type` using |
318 | /// the registered conversion functions. If no applicable conversion has been |
319 | /// registered, return std::nullopt. Note that the empty attribute/`nullptr` |
320 | /// is a valid return value for this function. |
321 | std::optional<Attribute> convertTypeAttribute(Type type, |
322 | Attribute attr) const; |
323 | |
324 | private: |
325 | /// The signature of the callback used to convert a type. If the new set of |
326 | /// types is empty, the type is removed and any usages of the existing value |
327 | /// are expected to be removed during conversion. |
328 | using ConversionCallbackFn = std::function<std::optional<LogicalResult>( |
329 | Type, SmallVectorImpl<Type> &)>; |
330 | |
331 | /// The signature of the callback used to materialize a conversion. |
332 | using MaterializationCallbackFn = std::function<std::optional<Value>( |
333 | OpBuilder &, Type, ValueRange, Location)>; |
334 | |
335 | /// The signature of the callback used to convert a type attribute. |
336 | using TypeAttributeConversionCallbackFn = |
337 | std::function<AttributeConversionResult(Type, Attribute)>; |
338 | |
339 | /// Attempt to materialize a conversion using one of the provided |
340 | /// materialization functions. |
341 | Value |
342 | materializeConversion(ArrayRef<MaterializationCallbackFn> materializations, |
343 | OpBuilder &builder, Location loc, Type resultType, |
344 | ValueRange inputs) const; |
345 | |
346 | /// Generate a wrapper for the given callback. This allows for accepting |
347 | /// different callback forms, that all compose into a single version. |
348 | /// With callback of form: `std::optional<Type>(T)` |
349 | template <typename T, typename FnT> |
350 | std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn> |
351 | wrapCallback(FnT &&callback) const { |
352 | return wrapCallback<T>([callback = std::forward<FnT>(callback)]( |
353 | T type, SmallVectorImpl<Type> &results) { |
354 | if (std::optional<Type> resultOpt = callback(type)) { |
355 | bool wasSuccess = static_cast<bool>(*resultOpt); |
356 | if (wasSuccess) |
357 | results.push_back(Elt: *resultOpt); |
358 | return std::optional<LogicalResult>(success(isSuccess: wasSuccess)); |
359 | } |
360 | return std::optional<LogicalResult>(); |
361 | }); |
362 | } |
363 | /// With callback of form: `std::optional<LogicalResult>( |
364 | /// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`. |
365 | template <typename T, typename FnT> |
366 | std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>, |
367 | ConversionCallbackFn> |
368 | wrapCallback(FnT &&callback) const { |
369 | return [callback = std::forward<FnT>(callback)]( |
370 | Type type, |
371 | SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { |
372 | T derivedType = dyn_cast<T>(type); |
373 | if (!derivedType) |
374 | return std::nullopt; |
375 | return callback(derivedType, results); |
376 | }; |
377 | } |
378 | |
379 | /// Register a type conversion. |
380 | void registerConversion(ConversionCallbackFn callback) { |
381 | conversions.emplace_back(Args: std::move(callback)); |
382 | cachedDirectConversions.clear(); |
383 | cachedMultiConversions.clear(); |
384 | } |
385 | |
386 | /// Generate a wrapper for the given materialization callback. The callback |
387 | /// may take any subclass of `Type` and the wrapper will check for the target |
388 | /// type to be of the expected class before calling the callback. |
389 | template <typename T, typename FnT> |
390 | MaterializationCallbackFn wrapMaterialization(FnT &&callback) const { |
391 | return [callback = std::forward<FnT>(callback)]( |
392 | OpBuilder &builder, Type resultType, ValueRange inputs, |
393 | Location loc) -> std::optional<Value> { |
394 | if (T derivedType = dyn_cast<T>(resultType)) |
395 | return callback(builder, derivedType, inputs, loc); |
396 | return std::nullopt; |
397 | }; |
398 | } |
399 | |
400 | /// Generate a wrapper for the given memory space conversion callback. The |
401 | /// callback may take any subclass of `Attribute` and the wrapper will check |
402 | /// for the target attribute to be of the expected class before calling the |
403 | /// callback. |
404 | template <typename T, typename A, typename FnT> |
405 | TypeAttributeConversionCallbackFn |
406 | wrapTypeAttributeConversion(FnT &&callback) const { |
407 | return [callback = std::forward<FnT>(callback)]( |
408 | Type type, Attribute attr) -> AttributeConversionResult { |
409 | if (T derivedType = dyn_cast<T>(type)) { |
410 | if (A derivedAttr = dyn_cast_or_null<A>(attr)) |
411 | return callback(derivedType, derivedAttr); |
412 | } |
413 | return AttributeConversionResult::na(); |
414 | }; |
415 | } |
416 | |
417 | /// Register a memory space conversion, clearing caches. |
418 | void |
419 | registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) { |
420 | typeAttributeConversions.emplace_back(Args: std::move(callback)); |
421 | // Clear type conversions in case a memory space is lingering inside. |
422 | cachedDirectConversions.clear(); |
423 | cachedMultiConversions.clear(); |
424 | } |
425 | |
426 | /// The set of registered conversion functions. |
427 | SmallVector<ConversionCallbackFn, 4> conversions; |
428 | |
429 | /// The list of registered materialization functions. |
430 | SmallVector<MaterializationCallbackFn, 2> argumentMaterializations; |
431 | SmallVector<MaterializationCallbackFn, 2> sourceMaterializations; |
432 | SmallVector<MaterializationCallbackFn, 2> targetMaterializations; |
433 | |
434 | /// The list of registered type attribute conversion functions. |
435 | SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions; |
436 | |
437 | /// A set of cached conversions to avoid recomputing in the common case. |
438 | /// Direct 1-1 conversions are the most common, so this cache stores the |
439 | /// successful 1-1 conversions as well as all failed conversions. |
440 | mutable DenseMap<Type, Type> cachedDirectConversions; |
441 | /// This cache stores the successful 1->N conversions, where N != 1. |
442 | mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions; |
443 | /// A mutex used for cache access |
444 | mutable llvm::sys::SmartRWMutex<true> cacheMutex; |
445 | }; |
446 | |
447 | //===----------------------------------------------------------------------===// |
448 | // Conversion Patterns |
449 | //===----------------------------------------------------------------------===// |
450 | |
451 | /// Base class for the conversion patterns. This pattern class enables type |
452 | /// conversions, and other uses specific to the conversion framework. As such, |
453 | /// patterns of this type can only be used with the 'apply*' methods below. |
454 | class ConversionPattern : public RewritePattern { |
455 | public: |
456 | /// Hook for derived classes to implement rewriting. `op` is the (first) |
457 | /// operation matched by the pattern, `operands` is a list of the rewritten |
458 | /// operand values that are passed to `op`, `rewriter` can be used to emit the |
459 | /// new operations. This function should not fail. If some specific cases of |
460 | /// the operation are not supported, these cases should not be matched. |
461 | virtual void rewrite(Operation *op, ArrayRef<Value> operands, |
462 | ConversionPatternRewriter &rewriter) const { |
463 | llvm_unreachable("unimplemented rewrite" ); |
464 | } |
465 | |
466 | /// Hook for derived classes to implement combined matching and rewriting. |
467 | virtual LogicalResult |
468 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
469 | ConversionPatternRewriter &rewriter) const { |
470 | if (failed(result: match(op))) |
471 | return failure(); |
472 | rewrite(op, operands, rewriter); |
473 | return success(); |
474 | } |
475 | |
476 | /// Attempt to match and rewrite the IR root at the specified operation. |
477 | LogicalResult matchAndRewrite(Operation *op, |
478 | PatternRewriter &rewriter) const final; |
479 | |
480 | /// Return the type converter held by this pattern, or nullptr if the pattern |
481 | /// does not require type conversion. |
482 | const TypeConverter *getTypeConverter() const { return typeConverter; } |
483 | |
484 | template <typename ConverterTy> |
485 | std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value, |
486 | const ConverterTy *> |
487 | getTypeConverter() const { |
488 | return static_cast<const ConverterTy *>(typeConverter); |
489 | } |
490 | |
491 | protected: |
492 | /// See `RewritePattern::RewritePattern` for information on the other |
493 | /// available constructors. |
494 | using RewritePattern::RewritePattern; |
495 | /// Construct a conversion pattern with the given converter, and forward the |
496 | /// remaining arguments to RewritePattern. |
497 | template <typename... Args> |
498 | ConversionPattern(const TypeConverter &typeConverter, Args &&...args) |
499 | : RewritePattern(std::forward<Args>(args)...), |
500 | typeConverter(&typeConverter) {} |
501 | |
502 | protected: |
503 | /// An optional type converter for use by this pattern. |
504 | const TypeConverter *typeConverter = nullptr; |
505 | |
506 | private: |
507 | using RewritePattern::rewrite; |
508 | }; |
509 | |
510 | /// OpConversionPattern is a wrapper around ConversionPattern that allows for |
511 | /// matching and rewriting against an instance of a derived operation class as |
512 | /// opposed to a raw Operation. |
513 | template <typename SourceOp> |
514 | class OpConversionPattern : public ConversionPattern { |
515 | public: |
516 | using OpAdaptor = typename SourceOp::Adaptor; |
517 | |
518 | OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) |
519 | : ConversionPattern(SourceOp::getOperationName(), benefit, context) {} |
520 | OpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, |
521 | PatternBenefit benefit = 1) |
522 | : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit, |
523 | context) {} |
524 | |
525 | /// Wrappers around the ConversionPattern methods that pass the derived op |
526 | /// type. |
527 | LogicalResult match(Operation *op) const final { |
528 | return match(cast<SourceOp>(op)); |
529 | } |
530 | void rewrite(Operation *op, ArrayRef<Value> operands, |
531 | ConversionPatternRewriter &rewriter) const final { |
532 | auto sourceOp = cast<SourceOp>(op); |
533 | rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); |
534 | } |
535 | LogicalResult |
536 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
537 | ConversionPatternRewriter &rewriter) const final { |
538 | auto sourceOp = cast<SourceOp>(op); |
539 | return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); |
540 | } |
541 | |
542 | /// Rewrite and Match methods that operate on the SourceOp type. These must be |
543 | /// overridden by the derived pattern class. |
544 | virtual LogicalResult match(SourceOp op) const { |
545 | llvm_unreachable("must override match or matchAndRewrite" ); |
546 | } |
547 | virtual void rewrite(SourceOp op, OpAdaptor adaptor, |
548 | ConversionPatternRewriter &rewriter) const { |
549 | llvm_unreachable("must override matchAndRewrite or a rewrite method" ); |
550 | } |
551 | virtual LogicalResult |
552 | matchAndRewrite(SourceOp op, OpAdaptor adaptor, |
553 | ConversionPatternRewriter &rewriter) const { |
554 | if (failed(match(op))) |
555 | return failure(); |
556 | rewrite(op, adaptor, rewriter); |
557 | return success(); |
558 | } |
559 | |
560 | private: |
561 | using ConversionPattern::matchAndRewrite; |
562 | }; |
563 | |
564 | /// OpInterfaceConversionPattern is a wrapper around ConversionPattern that |
565 | /// allows for matching and rewriting against an instance of an OpInterface |
566 | /// class as opposed to a raw Operation. |
567 | template <typename SourceOp> |
568 | class OpInterfaceConversionPattern : public ConversionPattern { |
569 | public: |
570 | OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) |
571 | : ConversionPattern(Pattern::MatchInterfaceOpTypeTag(), |
572 | SourceOp::getInterfaceID(), benefit, context) {} |
573 | OpInterfaceConversionPattern(const TypeConverter &typeConverter, |
574 | MLIRContext *context, PatternBenefit benefit = 1) |
575 | : ConversionPattern(typeConverter, Pattern::MatchInterfaceOpTypeTag(), |
576 | SourceOp::getInterfaceID(), benefit, context) {} |
577 | |
578 | /// Wrappers around the ConversionPattern methods that pass the derived op |
579 | /// type. |
580 | void rewrite(Operation *op, ArrayRef<Value> operands, |
581 | ConversionPatternRewriter &rewriter) const final { |
582 | rewrite(cast<SourceOp>(op), operands, rewriter); |
583 | } |
584 | LogicalResult |
585 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
586 | ConversionPatternRewriter &rewriter) const final { |
587 | return matchAndRewrite(cast<SourceOp>(op), operands, rewriter); |
588 | } |
589 | |
590 | /// Rewrite and Match methods that operate on the SourceOp type. These must be |
591 | /// overridden by the derived pattern class. |
592 | virtual void rewrite(SourceOp op, ArrayRef<Value> operands, |
593 | ConversionPatternRewriter &rewriter) const { |
594 | llvm_unreachable("must override matchAndRewrite or a rewrite method" ); |
595 | } |
596 | virtual LogicalResult |
597 | matchAndRewrite(SourceOp op, ArrayRef<Value> operands, |
598 | ConversionPatternRewriter &rewriter) const { |
599 | if (failed(match(op))) |
600 | return failure(); |
601 | rewrite(op, operands, rewriter); |
602 | return success(); |
603 | } |
604 | |
605 | private: |
606 | using ConversionPattern::matchAndRewrite; |
607 | }; |
608 | |
609 | /// OpTraitConversionPattern is a wrapper around ConversionPattern that allows |
610 | /// for matching and rewriting against instances of an operation that possess a |
611 | /// given trait. |
612 | template <template <typename> class TraitType> |
613 | class OpTraitConversionPattern : public ConversionPattern { |
614 | public: |
615 | OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) |
616 | : ConversionPattern(Pattern::MatchTraitOpTypeTag(), |
617 | TypeID::get<TraitType>(), benefit, context) {} |
618 | OpTraitConversionPattern(const TypeConverter &typeConverter, |
619 | MLIRContext *context, PatternBenefit benefit = 1) |
620 | : ConversionPattern(typeConverter, Pattern::MatchTraitOpTypeTag(), |
621 | TypeID::get<TraitType>(), benefit, context) {} |
622 | }; |
623 | |
624 | /// Generic utility to convert op result types according to type converter |
625 | /// without knowing exact op type. |
626 | /// Clones existing op with new result types and returns it. |
627 | FailureOr<Operation *> |
628 | convertOpResultTypes(Operation *op, ValueRange operands, |
629 | const TypeConverter &converter, |
630 | ConversionPatternRewriter &rewriter); |
631 | |
632 | /// Add a pattern to the given pattern list to convert the signature of a |
633 | /// FunctionOpInterface op with the given type converter. This only supports |
634 | /// ops which use FunctionType to represent their type. |
635 | void populateFunctionOpInterfaceTypeConversionPattern( |
636 | StringRef functionLikeOpName, RewritePatternSet &patterns, |
637 | const TypeConverter &converter); |
638 | |
639 | template <typename FuncOpT> |
640 | void populateFunctionOpInterfaceTypeConversionPattern( |
641 | RewritePatternSet &patterns, const TypeConverter &converter) { |
642 | populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(), |
643 | patterns, converter); |
644 | } |
645 | |
646 | void populateAnyFunctionOpInterfaceTypeConversionPattern( |
647 | RewritePatternSet &patterns, const TypeConverter &converter); |
648 | |
649 | //===----------------------------------------------------------------------===// |
650 | // Conversion PatternRewriter |
651 | //===----------------------------------------------------------------------===// |
652 | |
653 | namespace detail { |
654 | struct ConversionPatternRewriterImpl; |
655 | } // namespace detail |
656 | |
657 | /// This class implements a pattern rewriter for use with ConversionPatterns. It |
658 | /// extends the base PatternRewriter and provides special conversion specific |
659 | /// hooks. |
660 | class ConversionPatternRewriter final : public PatternRewriter { |
661 | public: |
662 | ~ConversionPatternRewriter() override; |
663 | |
664 | /// Apply a signature conversion to the entry block of the given region. This |
665 | /// replaces the entry block with a new block containing the updated |
666 | /// signature. The new entry block to the region is returned for convenience. |
667 | /// If no block argument types are changing, the entry original block will be |
668 | /// left in place and returned. |
669 | /// |
670 | /// If provided, `converter` will be used for any materializations. |
671 | Block * |
672 | applySignatureConversion(Region *region, |
673 | TypeConverter::SignatureConversion &conversion, |
674 | const TypeConverter *converter = nullptr); |
675 | |
676 | /// Convert the types of block arguments within the given region. This |
677 | /// replaces each block with a new block containing the updated signature. If |
678 | /// an updated signature would match the current signature, the respective |
679 | /// block is left in place as is. |
680 | /// |
681 | /// The entry block may have a special conversion if `entryConversion` is |
682 | /// provided. On success, the new entry block to the region is returned for |
683 | /// convenience. Otherwise, failure is returned. |
684 | FailureOr<Block *> convertRegionTypes( |
685 | Region *region, const TypeConverter &converter, |
686 | TypeConverter::SignatureConversion *entryConversion = nullptr); |
687 | |
688 | /// Convert the types of block arguments within the given region except for |
689 | /// the entry region. This replaces each non-entry block with a new block |
690 | /// containing the updated signature. If an updated signature would match the |
691 | /// current signature, the respective block is left in place as is. |
692 | /// |
693 | /// If special conversion behavior is needed for the non-entry blocks (for |
694 | /// example, we need to convert only a subset of a BB arguments), such |
695 | /// behavior can be specified in blockConversions. |
696 | LogicalResult convertNonEntryRegionTypes( |
697 | Region *region, const TypeConverter &converter, |
698 | ArrayRef<TypeConverter::SignatureConversion> blockConversions); |
699 | |
700 | /// Replace all the uses of the block argument `from` with value `to`. |
701 | void replaceUsesOfBlockArgument(BlockArgument from, Value to); |
702 | |
703 | /// Return the converted value of 'key' with a type defined by the type |
704 | /// converter of the currently executing pattern. Return nullptr in the case |
705 | /// of failure, the remapped value otherwise. |
706 | Value getRemappedValue(Value key); |
707 | |
708 | /// Return the converted values that replace 'keys' with types defined by the |
709 | /// type converter of the currently executing pattern. Returns failure if the |
710 | /// remap failed, success otherwise. |
711 | LogicalResult getRemappedValues(ValueRange keys, |
712 | SmallVectorImpl<Value> &results); |
713 | |
714 | //===--------------------------------------------------------------------===// |
715 | // PatternRewriter Hooks |
716 | //===--------------------------------------------------------------------===// |
717 | |
718 | /// Indicate that the conversion rewriter can recover from rewrite failure. |
719 | /// Recovery is supported via rollback, allowing for continued processing of |
720 | /// patterns even if a failure is encountered during the rewrite step. |
721 | bool canRecoverFromRewriteFailure() const override { return true; } |
722 | |
723 | /// PatternRewriter hook for replacing an operation. |
724 | void replaceOp(Operation *op, ValueRange newValues) override; |
725 | |
726 | /// PatternRewriter hook for replacing an operation. |
727 | void replaceOp(Operation *op, Operation *newOp) override; |
728 | |
729 | /// PatternRewriter hook for erasing a dead operation. The uses of this |
730 | /// operation *must* be made dead by the end of the conversion process, |
731 | /// otherwise an assert will be issued. |
732 | void eraseOp(Operation *op) override; |
733 | |
734 | /// PatternRewriter hook for erase all operations in a block. This is not yet |
735 | /// implemented for dialect conversion. |
736 | void eraseBlock(Block *block) override; |
737 | |
738 | /// PatternRewriter hook for inlining the ops of a block into another block. |
739 | void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, |
740 | ValueRange argValues = std::nullopt) override; |
741 | using PatternRewriter::inlineBlockBefore; |
742 | |
743 | /// PatternRewriter hook for updating the given operation in-place. |
744 | /// Note: These methods only track updates to the given operation itself, |
745 | /// and not nested regions. Updates to regions will still require notification |
746 | /// through other more specific hooks above. |
747 | void startOpModification(Operation *op) override; |
748 | |
749 | /// PatternRewriter hook for updating the given operation in-place. |
750 | void finalizeOpModification(Operation *op) override; |
751 | |
752 | /// PatternRewriter hook for updating the given operation in-place. |
753 | void cancelOpModification(Operation *op) override; |
754 | |
755 | /// Return a reference to the internal implementation. |
756 | detail::ConversionPatternRewriterImpl &getImpl(); |
757 | |
758 | private: |
759 | // Allow OperationConverter to construct new rewriters. |
760 | friend struct OperationConverter; |
761 | |
762 | /// Conversion pattern rewriters must not be used outside of dialect |
763 | /// conversions. They apply some IR rewrites in a delayed fashion and could |
764 | /// bring the IR into an inconsistent state when used standalone. |
765 | explicit ConversionPatternRewriter(MLIRContext *ctx, |
766 | const ConversionConfig &config); |
767 | |
768 | // Hide unsupported pattern rewriter API. |
769 | using OpBuilder::setListener; |
770 | |
771 | std::unique_ptr<detail::ConversionPatternRewriterImpl> impl; |
772 | }; |
773 | |
774 | //===----------------------------------------------------------------------===// |
775 | // ConversionTarget |
776 | //===----------------------------------------------------------------------===// |
777 | |
778 | /// This class describes a specific conversion target. |
779 | class ConversionTarget { |
780 | public: |
781 | /// This enumeration corresponds to the specific action to take when |
782 | /// considering an operation legal for this conversion target. |
783 | enum class LegalizationAction { |
784 | /// The target supports this operation. |
785 | Legal, |
786 | |
787 | /// This operation has dynamic legalization constraints that must be checked |
788 | /// by the target. |
789 | Dynamic, |
790 | |
791 | /// The target explicitly does not support this operation. |
792 | Illegal, |
793 | }; |
794 | |
795 | /// A structure containing additional information describing a specific legal |
796 | /// operation instance. |
797 | struct LegalOpDetails { |
798 | /// A flag that indicates if this operation is 'recursively' legal. This |
799 | /// means that if an operation is legal, either statically or dynamically, |
800 | /// all of the operations nested within are also considered legal. |
801 | bool isRecursivelyLegal = false; |
802 | }; |
803 | |
804 | /// The signature of the callback used to determine if an operation is |
805 | /// dynamically legal on the target. |
806 | using DynamicLegalityCallbackFn = |
807 | std::function<std::optional<bool>(Operation *)>; |
808 | |
809 | ConversionTarget(MLIRContext &ctx) : ctx(ctx) {} |
810 | virtual ~ConversionTarget() = default; |
811 | |
812 | //===--------------------------------------------------------------------===// |
813 | // Legality Registration |
814 | //===--------------------------------------------------------------------===// |
815 | |
816 | /// Register a legality action for the given operation. |
817 | void setOpAction(OperationName op, LegalizationAction action); |
818 | template <typename OpT> |
819 | void setOpAction(LegalizationAction action) { |
820 | setOpAction(op: OperationName(OpT::getOperationName(), &ctx), action); |
821 | } |
822 | |
823 | /// Register the given operations as legal. |
824 | void addLegalOp(OperationName op) { |
825 | setOpAction(op, action: LegalizationAction::Legal); |
826 | } |
827 | template <typename OpT> |
828 | void addLegalOp() { |
829 | addLegalOp(op: OperationName(OpT::getOperationName(), &ctx)); |
830 | } |
831 | template <typename OpT, typename OpT2, typename... OpTs> |
832 | void addLegalOp() { |
833 | addLegalOp<OpT>(); |
834 | addLegalOp<OpT2, OpTs...>(); |
835 | } |
836 | |
837 | /// Register the given operation as dynamically legal and set the dynamic |
838 | /// legalization callback to the one provided. |
839 | void addDynamicallyLegalOp(OperationName op, |
840 | const DynamicLegalityCallbackFn &callback) { |
841 | setOpAction(op, action: LegalizationAction::Dynamic); |
842 | setLegalityCallback(name: op, callback); |
843 | } |
844 | template <typename OpT> |
845 | void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) { |
846 | addDynamicallyLegalOp(op: OperationName(OpT::getOperationName(), &ctx), |
847 | callback); |
848 | } |
849 | template <typename OpT, typename OpT2, typename... OpTs> |
850 | void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) { |
851 | addDynamicallyLegalOp<OpT>(callback); |
852 | addDynamicallyLegalOp<OpT2, OpTs...>(callback); |
853 | } |
854 | template <typename OpT, class Callable> |
855 | std::enable_if_t<!std::is_invocable_v<Callable, Operation *>> |
856 | addDynamicallyLegalOp(Callable &&callback) { |
857 | addDynamicallyLegalOp<OpT>( |
858 | [=](Operation *op) { return callback(cast<OpT>(op)); }); |
859 | } |
860 | |
861 | /// Register the given operation as illegal, i.e. this operation is known to |
862 | /// not be supported by this target. |
863 | void addIllegalOp(OperationName op) { |
864 | setOpAction(op, action: LegalizationAction::Illegal); |
865 | } |
866 | template <typename OpT> |
867 | void addIllegalOp() { |
868 | addIllegalOp(op: OperationName(OpT::getOperationName(), &ctx)); |
869 | } |
870 | template <typename OpT, typename OpT2, typename... OpTs> |
871 | void addIllegalOp() { |
872 | addIllegalOp<OpT>(); |
873 | addIllegalOp<OpT2, OpTs...>(); |
874 | } |
875 | |
876 | /// Mark an operation, that *must* have either been set as `Legal` or |
877 | /// `DynamicallyLegal`, as being recursively legal. This means that in |
878 | /// addition to the operation itself, all of the operations nested within are |
879 | /// also considered legal. An optional dynamic legality callback may be |
880 | /// provided to mark subsets of legal instances as recursively legal. |
881 | void markOpRecursivelyLegal(OperationName name, |
882 | const DynamicLegalityCallbackFn &callback); |
883 | template <typename OpT> |
884 | void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) { |
885 | OperationName opName(OpT::getOperationName(), &ctx); |
886 | markOpRecursivelyLegal(name: opName, callback); |
887 | } |
888 | template <typename OpT, typename OpT2, typename... OpTs> |
889 | void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) { |
890 | markOpRecursivelyLegal<OpT>(callback); |
891 | markOpRecursivelyLegal<OpT2, OpTs...>(callback); |
892 | } |
893 | template <typename OpT, class Callable> |
894 | std::enable_if_t<!std::is_invocable_v<Callable, Operation *>> |
895 | markOpRecursivelyLegal(Callable &&callback) { |
896 | markOpRecursivelyLegal<OpT>( |
897 | [=](Operation *op) { return callback(cast<OpT>(op)); }); |
898 | } |
899 | |
900 | /// Register a legality action for the given dialects. |
901 | void setDialectAction(ArrayRef<StringRef> dialectNames, |
902 | LegalizationAction action); |
903 | |
904 | /// Register the operations of the given dialects as legal. |
905 | template <typename... Names> |
906 | void addLegalDialect(StringRef name, Names... names) { |
907 | SmallVector<StringRef, 2> dialectNames({name, names...}); |
908 | setDialectAction(dialectNames, action: LegalizationAction::Legal); |
909 | } |
910 | template <typename... Args> |
911 | void addLegalDialect() { |
912 | SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...}); |
913 | setDialectAction(dialectNames, action: LegalizationAction::Legal); |
914 | } |
915 | |
916 | /// Register the operations of the given dialects as dynamically legal, i.e. |
917 | /// requiring custom handling by the callback. |
918 | template <typename... Names> |
919 | void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, |
920 | StringRef name, Names... names) { |
921 | SmallVector<StringRef, 2> dialectNames({name, names...}); |
922 | setDialectAction(dialectNames, action: LegalizationAction::Dynamic); |
923 | setLegalityCallback(dialects: dialectNames, callback); |
924 | } |
925 | template <typename... Args> |
926 | void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback) { |
927 | addDynamicallyLegalDialect(std::move(callback), |
928 | Args::getDialectNamespace()...); |
929 | } |
930 | |
931 | /// Register unknown operations as dynamically legal. For operations(and |
932 | /// dialects) that do not have a set legalization action, treat them as |
933 | /// dynamically legal and invoke the given callback. |
934 | void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn) { |
935 | setLegalityCallback(fn); |
936 | } |
937 | |
938 | /// Register the operations of the given dialects as illegal, i.e. |
939 | /// operations of this dialect are not supported by the target. |
940 | template <typename... Names> |
941 | void addIllegalDialect(StringRef name, Names... names) { |
942 | SmallVector<StringRef, 2> dialectNames({name, names...}); |
943 | setDialectAction(dialectNames, action: LegalizationAction::Illegal); |
944 | } |
945 | template <typename... Args> |
946 | void addIllegalDialect() { |
947 | SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...}); |
948 | setDialectAction(dialectNames, action: LegalizationAction::Illegal); |
949 | } |
950 | |
951 | //===--------------------------------------------------------------------===// |
952 | // Legality Querying |
953 | //===--------------------------------------------------------------------===// |
954 | |
955 | /// Get the legality action for the given operation. |
956 | std::optional<LegalizationAction> getOpAction(OperationName op) const; |
957 | |
958 | /// If the given operation instance is legal on this target, a structure |
959 | /// containing legality information is returned. If the operation is not |
960 | /// legal, std::nullopt is returned. Also returns std::nullopt if operation |
961 | /// legality wasn't registered by user or dynamic legality callbacks returned |
962 | /// None. |
963 | /// |
964 | /// Note: Legality is actually a 4-state: Legal(recursive=true), |
965 | /// Legal(recursive=false), Illegal or Unknown, where Unknown is treated |
966 | /// either as Legal or Illegal depending on context. |
967 | std::optional<LegalOpDetails> isLegal(Operation *op) const; |
968 | |
969 | /// Returns true is operation instance is illegal on this target. Returns |
970 | /// false if operation is legal, operation legality wasn't registered by user |
971 | /// or dynamic legality callbacks returned None. |
972 | bool isIllegal(Operation *op) const; |
973 | |
974 | private: |
975 | /// Set the dynamic legality callback for the given operation. |
976 | void setLegalityCallback(OperationName name, |
977 | const DynamicLegalityCallbackFn &callback); |
978 | |
979 | /// Set the dynamic legality callback for the given dialects. |
980 | void setLegalityCallback(ArrayRef<StringRef> dialects, |
981 | const DynamicLegalityCallbackFn &callback); |
982 | |
983 | /// Set the dynamic legality callback for the unknown ops. |
984 | void setLegalityCallback(const DynamicLegalityCallbackFn &callback); |
985 | |
986 | /// The set of information that configures the legalization of an operation. |
987 | struct LegalizationInfo { |
988 | /// The legality action this operation was given. |
989 | LegalizationAction action = LegalizationAction::Illegal; |
990 | |
991 | /// If some legal instances of this operation may also be recursively legal. |
992 | bool isRecursivelyLegal = false; |
993 | |
994 | /// The legality callback if this operation is dynamically legal. |
995 | DynamicLegalityCallbackFn legalityFn; |
996 | }; |
997 | |
998 | /// Get the legalization information for the given operation. |
999 | std::optional<LegalizationInfo> getOpInfo(OperationName op) const; |
1000 | |
1001 | /// A deterministic mapping of operation name and its respective legality |
1002 | /// information. |
1003 | llvm::MapVector<OperationName, LegalizationInfo> legalOperations; |
1004 | |
1005 | /// A set of legality callbacks for given operation names that are used to |
1006 | /// check if an operation instance is recursively legal. |
1007 | DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns; |
1008 | |
1009 | /// A deterministic mapping of dialect name to the specific legality action to |
1010 | /// take. |
1011 | llvm::StringMap<LegalizationAction> legalDialects; |
1012 | |
1013 | /// A set of dynamic legality callbacks for given dialect names. |
1014 | llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns; |
1015 | |
1016 | /// An optional legality callback for unknown operations. |
1017 | DynamicLegalityCallbackFn unknownLegalityFn; |
1018 | |
1019 | /// The current context this target applies to. |
1020 | MLIRContext &ctx; |
1021 | }; |
1022 | |
1023 | #if MLIR_ENABLE_PDL_IN_PATTERNMATCH |
1024 | //===----------------------------------------------------------------------===// |
1025 | // PDL Configuration |
1026 | //===----------------------------------------------------------------------===// |
1027 | |
1028 | /// A PDL configuration that is used to supported dialect conversion |
1029 | /// functionality. |
1030 | class PDLConversionConfig final |
1031 | : public PDLPatternConfigBase<PDLConversionConfig> { |
1032 | public: |
1033 | PDLConversionConfig(const TypeConverter *converter) : converter(converter) {} |
1034 | ~PDLConversionConfig() final = default; |
1035 | |
1036 | /// Return the type converter used by this configuration, which may be nullptr |
1037 | /// if no type conversions are expected. |
1038 | const TypeConverter *getTypeConverter() const { return converter; } |
1039 | |
1040 | /// Hooks that are invoked at the beginning and end of a rewrite of a matched |
1041 | /// pattern. |
1042 | void notifyRewriteBegin(PatternRewriter &rewriter) final; |
1043 | void notifyRewriteEnd(PatternRewriter &rewriter) final; |
1044 | |
1045 | private: |
1046 | /// An optional type converter to use for the pattern. |
1047 | const TypeConverter *converter; |
1048 | }; |
1049 | |
1050 | /// Register the dialect conversion PDL functions with the given pattern set. |
1051 | void registerConversionPDLFunctions(RewritePatternSet &patterns); |
1052 | |
1053 | #else |
1054 | |
1055 | // Stubs for when PDL in rewriting is not enabled. |
1056 | |
1057 | inline void registerConversionPDLFunctions(RewritePatternSet &patterns) {} |
1058 | |
1059 | class PDLConversionConfig final { |
1060 | public: |
1061 | PDLConversionConfig(const TypeConverter * /*converter*/) {} |
1062 | }; |
1063 | |
1064 | #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH |
1065 | |
1066 | //===----------------------------------------------------------------------===// |
1067 | // ConversionConfig |
1068 | //===----------------------------------------------------------------------===// |
1069 | |
1070 | /// Dialect conversion configuration. |
1071 | struct ConversionConfig { |
1072 | /// An optional callback used to notify about match failure diagnostics during |
1073 | /// the conversion. Diagnostics reported to this callback may only be |
1074 | /// available in debug mode. |
1075 | function_ref<void(Diagnostic &)> notifyCallback = nullptr; |
1076 | |
1077 | /// Partial conversion only. All operations that are found not to be |
1078 | /// legalizable are placed in this set. (Note that if there is an op |
1079 | /// explicitly marked as illegal, the conversion terminates and the set will |
1080 | /// not necessarily be complete.) |
1081 | DenseSet<Operation *> *unlegalizedOps = nullptr; |
1082 | |
1083 | /// Analysis conversion only. All operations that are found to be legalizable |
1084 | /// are placed in this set. Note that no actual rewrites are applied to the |
1085 | /// IR during an analysis conversion and only pre-existing operations are |
1086 | /// added to the set. |
1087 | DenseSet<Operation *> *legalizableOps = nullptr; |
1088 | |
1089 | /// An optional listener that is notified about all IR modifications in case |
1090 | /// dialect conversion succeeds. If the dialect conversion fails and no IR |
1091 | /// modifications are visible (i.e., they were all rolled back), no |
1092 | /// notifications are sent. |
1093 | /// |
1094 | /// Note: Notifications are sent in a delayed fashion, when the dialect |
1095 | /// conversion is guaranteed to succeed. At that point, some IR modifications |
1096 | /// may already have been materialized. Consequently, operations/blocks that |
1097 | /// are passed to listener callbacks should not be accessed. (Ops/blocks are |
1098 | /// guaranteed to be valid pointers and accessing op names is allowed. But |
1099 | /// there are no guarantees about the state of ops/blocks at the time that a |
1100 | /// callback is triggered.) |
1101 | /// |
1102 | /// Example: Consider a dialect conversion a new op ("test.foo") is created |
1103 | /// and inserted, and later moved to another block. (Moving ops also triggers |
1104 | /// "notifyOperationInserted".) |
1105 | /// |
1106 | /// (1) notifyOperationInserted: "test.foo" (into block "b1") |
1107 | /// (2) notifyOperationInserted: "test.foo" (moved to another block "b2") |
1108 | /// |
1109 | /// When querying "op->getBlock()" during the first "notifyOperationInserted", |
1110 | /// "b2" would be returned because "moving an op" is a kind of rewrite that is |
1111 | /// immediately performed by the dialect conversion (and rolled back upon |
1112 | /// failure). |
1113 | // |
1114 | // Note: When receiving a "notifyBlockInserted"/"notifyOperationInserted" |
1115 | // callback, the previous region/block is provided to the callback, but not |
1116 | // the iterator pointing to the exact location within the region/block. That |
1117 | // is because these notifications are sent with a delay (after the IR has |
1118 | // already been modified) and iterators into past IR state cannot be |
1119 | // represented at the moment. |
1120 | RewriterBase::Listener *listener = nullptr; |
1121 | }; |
1122 | |
1123 | //===----------------------------------------------------------------------===// |
1124 | // Op Conversion Entry Points |
1125 | //===----------------------------------------------------------------------===// |
1126 | |
1127 | /// Below we define several entry points for operation conversion. It is |
1128 | /// important to note that the patterns provided to the conversion framework may |
1129 | /// have additional constraints. See the `PatternRewriter Hooks` section of the |
1130 | /// ConversionPatternRewriter, to see what additional constraints are imposed on |
1131 | /// the use of the PatternRewriter. |
1132 | |
1133 | /// Apply a partial conversion on the given operations and all nested |
1134 | /// operations. This method converts as many operations to the target as |
1135 | /// possible, ignoring operations that failed to legalize. This method only |
1136 | /// returns failure if there ops explicitly marked as illegal. |
1137 | LogicalResult |
1138 | applyPartialConversion(ArrayRef<Operation *> ops, |
1139 | const ConversionTarget &target, |
1140 | const FrozenRewritePatternSet &patterns, |
1141 | ConversionConfig config = ConversionConfig()); |
1142 | LogicalResult |
1143 | applyPartialConversion(Operation *op, const ConversionTarget &target, |
1144 | const FrozenRewritePatternSet &patterns, |
1145 | ConversionConfig config = ConversionConfig()); |
1146 | |
1147 | /// Apply a complete conversion on the given operations, and all nested |
1148 | /// operations. This method returns failure if the conversion of any operation |
1149 | /// fails, or if there are unreachable blocks in any of the regions nested |
1150 | /// within 'ops'. |
1151 | LogicalResult applyFullConversion(ArrayRef<Operation *> ops, |
1152 | const ConversionTarget &target, |
1153 | const FrozenRewritePatternSet &patterns, |
1154 | ConversionConfig config = ConversionConfig()); |
1155 | LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target, |
1156 | const FrozenRewritePatternSet &patterns, |
1157 | ConversionConfig config = ConversionConfig()); |
1158 | |
1159 | /// Apply an analysis conversion on the given operations, and all nested |
1160 | /// operations. This method analyzes which operations would be successfully |
1161 | /// converted to the target if a conversion was applied. All operations that |
1162 | /// were found to be legalizable to the given 'target' are placed within the |
1163 | /// provided 'config.legalizableOps' set; note that no actual rewrites are |
1164 | /// applied to the operations on success. This method only returns failure if |
1165 | /// there are unreachable blocks in any of the regions nested within 'ops'. |
1166 | LogicalResult |
1167 | applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target, |
1168 | const FrozenRewritePatternSet &patterns, |
1169 | ConversionConfig config = ConversionConfig()); |
1170 | LogicalResult |
1171 | applyAnalysisConversion(Operation *op, ConversionTarget &target, |
1172 | const FrozenRewritePatternSet &patterns, |
1173 | ConversionConfig config = ConversionConfig()); |
1174 | } // namespace mlir |
1175 | |
1176 | #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_ |
1177 | |