1//===- PDLPatternMatch.h - PDLPatternMatcher classes -------==---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_PDLPATTERNMATCH_H
10#define MLIR_IR_PDLPATTERNMATCH_H
11
12#include "mlir/Config/mlir-config.h"
13
14#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
15#include "mlir/IR/Builders.h"
16#include "mlir/IR/BuiltinOps.h"
17
18namespace mlir {
19//===----------------------------------------------------------------------===//
20// PDL Patterns
21//===----------------------------------------------------------------------===//
22
23//===----------------------------------------------------------------------===//
24// PDLValue
25
26/// Storage type of byte-code interpreter values. These are passed to constraint
27/// functions as arguments.
28class PDLValue {
29public:
30 /// The underlying kind of a PDL value.
31 enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange };
32
33 /// Construct a new PDL value.
34 PDLValue(const PDLValue &other) = default;
35 PDLValue(std::nullptr_t = nullptr) {}
36 PDLValue(Attribute value)
37 : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {}
38 PDLValue(Operation *value) : value(value), kind(Kind::Operation) {}
39 PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {}
40 PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {}
41 PDLValue(Value value)
42 : value(value.getAsOpaquePointer()), kind(Kind::Value) {}
43 PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}
44
45 /// Returns true if the type of the held value is `T`.
46 template <typename T>
47 bool isa() const {
48 assert(value && "isa<> used on a null value");
49 return kind == getKindOf<T>();
50 }
51
52 /// Attempt to dynamically cast this value to type `T`, returns null if this
53 /// value is not an instance of `T`.
54 template <typename T,
55 typename ResultT = std::conditional_t<
56 std::is_convertible<T, bool>::value, T, std::optional<T>>>
57 ResultT dyn_cast() const {
58 return isa<T>() ? castImpl<T>() : ResultT();
59 }
60
61 /// Cast this value to type `T`, asserts if this value is not an instance of
62 /// `T`.
63 template <typename T>
64 T cast() const {
65 assert(isa<T>() && "expected value to be of type `T`");
66 return castImpl<T>();
67 }
68
69 /// Get an opaque pointer to the value.
70 const void *getAsOpaquePointer() const { return value; }
71
72 /// Return if this value is null or not.
73 explicit operator bool() const { return value; }
74
75 /// Return the kind of this value.
76 Kind getKind() const { return kind; }
77
78 /// Print this value to the provided output stream.
79 void print(raw_ostream &os) const;
80
81 /// Print the specified value kind to an output stream.
82 static void print(raw_ostream &os, Kind kind);
83
84private:
85 /// Find the index of a given type in a range of other types.
86 template <typename...>
87 struct index_of_t;
88 template <typename T, typename... R>
89 struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
90 template <typename T, typename F, typename... R>
91 struct index_of_t<T, F, R...>
92 : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
93
94 /// Return the kind used for the given T.
95 template <typename T>
96 static Kind getKindOf() {
97 return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
98 TypeRange, Value, ValueRange>::value);
99 }
100
101 /// The internal implementation of `cast`, that returns the underlying value
102 /// as the given type `T`.
103 template <typename T>
104 std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
105 castImpl() const {
106 return T::getFromOpaquePointer(value);
107 }
108 template <typename T>
109 std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
110 castImpl() const {
111 return *reinterpret_cast<T *>(const_cast<void *>(value));
112 }
113 template <typename T>
114 std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const {
115 return reinterpret_cast<T>(const_cast<void *>(value));
116 }
117
118 /// The internal opaque representation of a PDLValue.
119 const void *value{nullptr};
120 /// The kind of the opaque value.
121 Kind kind{Kind::Attribute};
122};
123
124inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
125 value.print(os);
126 return os;
127}
128
129inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) {
130 PDLValue::print(os, kind);
131 return os;
132}
133
134//===----------------------------------------------------------------------===//
135// PDLResultList
136
137/// The class represents a list of PDL results, returned by a native rewrite
138/// method. It provides the mechanism with which to pass PDLValues back to the
139/// PDL bytecode.
140class PDLResultList {
141public:
142 /// Push a new Attribute value onto the result list.
143 void push_back(Attribute value) { results.push_back(Elt: value); }
144
145 /// Push a new Operation onto the result list.
146 void push_back(Operation *value) { results.push_back(Elt: value); }
147
148 /// Push a new Type onto the result list.
149 void push_back(Type value) { results.push_back(Elt: value); }
150
151 /// Push a new TypeRange onto the result list.
152 void push_back(TypeRange value) {
153 // The lifetime of a TypeRange can't be guaranteed, so we'll need to
154 // allocate a storage for it.
155 llvm::OwningArrayRef<Type> storage(value.size());
156 llvm::copy(Range&: value, Out: storage.begin());
157 allocatedTypeRanges.emplace_back(Args: std::move(storage));
158 typeRanges.push_back(Elt: allocatedTypeRanges.back());
159 results.push_back(Elt: &typeRanges.back());
160 }
161 void push_back(ValueTypeRange<OperandRange> value) {
162 typeRanges.push_back(Elt: value);
163 results.push_back(Elt: &typeRanges.back());
164 }
165 void push_back(ValueTypeRange<ResultRange> value) {
166 typeRanges.push_back(Elt: value);
167 results.push_back(Elt: &typeRanges.back());
168 }
169
170 /// Push a new Value onto the result list.
171 void push_back(Value value) { results.push_back(Elt: value); }
172
173 /// Push a new ValueRange onto the result list.
174 void push_back(ValueRange value) {
175 // The lifetime of a ValueRange can't be guaranteed, so we'll need to
176 // allocate a storage for it.
177 llvm::OwningArrayRef<Value> storage(value.size());
178 llvm::copy(Range&: value, Out: storage.begin());
179 allocatedValueRanges.emplace_back(Args: std::move(storage));
180 valueRanges.push_back(Elt: allocatedValueRanges.back());
181 results.push_back(Elt: &valueRanges.back());
182 }
183 void push_back(OperandRange value) {
184 valueRanges.push_back(Elt: value);
185 results.push_back(Elt: &valueRanges.back());
186 }
187 void push_back(ResultRange value) {
188 valueRanges.push_back(Elt: value);
189 results.push_back(Elt: &valueRanges.back());
190 }
191
192protected:
193 /// Create a new result list with the expected number of results.
194 PDLResultList(unsigned maxNumResults) {
195 // For now just reserve enough space for all of the results. We could do
196 // separate counts per range type, but it isn't really worth it unless there
197 // are a "large" number of results.
198 typeRanges.reserve(N: maxNumResults);
199 valueRanges.reserve(N: maxNumResults);
200 }
201
202 /// The PDL results held by this list.
203 SmallVector<PDLValue> results;
204 /// Memory used to store ranges held by the list.
205 SmallVector<TypeRange> typeRanges;
206 SmallVector<ValueRange> valueRanges;
207 /// Memory allocated to store ranges in the result list whose lifetime was
208 /// generated in the native function.
209 SmallVector<llvm::OwningArrayRef<Type>> allocatedTypeRanges;
210 SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges;
211};
212
213//===----------------------------------------------------------------------===//
214// PDLPatternConfig
215
216/// An individual configuration for a pattern, which can be accessed by native
217/// functions via the PDLPatternConfigSet. This allows for injecting additional
218/// configuration into PDL patterns that is specific to certain compilation
219/// flows.
220class PDLPatternConfig {
221public:
222 virtual ~PDLPatternConfig() = default;
223
224 /// Hooks that are invoked at the beginning and end of a rewrite of a matched
225 /// pattern. These can be used to setup any specific state necessary for the
226 /// rewrite.
227 virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
228 virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}
229
230 /// Return the TypeID that represents this configuration.
231 TypeID getTypeID() const { return id; }
232
233protected:
234 PDLPatternConfig(TypeID id) : id(id) {}
235
236private:
237 TypeID id;
238};
239
240/// This class provides a base class for users implementing a type of pattern
241/// configuration.
242template <typename T>
243class PDLPatternConfigBase : public PDLPatternConfig {
244public:
245 /// Support LLVM style casting.
246 static bool classof(const PDLPatternConfig *config) {
247 return config->getTypeID() == getConfigID();
248 }
249
250 /// Return the type id used for this configuration.
251 static TypeID getConfigID() { return TypeID::get<T>(); }
252
253protected:
254 PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {}
255};
256
257/// This class contains a set of configurations for a specific pattern.
258/// Configurations are uniqued by TypeID, meaning that only one configuration of
259/// each type is allowed.
260class PDLPatternConfigSet {
261public:
262 PDLPatternConfigSet() = default;
263
264 /// Construct a set with the given configurations.
265 template <typename... ConfigsT>
266 PDLPatternConfigSet(ConfigsT &&...configs) {
267 (addConfig(std::forward<ConfigsT>(configs)), ...);
268 }
269
270 /// Get the configuration defined by the given type. Asserts that the
271 /// configuration of the provided type exists.
272 template <typename T>
273 const T &get() const {
274 const T *config = tryGet<T>();
275 assert(config && "configuration not found");
276 return *config;
277 }
278
279 /// Get the configuration defined by the given type, returns nullptr if the
280 /// configuration does not exist.
281 template <typename T>
282 const T *tryGet() const {
283 for (const auto &configIt : configs)
284 if (const T *config = dyn_cast<T>(configIt.get()))
285 return config;
286 return nullptr;
287 }
288
289 /// Notify the configurations within this set at the beginning or end of a
290 /// rewrite of a matched pattern.
291 void notifyRewriteBegin(PatternRewriter &rewriter) {
292 for (const auto &config : configs)
293 config->notifyRewriteBegin(rewriter);
294 }
295 void notifyRewriteEnd(PatternRewriter &rewriter) {
296 for (const auto &config : configs)
297 config->notifyRewriteEnd(rewriter);
298 }
299
300protected:
301 /// Add a configuration to the set.
302 template <typename T>
303 void addConfig(T &&config) {
304 assert(!tryGet<std::decay_t<T>>() && "configuration already exists");
305 configs.emplace_back(
306 std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
307 }
308
309 /// The set of configurations for this pattern. This uses a vector instead of
310 /// a map with the expectation that the number of configurations per set is
311 /// small (<= 1).
312 SmallVector<std::unique_ptr<PDLPatternConfig>> configs;
313};
314
315//===----------------------------------------------------------------------===//
316// PDLPatternModule
317
318/// A generic PDL pattern constraint function. This function applies a
319/// constraint to a given set of opaque PDLValue entities. Returns success if
320/// the constraint successfully held, failure otherwise.
321using PDLConstraintFunction = std::function<LogicalResult(
322 PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
323
324/// A native PDL rewrite function. This function performs a rewrite on the
325/// given set of values. Any results from this rewrite that should be passed
326/// back to PDL should be added to the provided result list. This method is only
327/// invoked when the corresponding match was successful. Returns failure if an
328/// invariant of the rewrite was broken (certain rewriters may recover from
329/// partial pattern application).
330using PDLRewriteFunction = std::function<LogicalResult(
331 PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
332
333namespace detail {
334namespace pdl_function_builder {
335/// A utility variable that always resolves to false. This is useful for static
336/// asserts that are always false, but only should fire in certain templated
337/// constructs. For example, if a templated function should never be called, the
338/// function could be defined as:
339///
340/// template <typename T>
341/// void foo() {
342/// static_assert(always_false<T>, "This function should never be called");
343/// }
344///
345template <class... T>
346constexpr bool always_false = false;
347
348//===----------------------------------------------------------------------===//
349// PDL Function Builder: Type Processing
350//===----------------------------------------------------------------------===//
351
352/// This struct provides a convenient way to determine how to process a given
353/// type as either a PDL parameter, or a result value. This allows for
354/// supporting complex types in constraint and rewrite functions, without
355/// requiring the user to hand-write the necessary glue code themselves.
356/// Specializations of this class should implement the following methods to
357/// enable support as a PDL argument or result type:
358///
359/// static LogicalResult verifyAsArg(
360/// function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue,
361/// size_t argIdx);
362///
363/// * This method verifies that the given PDLValue is valid for use as a
364/// value of `T`.
365///
366/// static T processAsArg(PDLValue pdlValue);
367///
368/// * This method processes the given PDLValue as a value of `T`.
369///
370/// static void processAsResult(PatternRewriter &, PDLResultList &results,
371/// const T &value);
372///
373/// * This method processes the given value of `T` as the result of a
374/// function invocation. The method should package the value into an
375/// appropriate form and append it to the given result list.
376///
377/// If the type `T` is based on a higher order value, consider using
378/// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify
379/// the implementation.
380///
381template <typename T, typename Enable = void>
382struct ProcessPDLValue;
383
384/// This struct provides a simplified model for processing types that are based
385/// on another type, e.g. APInt is based on the handling for IntegerAttr. This
386/// allows for building the necessary processing functions on top of the base
387/// value instead of a PDLValue. Derived users should implement the following
388/// (which subsume the ProcessPDLValue variants):
389///
390/// static LogicalResult verifyAsArg(
391/// function_ref<LogicalResult(const Twine &)> errorFn,
392/// const BaseT &baseValue, size_t argIdx);
393///
394/// * This method verifies that the given PDLValue is valid for use as a
395/// value of `T`.
396///
397/// static T processAsArg(BaseT baseValue);
398///
399/// * This method processes the given base value as a value of `T`.
400///
401template <typename T, typename BaseT>
402struct ProcessPDLValueBasedOn {
403 static LogicalResult
404 verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
405 PDLValue pdlValue, size_t argIdx) {
406 // Verify the base class before continuing.
407 if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
408 return failure();
409 return ProcessPDLValue<T>::verifyAsArg(
410 errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
411 }
412 static T processAsArg(PDLValue pdlValue) {
413 return ProcessPDLValue<T>::processAsArg(
414 ProcessPDLValue<BaseT>::processAsArg(pdlValue));
415 }
416
417 /// Explicitly add the expected parent API to ensure the parent class
418 /// implements the necessary API (and doesn't implicitly inherit it from
419 /// somewhere else).
420 static LogicalResult
421 verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, BaseT value,
422 size_t argIdx) {
423 return success();
424 }
425 static T processAsArg(BaseT baseValue);
426};
427
428/// This struct provides a simplified model for processing types that have
429/// "builtin" PDLValue support:
430/// * Attribute, Operation *, Type, TypeRange, ValueRange
431template <typename T>
432struct ProcessBuiltinPDLValue {
433 static LogicalResult
434 verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
435 PDLValue pdlValue, size_t argIdx) {
436 if (pdlValue)
437 return success();
438 return errorFn("expected a non-null value for argument " + Twine(argIdx) +
439 " of type: " + llvm::getTypeName<T>());
440 }
441
442 static T processAsArg(PDLValue pdlValue) { return pdlValue.cast<T>(); }
443 static void processAsResult(PatternRewriter &, PDLResultList &results,
444 T value) {
445 results.push_back(value);
446 }
447};
448
449/// This struct provides a simplified model for processing types that inherit
450/// from builtin PDLValue types. For example, derived attributes like
451/// IntegerAttr, derived types like IntegerType, derived operations like
452/// ModuleOp, Interfaces, etc.
453template <typename T, typename BaseT>
454struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> {
455 static LogicalResult
456 verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
457 BaseT baseValue, size_t argIdx) {
458 return TypeSwitch<BaseT, LogicalResult>(baseValue)
459 .Case([&](T) { return success(); })
460 .Default([&](BaseT) {
461 return errorFn("expected argument " + Twine(argIdx) +
462 " to be of type: " + llvm::getTypeName<T>());
463 });
464 }
465 using ProcessPDLValueBasedOn<T, BaseT>::verifyAsArg;
466
467 static T processAsArg(BaseT baseValue) {
468 return baseValue.template cast<T>();
469 }
470 using ProcessPDLValueBasedOn<T, BaseT>::processAsArg;
471
472 static void processAsResult(PatternRewriter &, PDLResultList &results,
473 T value) {
474 results.push_back(value);
475 }
476};
477
478//===----------------------------------------------------------------------===//
479// Attribute
480
481template <>
482struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> {};
483template <typename T>
484struct ProcessPDLValue<T,
485 std::enable_if_t<std::is_base_of<Attribute, T>::value>>
486 : public ProcessDerivedPDLValue<T, Attribute> {};
487
488/// Handling for various Attribute value types.
489template <>
490struct ProcessPDLValue<StringRef>
491 : public ProcessPDLValueBasedOn<StringRef, StringAttr> {
492 static StringRef processAsArg(StringAttr value) { return value.getValue(); }
493 using ProcessPDLValueBasedOn<StringRef, StringAttr>::processAsArg;
494
495 static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
496 StringRef value) {
497 results.push_back(rewriter.getStringAttr(value));
498 }
499};
500template <>
501struct ProcessPDLValue<std::string>
502 : public ProcessPDLValueBasedOn<std::string, StringAttr> {
503 template <typename T>
504 static std::string processAsArg(T value) {
505 static_assert(always_false<T>,
506 "`std::string` arguments require a string copy, use "
507 "`StringRef` for string-like arguments instead");
508 return {};
509 }
510 static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
511 StringRef value) {
512 results.push_back(rewriter.getStringAttr(value));
513 }
514};
515
516//===----------------------------------------------------------------------===//
517// Operation
518
519template <>
520struct ProcessPDLValue<Operation *>
521 : public ProcessBuiltinPDLValue<Operation *> {};
522template <typename T>
523struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>
524 : public ProcessDerivedPDLValue<T, Operation *> {
525 static T processAsArg(Operation *value) { return cast<T>(value); }
526};
527
528//===----------------------------------------------------------------------===//
529// Type
530
531template <>
532struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> {};
533template <typename T>
534struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>
535 : public ProcessDerivedPDLValue<T, Type> {};
536
537//===----------------------------------------------------------------------===//
538// TypeRange
539
540template <>
541struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> {};
542template <>
543struct ProcessPDLValue<ValueTypeRange<OperandRange>> {
544 static void processAsResult(PatternRewriter &, PDLResultList &results,
545 ValueTypeRange<OperandRange> types) {
546 results.push_back(value: types);
547 }
548};
549template <>
550struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
551 static void processAsResult(PatternRewriter &, PDLResultList &results,
552 ValueTypeRange<ResultRange> types) {
553 results.push_back(value: types);
554 }
555};
556template <unsigned N>
557struct ProcessPDLValue<SmallVector<Type, N>> {
558 static void processAsResult(PatternRewriter &, PDLResultList &results,
559 SmallVector<Type, N> values) {
560 results.push_back(value: TypeRange(values));
561 }
562};
563
564//===----------------------------------------------------------------------===//
565// Value
566
567template <>
568struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> {};
569
570//===----------------------------------------------------------------------===//
571// ValueRange
572
573template <>
574struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> {
575};
576template <>
577struct ProcessPDLValue<OperandRange> {
578 static void processAsResult(PatternRewriter &, PDLResultList &results,
579 OperandRange values) {
580 results.push_back(value: values);
581 }
582};
583template <>
584struct ProcessPDLValue<ResultRange> {
585 static void processAsResult(PatternRewriter &, PDLResultList &results,
586 ResultRange values) {
587 results.push_back(value: values);
588 }
589};
590template <unsigned N>
591struct ProcessPDLValue<SmallVector<Value, N>> {
592 static void processAsResult(PatternRewriter &, PDLResultList &results,
593 SmallVector<Value, N> values) {
594 results.push_back(value: ValueRange(values));
595 }
596};
597
598//===----------------------------------------------------------------------===//
599// PDL Function Builder: Argument Handling
600//===----------------------------------------------------------------------===//
601
602/// Validate the given PDLValues match the constraints defined by the argument
603/// types of the given function. In the case of failure, a match failure
604/// diagnostic is emitted.
605/// FIXME: This should be completely removed in favor of `assertArgs`, but PDL
606/// does not currently preserve Constraint application ordering.
607template <typename PDLFnT, std::size_t... I>
608LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
609 std::index_sequence<I...>) {
610 using FnTraitsT = llvm::function_traits<PDLFnT>;
611
612 auto errorFn = [&](const Twine &msg) {
613 return rewriter.notifyMatchFailure(arg: rewriter.getUnknownLoc(), msg);
614 };
615 return success(
616 (succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
617 verifyAsArg(errorFn, values[I], I)) &&
618 ...));
619}
620
621/// Assert that the given PDLValues match the constraints defined by the
622/// arguments of the given function. In the case of failure, a fatal error
623/// is emitted.
624template <typename PDLFnT, std::size_t... I>
625void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
626 std::index_sequence<I...>) {
627 // We only want to do verification in debug builds, same as with `assert`.
628#if LLVM_ENABLE_ABI_BREAKING_CHECKS
629 using FnTraitsT = llvm::function_traits<PDLFnT>;
630 auto errorFn = [&](const Twine &msg) -> LogicalResult {
631 llvm::report_fatal_error(reason: msg);
632 };
633 (void)errorFn;
634 assert((succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
635 verifyAsArg(errorFn, values[I], I)) &&
636 ...));
637#endif
638 (void)values;
639}
640
641//===----------------------------------------------------------------------===//
642// PDL Function Builder: Results Handling
643//===----------------------------------------------------------------------===//
644
645/// Store a single result within the result list.
646template <typename T>
647static LogicalResult processResults(PatternRewriter &rewriter,
648 PDLResultList &results, T &&value) {
649 ProcessPDLValue<T>::processAsResult(rewriter, results,
650 std::forward<T>(value));
651 return success();
652}
653
654/// Store a std::pair<> as individual results within the result list.
655template <typename T1, typename T2>
656static LogicalResult processResults(PatternRewriter &rewriter,
657 PDLResultList &results,
658 std::pair<T1, T2> &&pair) {
659 if (failed(processResults(rewriter, results, std::move(pair.first))) ||
660 failed(processResults(rewriter, results, std::move(pair.second))))
661 return failure();
662 return success();
663}
664
665/// Store a std::tuple<> as individual results within the result list.
666template <typename... Ts>
667static LogicalResult processResults(PatternRewriter &rewriter,
668 PDLResultList &results,
669 std::tuple<Ts...> &&tuple) {
670 auto applyFn = [&](auto &&...args) {
671 return (succeeded(processResults(rewriter, results, std::move(args))) &&
672 ...);
673 };
674 return success(std::apply(applyFn, std::move(tuple)));
675}
676
677/// Handle LogicalResult propagation.
678inline LogicalResult processResults(PatternRewriter &rewriter,
679 PDLResultList &results,
680 LogicalResult &&result) {
681 return result;
682}
683template <typename T>
684static LogicalResult processResults(PatternRewriter &rewriter,
685 PDLResultList &results,
686 FailureOr<T> &&result) {
687 if (failed(result))
688 return failure();
689 return processResults(rewriter, results, std::move(*result));
690}
691
692//===----------------------------------------------------------------------===//
693// PDL Constraint Builder
694//===----------------------------------------------------------------------===//
695
696/// Process the arguments of a native constraint and invoke it.
697template <typename PDLFnT, std::size_t... I,
698 typename FnTraitsT = llvm::function_traits<PDLFnT>>
699typename FnTraitsT::result_t
700processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter,
701 ArrayRef<PDLValue> values,
702 std::index_sequence<I...>) {
703 return fn(
704 rewriter,
705 (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
706 values[I]))...);
707}
708
709/// Build a constraint function from the given function `ConstraintFnT`. This
710/// allows for enabling the user to define simpler, more direct constraint
711/// functions without needing to handle the low-level PDL goop.
712///
713/// If the constraint function is already in the correct form, we just forward
714/// it directly.
715template <typename ConstraintFnT>
716std::enable_if_t<
717 std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
718 PDLConstraintFunction>
719buildConstraintFn(ConstraintFnT &&constraintFn) {
720 return std::forward<ConstraintFnT>(constraintFn);
721}
722/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
723/// we desire.
724template <typename ConstraintFnT>
725std::enable_if_t<
726 !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
727 PDLConstraintFunction>
728buildConstraintFn(ConstraintFnT &&constraintFn) {
729 return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
730 PatternRewriter &rewriter, PDLResultList &,
731 ArrayRef<PDLValue> values) -> LogicalResult {
732 auto argIndices = std::make_index_sequence<
733 llvm::function_traits<ConstraintFnT>::num_args - 1>();
734 if (failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
735 return failure();
736 return processArgsAndInvokeConstraint(constraintFn, rewriter, values,
737 argIndices);
738 };
739}
740
741//===----------------------------------------------------------------------===//
742// PDL Rewrite Builder
743//===----------------------------------------------------------------------===//
744
745/// Process the arguments of a native rewrite and invoke it.
746/// This overload handles the case of no return values.
747template <typename PDLFnT, std::size_t... I,
748 typename FnTraitsT = llvm::function_traits<PDLFnT>>
749std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
750 LogicalResult>
751processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
752 PDLResultList &, ArrayRef<PDLValue> values,
753 std::index_sequence<I...>) {
754 fn(rewriter,
755 (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
756 values[I]))...);
757 return success();
758}
759/// This overload handles the case of return values, which need to be packaged
760/// into the result list.
761template <typename PDLFnT, std::size_t... I,
762 typename FnTraitsT = llvm::function_traits<PDLFnT>>
763std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
764 LogicalResult>
765processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
766 PDLResultList &results, ArrayRef<PDLValue> values,
767 std::index_sequence<I...>) {
768 return processResults(
769 rewriter, results,
770 fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
771 processAsArg(values[I]))...));
772 (void)values;
773}
774
775/// Build a rewrite function from the given function `RewriteFnT`. This
776/// allows for enabling the user to define simpler, more direct rewrite
777/// functions without needing to handle the low-level PDL goop.
778///
779/// If the rewrite function is already in the correct form, we just forward
780/// it directly.
781template <typename RewriteFnT>
782std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
783 PDLRewriteFunction>
784buildRewriteFn(RewriteFnT &&rewriteFn) {
785 return std::forward<RewriteFnT>(rewriteFn);
786}
787/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
788/// we desire.
789template <typename RewriteFnT>
790std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
791 PDLRewriteFunction>
792buildRewriteFn(RewriteFnT &&rewriteFn) {
793 return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
794 PatternRewriter &rewriter, PDLResultList &results,
795 ArrayRef<PDLValue> values) {
796 auto argIndices =
797 std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
798 1>();
799 assertArgs<RewriteFnT>(rewriter, values, argIndices);
800 return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
801 argIndices);
802 };
803}
804
805} // namespace pdl_function_builder
806} // namespace detail
807
808//===----------------------------------------------------------------------===//
809// PDLPatternModule
810
811/// This class contains all of the necessary data for a set of PDL patterns, or
812/// pattern rewrites specified in the form of the PDL dialect. This PDL module
813/// contained by this pattern may contain any number of `pdl.pattern`
814/// operations.
815class PDLPatternModule {
816public:
817 PDLPatternModule() = default;
818
819 /// Construct a PDL pattern with the given module and configurations.
820 PDLPatternModule(OwningOpRef<ModuleOp> module)
821 : pdlModule(std::move(module)) {}
822 template <typename... ConfigsT>
823 PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs)
824 : PDLPatternModule(std::move(module)) {
825 auto configSet = std::make_unique<PDLPatternConfigSet>(
826 std::forward<ConfigsT>(patternConfigs)...);
827 attachConfigToPatterns(*pdlModule, *configSet);
828 configs.emplace_back(std::move(configSet));
829 }
830
831 /// Merge the state in `other` into this pattern module.
832 void mergeIn(PDLPatternModule &&other);
833
834 /// Return the internal PDL module of this pattern.
835 ModuleOp getModule() { return pdlModule.get(); }
836
837 /// Return the MLIR context of this pattern.
838 MLIRContext *getContext() { return getModule()->getContext(); }
839
840 //===--------------------------------------------------------------------===//
841 // Function Registry
842
843 /// Register a constraint function with PDL. A constraint function may be
844 /// specified in one of two ways:
845 ///
846 /// * `LogicalResult (PatternRewriter &,
847 /// PDLResultList &,
848 /// ArrayRef<PDLValue>)`
849 ///
850 /// In this overload the arguments of the constraint function are passed via
851 /// the low-level PDLValue form, and the results are manually appended to
852 /// the given result list.
853 ///
854 /// * `LogicalResult (PatternRewriter &, ValueTs... values)`
855 ///
856 /// In this form the arguments of the constraint function are passed via the
857 /// expected high level C++ type. In this form, the framework will
858 /// automatically unwrap PDLValues and convert them to the expected ValueTs.
859 /// For example, if the constraint function accepts a `Operation *`, the
860 /// framework will automatically cast the input PDLValue. In the case of a
861 /// `StringRef`, the framework will automatically unwrap the argument as a
862 /// StringAttr and pass the underlying string value. To see the full list of
863 /// supported types, or to see how to add handling for custom types, view
864 /// the definition of `ProcessPDLValue` above.
865 void registerConstraintFunction(StringRef name,
866 PDLConstraintFunction constraintFn);
867 template <typename ConstraintFnT>
868 void registerConstraintFunction(StringRef name,
869 ConstraintFnT &&constraintFn) {
870 registerConstraintFunction(name,
871 detail::pdl_function_builder::buildConstraintFn(
872 std::forward<ConstraintFnT>(constraintFn)));
873 }
874
875 /// Register a rewrite function with PDL. A rewrite function may be specified
876 /// in one of two ways:
877 ///
878 /// * `void (PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)`
879 ///
880 /// In this overload the arguments of the constraint function are passed via
881 /// the low-level PDLValue form, and the results are manually appended to
882 /// the given result list.
883 ///
884 /// * `ResultT (PatternRewriter &, ValueTs... values)`
885 ///
886 /// In this form the arguments and result of the rewrite function are passed
887 /// via the expected high level C++ type. In this form, the framework will
888 /// automatically unwrap the PDLValues arguments and convert them to the
889 /// expected ValueTs. It will also automatically handle the processing and
890 /// packaging of the result value to the result list. For example, if the
891 /// rewrite function takes a `Operation *`, the framework will automatically
892 /// cast the input PDLValue. In the case of a `StringRef`, the framework
893 /// will automatically unwrap the argument as a StringAttr and pass the
894 /// underlying string value. In the reverse case, if the rewrite returns a
895 /// StringRef or std::string, it will automatically package this as a
896 /// StringAttr and append it to the result list. To see the full list of
897 /// supported types, or to see how to add handling for custom types, view
898 /// the definition of `ProcessPDLValue` above.
899 void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
900 template <typename RewriteFnT>
901 void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {
902 registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn(
903 std::forward<RewriteFnT>(rewriteFn)));
904 }
905
906 /// Return the set of the registered constraint functions.
907 const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
908 return constraintFunctions;
909 }
910 llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
911 return constraintFunctions;
912 }
913 /// Return the set of the registered rewrite functions.
914 const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
915 return rewriteFunctions;
916 }
917 llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
918 return rewriteFunctions;
919 }
920
921 /// Return the set of the registered pattern configs.
922 SmallVector<std::unique_ptr<PDLPatternConfigSet>> takeConfigs() {
923 return std::move(configs);
924 }
925 DenseMap<Operation *, PDLPatternConfigSet *> takeConfigMap() {
926 return std::move(configMap);
927 }
928
929 /// Clear out the patterns and functions within this module.
930 void clear() {
931 pdlModule = nullptr;
932 constraintFunctions.clear();
933 rewriteFunctions.clear();
934 }
935
936private:
937 /// Attach the given pattern config set to the patterns defined within the
938 /// given module.
939 void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet);
940
941 /// The module containing the `pdl.pattern` operations.
942 OwningOpRef<ModuleOp> pdlModule;
943
944 /// The set of configuration sets referenced by patterns within `pdlModule`.
945 SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
946 DenseMap<Operation *, PDLPatternConfigSet *> configMap;
947
948 /// The external functions referenced from within the PDL module.
949 llvm::StringMap<PDLConstraintFunction> constraintFunctions;
950 llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
951};
952} // namespace mlir
953
954#else
955
956namespace mlir {
957// Stubs for when PDL in pattern rewrites is not enabled.
958
959class PDLValue {
960public:
961 template <typename T>
962 T dyn_cast() const {
963 return nullptr;
964 }
965};
966class PDLResultList {};
967using PDLConstraintFunction = std::function<LogicalResult(
968 PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
969using PDLRewriteFunction = std::function<LogicalResult(
970 PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
971
972class PDLPatternModule {
973public:
974 PDLPatternModule() = default;
975
976 PDLPatternModule(OwningOpRef<ModuleOp> /*module*/) {}
977 MLIRContext *getContext() {
978 llvm_unreachable("Error: PDL for rewrites when PDL is not enabled");
979 }
980 void mergeIn(PDLPatternModule &&other) {}
981 void clear() {}
982 template <typename ConstraintFnT>
983 void registerConstraintFunction(StringRef name,
984 ConstraintFnT &&constraintFn) {}
985 void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn) {}
986 template <typename RewriteFnT>
987 void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {}
988 const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
989 return constraintFunctions;
990 }
991
992private:
993 llvm::StringMap<PDLConstraintFunction> constraintFunctions;
994};
995
996} // namespace mlir
997#endif
998
999#endif // MLIR_IR_PDLPATTERNMATCH_H
1000

source code of mlir/include/mlir/IR/PDLPatternMatch.h.inc