1 | //===- PDLPatternMatch.h - PDLPatternMatcher classes -------==---*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_IR_PDLPATTERNMATCH_H |
10 | #define MLIR_IR_PDLPATTERNMATCH_H |
11 | |
12 | #include "mlir/Config/mlir-config.h" |
13 | |
14 | #if MLIR_ENABLE_PDL_IN_PATTERNMATCH |
15 | #include "mlir/IR/Builders.h" |
16 | #include "mlir/IR/BuiltinOps.h" |
17 | |
18 | namespace mlir { |
19 | //===----------------------------------------------------------------------===// |
20 | // PDL Patterns |
21 | //===----------------------------------------------------------------------===// |
22 | |
23 | //===----------------------------------------------------------------------===// |
24 | // PDLValue |
25 | |
26 | /// Storage type of byte-code interpreter values. These are passed to constraint |
27 | /// functions as arguments. |
28 | class PDLValue { |
29 | public: |
30 | /// The underlying kind of a PDL value. |
31 | enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange }; |
32 | |
33 | /// Construct a new PDL value. |
34 | PDLValue(const PDLValue &other) = default; |
35 | PDLValue(std::nullptr_t = nullptr) {} |
36 | PDLValue(Attribute value) |
37 | : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {} |
38 | PDLValue(Operation *value) : value(value), kind(Kind::Operation) {} |
39 | PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {} |
40 | PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {} |
41 | PDLValue(Value value) |
42 | : value(value.getAsOpaquePointer()), kind(Kind::Value) {} |
43 | PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {} |
44 | |
45 | /// Returns true if the type of the held value is `T`. |
46 | template <typename T> |
47 | bool isa() const { |
48 | assert(value && "isa<> used on a null value" ); |
49 | return kind == getKindOf<T>(); |
50 | } |
51 | |
52 | /// Attempt to dynamically cast this value to type `T`, returns null if this |
53 | /// value is not an instance of `T`. |
54 | template <typename T, |
55 | typename ResultT = std::conditional_t< |
56 | std::is_convertible<T, bool>::value, T, std::optional<T>>> |
57 | ResultT dyn_cast() const { |
58 | return isa<T>() ? castImpl<T>() : ResultT(); |
59 | } |
60 | |
61 | /// Cast this value to type `T`, asserts if this value is not an instance of |
62 | /// `T`. |
63 | template <typename T> |
64 | T cast() const { |
65 | assert(isa<T>() && "expected value to be of type `T`" ); |
66 | return castImpl<T>(); |
67 | } |
68 | |
69 | /// Get an opaque pointer to the value. |
70 | const void *getAsOpaquePointer() const { return value; } |
71 | |
72 | /// Return if this value is null or not. |
73 | explicit operator bool() const { return value; } |
74 | |
75 | /// Return the kind of this value. |
76 | Kind getKind() const { return kind; } |
77 | |
78 | /// Print this value to the provided output stream. |
79 | void print(raw_ostream &os) const; |
80 | |
81 | /// Print the specified value kind to an output stream. |
82 | static void print(raw_ostream &os, Kind kind); |
83 | |
84 | private: |
85 | /// Find the index of a given type in a range of other types. |
86 | template <typename...> |
87 | struct index_of_t; |
88 | template <typename T, typename... R> |
89 | struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {}; |
90 | template <typename T, typename F, typename... R> |
91 | struct index_of_t<T, F, R...> |
92 | : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {}; |
93 | |
94 | /// Return the kind used for the given T. |
95 | template <typename T> |
96 | static Kind getKindOf() { |
97 | return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type, |
98 | TypeRange, Value, ValueRange>::value); |
99 | } |
100 | |
101 | /// The internal implementation of `cast`, that returns the underlying value |
102 | /// as the given type `T`. |
103 | template <typename T> |
104 | std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T> |
105 | castImpl() const { |
106 | return T::getFromOpaquePointer(value); |
107 | } |
108 | template <typename T> |
109 | std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T> |
110 | castImpl() const { |
111 | return *reinterpret_cast<T *>(const_cast<void *>(value)); |
112 | } |
113 | template <typename T> |
114 | std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const { |
115 | return reinterpret_cast<T>(const_cast<void *>(value)); |
116 | } |
117 | |
118 | /// The internal opaque representation of a PDLValue. |
119 | const void *value{nullptr}; |
120 | /// The kind of the opaque value. |
121 | Kind kind{Kind::Attribute}; |
122 | }; |
123 | |
124 | inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) { |
125 | value.print(os); |
126 | return os; |
127 | } |
128 | |
129 | inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) { |
130 | PDLValue::print(os, kind); |
131 | return os; |
132 | } |
133 | |
134 | //===----------------------------------------------------------------------===// |
135 | // PDLResultList |
136 | |
137 | /// The class represents a list of PDL results, returned by a native rewrite |
138 | /// method. It provides the mechanism with which to pass PDLValues back to the |
139 | /// PDL bytecode. |
140 | class PDLResultList { |
141 | public: |
142 | /// Push a new Attribute value onto the result list. |
143 | void push_back(Attribute value) { results.push_back(Elt: value); } |
144 | |
145 | /// Push a new Operation onto the result list. |
146 | void push_back(Operation *value) { results.push_back(Elt: value); } |
147 | |
148 | /// Push a new Type onto the result list. |
149 | void push_back(Type value) { results.push_back(Elt: value); } |
150 | |
151 | /// Push a new TypeRange onto the result list. |
152 | void push_back(TypeRange value) { |
153 | // The lifetime of a TypeRange can't be guaranteed, so we'll need to |
154 | // allocate a storage for it. |
155 | llvm::OwningArrayRef<Type> storage(value.size()); |
156 | llvm::copy(Range&: value, Out: storage.begin()); |
157 | allocatedTypeRanges.emplace_back(Args: std::move(storage)); |
158 | typeRanges.push_back(Elt: allocatedTypeRanges.back()); |
159 | results.push_back(Elt: &typeRanges.back()); |
160 | } |
161 | void push_back(ValueTypeRange<OperandRange> value) { |
162 | typeRanges.push_back(Elt: value); |
163 | results.push_back(Elt: &typeRanges.back()); |
164 | } |
165 | void push_back(ValueTypeRange<ResultRange> value) { |
166 | typeRanges.push_back(Elt: value); |
167 | results.push_back(Elt: &typeRanges.back()); |
168 | } |
169 | |
170 | /// Push a new Value onto the result list. |
171 | void push_back(Value value) { results.push_back(Elt: value); } |
172 | |
173 | /// Push a new ValueRange onto the result list. |
174 | void push_back(ValueRange value) { |
175 | // The lifetime of a ValueRange can't be guaranteed, so we'll need to |
176 | // allocate a storage for it. |
177 | llvm::OwningArrayRef<Value> storage(value.size()); |
178 | llvm::copy(Range&: value, Out: storage.begin()); |
179 | allocatedValueRanges.emplace_back(Args: std::move(storage)); |
180 | valueRanges.push_back(Elt: allocatedValueRanges.back()); |
181 | results.push_back(Elt: &valueRanges.back()); |
182 | } |
183 | void push_back(OperandRange value) { |
184 | valueRanges.push_back(Elt: value); |
185 | results.push_back(Elt: &valueRanges.back()); |
186 | } |
187 | void push_back(ResultRange value) { |
188 | valueRanges.push_back(Elt: value); |
189 | results.push_back(Elt: &valueRanges.back()); |
190 | } |
191 | |
192 | protected: |
193 | /// Create a new result list with the expected number of results. |
194 | PDLResultList(unsigned maxNumResults) { |
195 | // For now just reserve enough space for all of the results. We could do |
196 | // separate counts per range type, but it isn't really worth it unless there |
197 | // are a "large" number of results. |
198 | typeRanges.reserve(N: maxNumResults); |
199 | valueRanges.reserve(N: maxNumResults); |
200 | } |
201 | |
202 | /// The PDL results held by this list. |
203 | SmallVector<PDLValue> results; |
204 | /// Memory used to store ranges held by the list. |
205 | SmallVector<TypeRange> typeRanges; |
206 | SmallVector<ValueRange> valueRanges; |
207 | /// Memory allocated to store ranges in the result list whose lifetime was |
208 | /// generated in the native function. |
209 | SmallVector<llvm::OwningArrayRef<Type>> allocatedTypeRanges; |
210 | SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges; |
211 | }; |
212 | |
213 | //===----------------------------------------------------------------------===// |
214 | // PDLPatternConfig |
215 | |
216 | /// An individual configuration for a pattern, which can be accessed by native |
217 | /// functions via the PDLPatternConfigSet. This allows for injecting additional |
218 | /// configuration into PDL patterns that is specific to certain compilation |
219 | /// flows. |
220 | class PDLPatternConfig { |
221 | public: |
222 | virtual ~PDLPatternConfig() = default; |
223 | |
224 | /// Hooks that are invoked at the beginning and end of a rewrite of a matched |
225 | /// pattern. These can be used to setup any specific state necessary for the |
226 | /// rewrite. |
227 | virtual void notifyRewriteBegin(PatternRewriter &rewriter) {} |
228 | virtual void notifyRewriteEnd(PatternRewriter &rewriter) {} |
229 | |
230 | /// Return the TypeID that represents this configuration. |
231 | TypeID getTypeID() const { return id; } |
232 | |
233 | protected: |
234 | PDLPatternConfig(TypeID id) : id(id) {} |
235 | |
236 | private: |
237 | TypeID id; |
238 | }; |
239 | |
240 | /// This class provides a base class for users implementing a type of pattern |
241 | /// configuration. |
242 | template <typename T> |
243 | class PDLPatternConfigBase : public PDLPatternConfig { |
244 | public: |
245 | /// Support LLVM style casting. |
246 | static bool classof(const PDLPatternConfig *config) { |
247 | return config->getTypeID() == getConfigID(); |
248 | } |
249 | |
250 | /// Return the type id used for this configuration. |
251 | static TypeID getConfigID() { return TypeID::get<T>(); } |
252 | |
253 | protected: |
254 | PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {} |
255 | }; |
256 | |
257 | /// This class contains a set of configurations for a specific pattern. |
258 | /// Configurations are uniqued by TypeID, meaning that only one configuration of |
259 | /// each type is allowed. |
260 | class PDLPatternConfigSet { |
261 | public: |
262 | PDLPatternConfigSet() = default; |
263 | |
264 | /// Construct a set with the given configurations. |
265 | template <typename... ConfigsT> |
266 | PDLPatternConfigSet(ConfigsT &&...configs) { |
267 | (addConfig(std::forward<ConfigsT>(configs)), ...); |
268 | } |
269 | |
270 | /// Get the configuration defined by the given type. Asserts that the |
271 | /// configuration of the provided type exists. |
272 | template <typename T> |
273 | const T &get() const { |
274 | const T *config = tryGet<T>(); |
275 | assert(config && "configuration not found" ); |
276 | return *config; |
277 | } |
278 | |
279 | /// Get the configuration defined by the given type, returns nullptr if the |
280 | /// configuration does not exist. |
281 | template <typename T> |
282 | const T *tryGet() const { |
283 | for (const auto &configIt : configs) |
284 | if (const T *config = dyn_cast<T>(configIt.get())) |
285 | return config; |
286 | return nullptr; |
287 | } |
288 | |
289 | /// Notify the configurations within this set at the beginning or end of a |
290 | /// rewrite of a matched pattern. |
291 | void notifyRewriteBegin(PatternRewriter &rewriter) { |
292 | for (const auto &config : configs) |
293 | config->notifyRewriteBegin(rewriter); |
294 | } |
295 | void notifyRewriteEnd(PatternRewriter &rewriter) { |
296 | for (const auto &config : configs) |
297 | config->notifyRewriteEnd(rewriter); |
298 | } |
299 | |
300 | protected: |
301 | /// Add a configuration to the set. |
302 | template <typename T> |
303 | void addConfig(T &&config) { |
304 | assert(!tryGet<std::decay_t<T>>() && "configuration already exists" ); |
305 | configs.emplace_back( |
306 | std::make_unique<std::decay_t<T>>(std::forward<T>(config))); |
307 | } |
308 | |
309 | /// The set of configurations for this pattern. This uses a vector instead of |
310 | /// a map with the expectation that the number of configurations per set is |
311 | /// small (<= 1). |
312 | SmallVector<std::unique_ptr<PDLPatternConfig>> configs; |
313 | }; |
314 | |
315 | //===----------------------------------------------------------------------===// |
316 | // PDLPatternModule |
317 | |
318 | /// A generic PDL pattern constraint function. This function applies a |
319 | /// constraint to a given set of opaque PDLValue entities. Returns success if |
320 | /// the constraint successfully held, failure otherwise. |
321 | using PDLConstraintFunction = std::function<LogicalResult( |
322 | PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>; |
323 | |
324 | /// A native PDL rewrite function. This function performs a rewrite on the |
325 | /// given set of values. Any results from this rewrite that should be passed |
326 | /// back to PDL should be added to the provided result list. This method is only |
327 | /// invoked when the corresponding match was successful. Returns failure if an |
328 | /// invariant of the rewrite was broken (certain rewriters may recover from |
329 | /// partial pattern application). |
330 | using PDLRewriteFunction = std::function<LogicalResult( |
331 | PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>; |
332 | |
333 | namespace detail { |
334 | namespace pdl_function_builder { |
335 | /// A utility variable that always resolves to false. This is useful for static |
336 | /// asserts that are always false, but only should fire in certain templated |
337 | /// constructs. For example, if a templated function should never be called, the |
338 | /// function could be defined as: |
339 | /// |
340 | /// template <typename T> |
341 | /// void foo() { |
342 | /// static_assert(always_false<T>, "This function should never be called"); |
343 | /// } |
344 | /// |
345 | template <class... T> |
346 | constexpr bool always_false = false; |
347 | |
348 | //===----------------------------------------------------------------------===// |
349 | // PDL Function Builder: Type Processing |
350 | //===----------------------------------------------------------------------===// |
351 | |
352 | /// This struct provides a convenient way to determine how to process a given |
353 | /// type as either a PDL parameter, or a result value. This allows for |
354 | /// supporting complex types in constraint and rewrite functions, without |
355 | /// requiring the user to hand-write the necessary glue code themselves. |
356 | /// Specializations of this class should implement the following methods to |
357 | /// enable support as a PDL argument or result type: |
358 | /// |
359 | /// static LogicalResult verifyAsArg( |
360 | /// function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue, |
361 | /// size_t argIdx); |
362 | /// |
363 | /// * This method verifies that the given PDLValue is valid for use as a |
364 | /// value of `T`. |
365 | /// |
366 | /// static T processAsArg(PDLValue pdlValue); |
367 | /// |
368 | /// * This method processes the given PDLValue as a value of `T`. |
369 | /// |
370 | /// static void processAsResult(PatternRewriter &, PDLResultList &results, |
371 | /// const T &value); |
372 | /// |
373 | /// * This method processes the given value of `T` as the result of a |
374 | /// function invocation. The method should package the value into an |
375 | /// appropriate form and append it to the given result list. |
376 | /// |
377 | /// If the type `T` is based on a higher order value, consider using |
378 | /// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify |
379 | /// the implementation. |
380 | /// |
381 | template <typename T, typename Enable = void> |
382 | struct ProcessPDLValue; |
383 | |
384 | /// This struct provides a simplified model for processing types that are based |
385 | /// on another type, e.g. APInt is based on the handling for IntegerAttr. This |
386 | /// allows for building the necessary processing functions on top of the base |
387 | /// value instead of a PDLValue. Derived users should implement the following |
388 | /// (which subsume the ProcessPDLValue variants): |
389 | /// |
390 | /// static LogicalResult verifyAsArg( |
391 | /// function_ref<LogicalResult(const Twine &)> errorFn, |
392 | /// const BaseT &baseValue, size_t argIdx); |
393 | /// |
394 | /// * This method verifies that the given PDLValue is valid for use as a |
395 | /// value of `T`. |
396 | /// |
397 | /// static T processAsArg(BaseT baseValue); |
398 | /// |
399 | /// * This method processes the given base value as a value of `T`. |
400 | /// |
401 | template <typename T, typename BaseT> |
402 | struct ProcessPDLValueBasedOn { |
403 | static LogicalResult |
404 | verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, |
405 | PDLValue pdlValue, size_t argIdx) { |
406 | // Verify the base class before continuing. |
407 | if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx))) |
408 | return failure(); |
409 | return ProcessPDLValue<T>::verifyAsArg( |
410 | errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx); |
411 | } |
412 | static T processAsArg(PDLValue pdlValue) { |
413 | return ProcessPDLValue<T>::processAsArg( |
414 | ProcessPDLValue<BaseT>::processAsArg(pdlValue)); |
415 | } |
416 | |
417 | /// Explicitly add the expected parent API to ensure the parent class |
418 | /// implements the necessary API (and doesn't implicitly inherit it from |
419 | /// somewhere else). |
420 | static LogicalResult |
421 | verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, BaseT value, |
422 | size_t argIdx) { |
423 | return success(); |
424 | } |
425 | static T processAsArg(BaseT baseValue); |
426 | }; |
427 | |
428 | /// This struct provides a simplified model for processing types that have |
429 | /// "builtin" PDLValue support: |
430 | /// * Attribute, Operation *, Type, TypeRange, ValueRange |
431 | template <typename T> |
432 | struct ProcessBuiltinPDLValue { |
433 | static LogicalResult |
434 | verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, |
435 | PDLValue pdlValue, size_t argIdx) { |
436 | if (pdlValue) |
437 | return success(); |
438 | return errorFn("expected a non-null value for argument " + Twine(argIdx) + |
439 | " of type: " + llvm::getTypeName<T>()); |
440 | } |
441 | |
442 | static T processAsArg(PDLValue pdlValue) { return pdlValue.cast<T>(); } |
443 | static void processAsResult(PatternRewriter &, PDLResultList &results, |
444 | T value) { |
445 | results.push_back(value); |
446 | } |
447 | }; |
448 | |
449 | /// This struct provides a simplified model for processing types that inherit |
450 | /// from builtin PDLValue types. For example, derived attributes like |
451 | /// IntegerAttr, derived types like IntegerType, derived operations like |
452 | /// ModuleOp, Interfaces, etc. |
453 | template <typename T, typename BaseT> |
454 | struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> { |
455 | static LogicalResult |
456 | verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, |
457 | BaseT baseValue, size_t argIdx) { |
458 | return TypeSwitch<BaseT, LogicalResult>(baseValue) |
459 | .Case([&](T) { return success(); }) |
460 | .Default([&](BaseT) { |
461 | return errorFn("expected argument " + Twine(argIdx) + |
462 | " to be of type: " + llvm::getTypeName<T>()); |
463 | }); |
464 | } |
465 | using ProcessPDLValueBasedOn<T, BaseT>::verifyAsArg; |
466 | |
467 | static T processAsArg(BaseT baseValue) { |
468 | return baseValue.template cast<T>(); |
469 | } |
470 | using ProcessPDLValueBasedOn<T, BaseT>::processAsArg; |
471 | |
472 | static void processAsResult(PatternRewriter &, PDLResultList &results, |
473 | T value) { |
474 | results.push_back(value); |
475 | } |
476 | }; |
477 | |
478 | //===----------------------------------------------------------------------===// |
479 | // Attribute |
480 | |
481 | template <> |
482 | struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> {}; |
483 | template <typename T> |
484 | struct ProcessPDLValue<T, |
485 | std::enable_if_t<std::is_base_of<Attribute, T>::value>> |
486 | : public ProcessDerivedPDLValue<T, Attribute> {}; |
487 | |
488 | /// Handling for various Attribute value types. |
489 | template <> |
490 | struct ProcessPDLValue<StringRef> |
491 | : public ProcessPDLValueBasedOn<StringRef, StringAttr> { |
492 | static StringRef processAsArg(StringAttr value) { return value.getValue(); } |
493 | using ProcessPDLValueBasedOn<StringRef, StringAttr>::processAsArg; |
494 | |
495 | static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, |
496 | StringRef value) { |
497 | results.push_back(rewriter.getStringAttr(value)); |
498 | } |
499 | }; |
500 | template <> |
501 | struct ProcessPDLValue<std::string> |
502 | : public ProcessPDLValueBasedOn<std::string, StringAttr> { |
503 | template <typename T> |
504 | static std::string processAsArg(T value) { |
505 | static_assert(always_false<T>, |
506 | "`std::string` arguments require a string copy, use " |
507 | "`StringRef` for string-like arguments instead" ); |
508 | return {}; |
509 | } |
510 | static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, |
511 | StringRef value) { |
512 | results.push_back(rewriter.getStringAttr(value)); |
513 | } |
514 | }; |
515 | |
516 | //===----------------------------------------------------------------------===// |
517 | // Operation |
518 | |
519 | template <> |
520 | struct ProcessPDLValue<Operation *> |
521 | : public ProcessBuiltinPDLValue<Operation *> {}; |
522 | template <typename T> |
523 | struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>> |
524 | : public ProcessDerivedPDLValue<T, Operation *> { |
525 | static T processAsArg(Operation *value) { return cast<T>(value); } |
526 | }; |
527 | |
528 | //===----------------------------------------------------------------------===// |
529 | // Type |
530 | |
531 | template <> |
532 | struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> {}; |
533 | template <typename T> |
534 | struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>> |
535 | : public ProcessDerivedPDLValue<T, Type> {}; |
536 | |
537 | //===----------------------------------------------------------------------===// |
538 | // TypeRange |
539 | |
540 | template <> |
541 | struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> {}; |
542 | template <> |
543 | struct ProcessPDLValue<ValueTypeRange<OperandRange>> { |
544 | static void processAsResult(PatternRewriter &, PDLResultList &results, |
545 | ValueTypeRange<OperandRange> types) { |
546 | results.push_back(value: types); |
547 | } |
548 | }; |
549 | template <> |
550 | struct ProcessPDLValue<ValueTypeRange<ResultRange>> { |
551 | static void processAsResult(PatternRewriter &, PDLResultList &results, |
552 | ValueTypeRange<ResultRange> types) { |
553 | results.push_back(value: types); |
554 | } |
555 | }; |
556 | template <unsigned N> |
557 | struct ProcessPDLValue<SmallVector<Type, N>> { |
558 | static void processAsResult(PatternRewriter &, PDLResultList &results, |
559 | SmallVector<Type, N> values) { |
560 | results.push_back(value: TypeRange(values)); |
561 | } |
562 | }; |
563 | |
564 | //===----------------------------------------------------------------------===// |
565 | // Value |
566 | |
567 | template <> |
568 | struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> {}; |
569 | |
570 | //===----------------------------------------------------------------------===// |
571 | // ValueRange |
572 | |
573 | template <> |
574 | struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> { |
575 | }; |
576 | template <> |
577 | struct ProcessPDLValue<OperandRange> { |
578 | static void processAsResult(PatternRewriter &, PDLResultList &results, |
579 | OperandRange values) { |
580 | results.push_back(value: values); |
581 | } |
582 | }; |
583 | template <> |
584 | struct ProcessPDLValue<ResultRange> { |
585 | static void processAsResult(PatternRewriter &, PDLResultList &results, |
586 | ResultRange values) { |
587 | results.push_back(value: values); |
588 | } |
589 | }; |
590 | template <unsigned N> |
591 | struct ProcessPDLValue<SmallVector<Value, N>> { |
592 | static void processAsResult(PatternRewriter &, PDLResultList &results, |
593 | SmallVector<Value, N> values) { |
594 | results.push_back(value: ValueRange(values)); |
595 | } |
596 | }; |
597 | |
598 | //===----------------------------------------------------------------------===// |
599 | // PDL Function Builder: Argument Handling |
600 | //===----------------------------------------------------------------------===// |
601 | |
602 | /// Validate the given PDLValues match the constraints defined by the argument |
603 | /// types of the given function. In the case of failure, a match failure |
604 | /// diagnostic is emitted. |
605 | /// FIXME: This should be completely removed in favor of `assertArgs`, but PDL |
606 | /// does not currently preserve Constraint application ordering. |
607 | template <typename PDLFnT, std::size_t... I> |
608 | LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values, |
609 | std::index_sequence<I...>) { |
610 | using FnTraitsT = llvm::function_traits<PDLFnT>; |
611 | |
612 | auto errorFn = [&](const Twine &msg) { |
613 | return rewriter.notifyMatchFailure(arg: rewriter.getUnknownLoc(), msg); |
614 | }; |
615 | return success( |
616 | (succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>:: |
617 | verifyAsArg(errorFn, values[I], I)) && |
618 | ...)); |
619 | } |
620 | |
621 | /// Assert that the given PDLValues match the constraints defined by the |
622 | /// arguments of the given function. In the case of failure, a fatal error |
623 | /// is emitted. |
624 | template <typename PDLFnT, std::size_t... I> |
625 | void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values, |
626 | std::index_sequence<I...>) { |
627 | // We only want to do verification in debug builds, same as with `assert`. |
628 | #if LLVM_ENABLE_ABI_BREAKING_CHECKS |
629 | using FnTraitsT = llvm::function_traits<PDLFnT>; |
630 | auto errorFn = [&](const Twine &msg) -> LogicalResult { |
631 | llvm::report_fatal_error(reason: msg); |
632 | }; |
633 | (void)errorFn; |
634 | assert((succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>:: |
635 | verifyAsArg(errorFn, values[I], I)) && |
636 | ...)); |
637 | #endif |
638 | (void)values; |
639 | } |
640 | |
641 | //===----------------------------------------------------------------------===// |
642 | // PDL Function Builder: Results Handling |
643 | //===----------------------------------------------------------------------===// |
644 | |
645 | /// Store a single result within the result list. |
646 | template <typename T> |
647 | static LogicalResult processResults(PatternRewriter &rewriter, |
648 | PDLResultList &results, T &&value) { |
649 | ProcessPDLValue<T>::processAsResult(rewriter, results, |
650 | std::forward<T>(value)); |
651 | return success(); |
652 | } |
653 | |
654 | /// Store a std::pair<> as individual results within the result list. |
655 | template <typename T1, typename T2> |
656 | static LogicalResult processResults(PatternRewriter &rewriter, |
657 | PDLResultList &results, |
658 | std::pair<T1, T2> &&pair) { |
659 | if (failed(processResults(rewriter, results, std::move(pair.first))) || |
660 | failed(processResults(rewriter, results, std::move(pair.second)))) |
661 | return failure(); |
662 | return success(); |
663 | } |
664 | |
665 | /// Store a std::tuple<> as individual results within the result list. |
666 | template <typename... Ts> |
667 | static LogicalResult processResults(PatternRewriter &rewriter, |
668 | PDLResultList &results, |
669 | std::tuple<Ts...> &&tuple) { |
670 | auto applyFn = [&](auto &&...args) { |
671 | return (succeeded(processResults(rewriter, results, std::move(args))) && |
672 | ...); |
673 | }; |
674 | return success(std::apply(applyFn, std::move(tuple))); |
675 | } |
676 | |
677 | /// Handle LogicalResult propagation. |
678 | inline LogicalResult processResults(PatternRewriter &rewriter, |
679 | PDLResultList &results, |
680 | LogicalResult &&result) { |
681 | return result; |
682 | } |
683 | template <typename T> |
684 | static LogicalResult processResults(PatternRewriter &rewriter, |
685 | PDLResultList &results, |
686 | FailureOr<T> &&result) { |
687 | if (failed(result)) |
688 | return failure(); |
689 | return processResults(rewriter, results, std::move(*result)); |
690 | } |
691 | |
692 | //===----------------------------------------------------------------------===// |
693 | // PDL Constraint Builder |
694 | //===----------------------------------------------------------------------===// |
695 | |
696 | /// Process the arguments of a native constraint and invoke it. |
697 | template <typename PDLFnT, std::size_t... I, |
698 | typename FnTraitsT = llvm::function_traits<PDLFnT>> |
699 | typename FnTraitsT::result_t |
700 | processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter, |
701 | ArrayRef<PDLValue> values, |
702 | std::index_sequence<I...>) { |
703 | return fn( |
704 | rewriter, |
705 | (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg( |
706 | values[I]))...); |
707 | } |
708 | |
709 | /// Build a constraint function from the given function `ConstraintFnT`. This |
710 | /// allows for enabling the user to define simpler, more direct constraint |
711 | /// functions without needing to handle the low-level PDL goop. |
712 | /// |
713 | /// If the constraint function is already in the correct form, we just forward |
714 | /// it directly. |
715 | template <typename ConstraintFnT> |
716 | std::enable_if_t< |
717 | std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value, |
718 | PDLConstraintFunction> |
719 | buildConstraintFn(ConstraintFnT &&constraintFn) { |
720 | return std::forward<ConstraintFnT>(constraintFn); |
721 | } |
722 | /// Otherwise, we generate a wrapper that will unpack the PDLValues in the form |
723 | /// we desire. |
724 | template <typename ConstraintFnT> |
725 | std::enable_if_t< |
726 | !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value, |
727 | PDLConstraintFunction> |
728 | buildConstraintFn(ConstraintFnT &&constraintFn) { |
729 | return [constraintFn = std::forward<ConstraintFnT>(constraintFn)]( |
730 | PatternRewriter &rewriter, PDLResultList &, |
731 | ArrayRef<PDLValue> values) -> LogicalResult { |
732 | auto argIndices = std::make_index_sequence< |
733 | llvm::function_traits<ConstraintFnT>::num_args - 1>(); |
734 | if (failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices))) |
735 | return failure(); |
736 | return processArgsAndInvokeConstraint(constraintFn, rewriter, values, |
737 | argIndices); |
738 | }; |
739 | } |
740 | |
741 | //===----------------------------------------------------------------------===// |
742 | // PDL Rewrite Builder |
743 | //===----------------------------------------------------------------------===// |
744 | |
745 | /// Process the arguments of a native rewrite and invoke it. |
746 | /// This overload handles the case of no return values. |
747 | template <typename PDLFnT, std::size_t... I, |
748 | typename FnTraitsT = llvm::function_traits<PDLFnT>> |
749 | std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value, |
750 | LogicalResult> |
751 | processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, |
752 | PDLResultList &, ArrayRef<PDLValue> values, |
753 | std::index_sequence<I...>) { |
754 | fn(rewriter, |
755 | (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg( |
756 | values[I]))...); |
757 | return success(); |
758 | } |
759 | /// This overload handles the case of return values, which need to be packaged |
760 | /// into the result list. |
761 | template <typename PDLFnT, std::size_t... I, |
762 | typename FnTraitsT = llvm::function_traits<PDLFnT>> |
763 | std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value, |
764 | LogicalResult> |
765 | processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, |
766 | PDLResultList &results, ArrayRef<PDLValue> values, |
767 | std::index_sequence<I...>) { |
768 | return processResults( |
769 | rewriter, results, |
770 | fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>:: |
771 | processAsArg(values[I]))...)); |
772 | (void)values; |
773 | } |
774 | |
775 | /// Build a rewrite function from the given function `RewriteFnT`. This |
776 | /// allows for enabling the user to define simpler, more direct rewrite |
777 | /// functions without needing to handle the low-level PDL goop. |
778 | /// |
779 | /// If the rewrite function is already in the correct form, we just forward |
780 | /// it directly. |
781 | template <typename RewriteFnT> |
782 | std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value, |
783 | PDLRewriteFunction> |
784 | buildRewriteFn(RewriteFnT &&rewriteFn) { |
785 | return std::forward<RewriteFnT>(rewriteFn); |
786 | } |
787 | /// Otherwise, we generate a wrapper that will unpack the PDLValues in the form |
788 | /// we desire. |
789 | template <typename RewriteFnT> |
790 | std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value, |
791 | PDLRewriteFunction> |
792 | buildRewriteFn(RewriteFnT &&rewriteFn) { |
793 | return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)]( |
794 | PatternRewriter &rewriter, PDLResultList &results, |
795 | ArrayRef<PDLValue> values) { |
796 | auto argIndices = |
797 | std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args - |
798 | 1>(); |
799 | assertArgs<RewriteFnT>(rewriter, values, argIndices); |
800 | return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values, |
801 | argIndices); |
802 | }; |
803 | } |
804 | |
805 | } // namespace pdl_function_builder |
806 | } // namespace detail |
807 | |
808 | //===----------------------------------------------------------------------===// |
809 | // PDLPatternModule |
810 | |
811 | /// This class contains all of the necessary data for a set of PDL patterns, or |
812 | /// pattern rewrites specified in the form of the PDL dialect. This PDL module |
813 | /// contained by this pattern may contain any number of `pdl.pattern` |
814 | /// operations. |
815 | class PDLPatternModule { |
816 | public: |
817 | PDLPatternModule() = default; |
818 | |
819 | /// Construct a PDL pattern with the given module and configurations. |
820 | PDLPatternModule(OwningOpRef<ModuleOp> module) |
821 | : pdlModule(std::move(module)) {} |
822 | template <typename... ConfigsT> |
823 | PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs) |
824 | : PDLPatternModule(std::move(module)) { |
825 | auto configSet = std::make_unique<PDLPatternConfigSet>( |
826 | std::forward<ConfigsT>(patternConfigs)...); |
827 | attachConfigToPatterns(*pdlModule, *configSet); |
828 | configs.emplace_back(std::move(configSet)); |
829 | } |
830 | |
831 | /// Merge the state in `other` into this pattern module. |
832 | void mergeIn(PDLPatternModule &&other); |
833 | |
834 | /// Return the internal PDL module of this pattern. |
835 | ModuleOp getModule() { return pdlModule.get(); } |
836 | |
837 | /// Return the MLIR context of this pattern. |
838 | MLIRContext *getContext() { return getModule()->getContext(); } |
839 | |
840 | //===--------------------------------------------------------------------===// |
841 | // Function Registry |
842 | |
843 | /// Register a constraint function with PDL. A constraint function may be |
844 | /// specified in one of two ways: |
845 | /// |
846 | /// * `LogicalResult (PatternRewriter &, |
847 | /// PDLResultList &, |
848 | /// ArrayRef<PDLValue>)` |
849 | /// |
850 | /// In this overload the arguments of the constraint function are passed via |
851 | /// the low-level PDLValue form, and the results are manually appended to |
852 | /// the given result list. |
853 | /// |
854 | /// * `LogicalResult (PatternRewriter &, ValueTs... values)` |
855 | /// |
856 | /// In this form the arguments of the constraint function are passed via the |
857 | /// expected high level C++ type. In this form, the framework will |
858 | /// automatically unwrap PDLValues and convert them to the expected ValueTs. |
859 | /// For example, if the constraint function accepts a `Operation *`, the |
860 | /// framework will automatically cast the input PDLValue. In the case of a |
861 | /// `StringRef`, the framework will automatically unwrap the argument as a |
862 | /// StringAttr and pass the underlying string value. To see the full list of |
863 | /// supported types, or to see how to add handling for custom types, view |
864 | /// the definition of `ProcessPDLValue` above. |
865 | void registerConstraintFunction(StringRef name, |
866 | PDLConstraintFunction constraintFn); |
867 | template <typename ConstraintFnT> |
868 | void registerConstraintFunction(StringRef name, |
869 | ConstraintFnT &&constraintFn) { |
870 | registerConstraintFunction(name, |
871 | detail::pdl_function_builder::buildConstraintFn( |
872 | std::forward<ConstraintFnT>(constraintFn))); |
873 | } |
874 | |
875 | /// Register a rewrite function with PDL. A rewrite function may be specified |
876 | /// in one of two ways: |
877 | /// |
878 | /// * `void (PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)` |
879 | /// |
880 | /// In this overload the arguments of the constraint function are passed via |
881 | /// the low-level PDLValue form, and the results are manually appended to |
882 | /// the given result list. |
883 | /// |
884 | /// * `ResultT (PatternRewriter &, ValueTs... values)` |
885 | /// |
886 | /// In this form the arguments and result of the rewrite function are passed |
887 | /// via the expected high level C++ type. In this form, the framework will |
888 | /// automatically unwrap the PDLValues arguments and convert them to the |
889 | /// expected ValueTs. It will also automatically handle the processing and |
890 | /// packaging of the result value to the result list. For example, if the |
891 | /// rewrite function takes a `Operation *`, the framework will automatically |
892 | /// cast the input PDLValue. In the case of a `StringRef`, the framework |
893 | /// will automatically unwrap the argument as a StringAttr and pass the |
894 | /// underlying string value. In the reverse case, if the rewrite returns a |
895 | /// StringRef or std::string, it will automatically package this as a |
896 | /// StringAttr and append it to the result list. To see the full list of |
897 | /// supported types, or to see how to add handling for custom types, view |
898 | /// the definition of `ProcessPDLValue` above. |
899 | void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn); |
900 | template <typename RewriteFnT> |
901 | void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) { |
902 | registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn( |
903 | std::forward<RewriteFnT>(rewriteFn))); |
904 | } |
905 | |
906 | /// Return the set of the registered constraint functions. |
907 | const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const { |
908 | return constraintFunctions; |
909 | } |
910 | llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() { |
911 | return constraintFunctions; |
912 | } |
913 | /// Return the set of the registered rewrite functions. |
914 | const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const { |
915 | return rewriteFunctions; |
916 | } |
917 | llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() { |
918 | return rewriteFunctions; |
919 | } |
920 | |
921 | /// Return the set of the registered pattern configs. |
922 | SmallVector<std::unique_ptr<PDLPatternConfigSet>> takeConfigs() { |
923 | return std::move(configs); |
924 | } |
925 | DenseMap<Operation *, PDLPatternConfigSet *> takeConfigMap() { |
926 | return std::move(configMap); |
927 | } |
928 | |
929 | /// Clear out the patterns and functions within this module. |
930 | void clear() { |
931 | pdlModule = nullptr; |
932 | constraintFunctions.clear(); |
933 | rewriteFunctions.clear(); |
934 | } |
935 | |
936 | private: |
937 | /// Attach the given pattern config set to the patterns defined within the |
938 | /// given module. |
939 | void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet); |
940 | |
941 | /// The module containing the `pdl.pattern` operations. |
942 | OwningOpRef<ModuleOp> pdlModule; |
943 | |
944 | /// The set of configuration sets referenced by patterns within `pdlModule`. |
945 | SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs; |
946 | DenseMap<Operation *, PDLPatternConfigSet *> configMap; |
947 | |
948 | /// The external functions referenced from within the PDL module. |
949 | llvm::StringMap<PDLConstraintFunction> constraintFunctions; |
950 | llvm::StringMap<PDLRewriteFunction> rewriteFunctions; |
951 | }; |
952 | } // namespace mlir |
953 | |
954 | #else |
955 | |
956 | namespace mlir { |
957 | // Stubs for when PDL in pattern rewrites is not enabled. |
958 | |
959 | class PDLValue { |
960 | public: |
961 | template <typename T> |
962 | T dyn_cast() const { |
963 | return nullptr; |
964 | } |
965 | }; |
966 | class PDLResultList {}; |
967 | using PDLConstraintFunction = std::function<LogicalResult( |
968 | PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>; |
969 | using PDLRewriteFunction = std::function<LogicalResult( |
970 | PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>; |
971 | |
972 | class PDLPatternModule { |
973 | public: |
974 | PDLPatternModule() = default; |
975 | |
976 | PDLPatternModule(OwningOpRef<ModuleOp> /*module*/) {} |
977 | MLIRContext *getContext() { |
978 | llvm_unreachable("Error: PDL for rewrites when PDL is not enabled" ); |
979 | } |
980 | void mergeIn(PDLPatternModule &&other) {} |
981 | void clear() {} |
982 | template <typename ConstraintFnT> |
983 | void registerConstraintFunction(StringRef name, |
984 | ConstraintFnT &&constraintFn) {} |
985 | void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn) {} |
986 | template <typename RewriteFnT> |
987 | void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {} |
988 | const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const { |
989 | return constraintFunctions; |
990 | } |
991 | |
992 | private: |
993 | llvm::StringMap<PDLConstraintFunction> constraintFunctions; |
994 | }; |
995 | |
996 | } // namespace mlir |
997 | #endif |
998 | |
999 | #endif // MLIR_IR_PDLPATTERNMATCH_H |
1000 | |