1 | //===- FrozenRewritePatternSet.h --------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_REWRITE_FROZENREWRITEPATTERNSET_H |
10 | #define MLIR_REWRITE_FROZENREWRITEPATTERNSET_H |
11 | |
12 | #include "mlir/IR/PatternMatch.h" |
13 | |
14 | namespace mlir { |
15 | namespace detail { |
16 | class PDLByteCode; |
17 | } // namespace detail |
18 | |
19 | /// This class represents a frozen set of patterns that can be processed by a |
20 | /// pattern applicator. This class is designed to enable caching pattern lists |
21 | /// such that they need not be continuously recomputed. Note that all copies of |
22 | /// this class share the same compiled pattern list, allowing for a reduction in |
23 | /// the number of duplicated patterns that need to be created. |
24 | class FrozenRewritePatternSet { |
25 | using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>; |
26 | |
27 | public: |
28 | /// A map of operation specific native patterns. |
29 | using OpSpecificNativePatternListT = |
30 | DenseMap<OperationName, std::vector<RewritePattern *>>; |
31 | |
32 | FrozenRewritePatternSet(); |
33 | FrozenRewritePatternSet(FrozenRewritePatternSet &&patterns) = default; |
34 | FrozenRewritePatternSet(const FrozenRewritePatternSet &patterns) = default; |
35 | FrozenRewritePatternSet & |
36 | operator=(const FrozenRewritePatternSet &patterns) = default; |
37 | FrozenRewritePatternSet & |
38 | operator=(FrozenRewritePatternSet &&patterns) = default; |
39 | ~FrozenRewritePatternSet(); |
40 | |
41 | /// Freeze the patterns held in `patterns`, and take ownership. |
42 | /// `disabledPatternLabels` is a set of labels used to filter out input |
43 | /// patterns with a debug label or debug name in this set. |
44 | /// `enabledPatternLabels` is a set of labels used to filter out input |
45 | /// patterns that do not have one of the labels in this set. Debug labels must |
46 | /// be set explicitly on patterns or when adding them with |
47 | /// `RewritePatternSet::addWithLabel`. Debug names may be empty, but patterns |
48 | /// created with `RewritePattern::create` have their default debug name set to |
49 | /// their type name. |
50 | FrozenRewritePatternSet( |
51 | RewritePatternSet &&patterns, |
52 | ArrayRef<std::string> disabledPatternLabels = std::nullopt, |
53 | ArrayRef<std::string> enabledPatternLabels = std::nullopt); |
54 | |
55 | /// Return the op specific native patterns held by this list. |
56 | const OpSpecificNativePatternListT &getOpSpecificNativePatterns() const { |
57 | return impl->nativeOpSpecificPatternMap; |
58 | } |
59 | |
60 | /// Return the "match any" native patterns held by this list. |
61 | iterator_range<llvm::pointee_iterator<NativePatternListT::const_iterator>> |
62 | getMatchAnyOpNativePatterns() const { |
63 | const NativePatternListT &nativeList = impl->nativeAnyOpPatterns; |
64 | return llvm::make_pointee_range(Range: nativeList); |
65 | } |
66 | |
67 | /// Return the compiled PDL bytecode held by this list. Returns null if |
68 | /// there are no PDL patterns within the list. |
69 | const detail::PDLByteCode *getPDLByteCode() const { |
70 | return impl->pdlByteCode.get(); |
71 | } |
72 | |
73 | private: |
74 | /// The internal implementation of the frozen pattern list. |
75 | struct Impl { |
76 | /// The set of native C++ rewrite patterns that are matched to specific |
77 | /// operation kinds. |
78 | OpSpecificNativePatternListT nativeOpSpecificPatternMap; |
79 | |
80 | /// The full op-specific native rewrite list. This allows for the map above |
81 | /// to contain duplicate patterns, e.g. for interfaces and traits. |
82 | NativePatternListT nativeOpSpecificPatternList; |
83 | |
84 | /// The set of native C++ rewrite patterns that are matched to "any" |
85 | /// operation. |
86 | NativePatternListT nativeAnyOpPatterns; |
87 | |
88 | /// The bytecode containing the compiled PDL patterns. |
89 | std::unique_ptr<detail::PDLByteCode> pdlByteCode; |
90 | }; |
91 | |
92 | /// A pointer to the internal pattern list. This uses a shared_ptr to avoid |
93 | /// the need to compile the same pattern list multiple times. For example, |
94 | /// during multi-threaded pass execution, all copies of a pass can share the |
95 | /// same pattern list. |
96 | std::shared_ptr<Impl> impl; |
97 | }; |
98 | |
99 | } // namespace mlir |
100 | |
101 | #endif // MLIR_REWRITE_FROZENREWRITEPATTERNSET_H |
102 | |