1 | //===- PDLPatternMatch.cpp - Base classes for PDL pattern match |
2 | //------------===// |
3 | // |
4 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
5 | // See https://llvm.org/LICENSE.txt for license information. |
6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
7 | // |
8 | //===----------------------------------------------------------------------===// |
9 | |
10 | #include "mlir/IR/IRMapping.h" |
11 | #include "mlir/IR/Iterators.h" |
12 | #include "mlir/IR/PatternMatch.h" |
13 | #include "mlir/IR/RegionKindInterface.h" |
14 | #include "llvm/Support/InterleavedRange.h" |
15 | |
16 | using namespace mlir; |
17 | |
18 | //===----------------------------------------------------------------------===// |
19 | // PDLValue |
20 | //===----------------------------------------------------------------------===// |
21 | |
22 | void PDLValue::print(raw_ostream &os) const { |
23 | if (!value) { |
24 | os << "<NULL-PDLValue>" ; |
25 | return; |
26 | } |
27 | switch (kind) { |
28 | case Kind::Attribute: |
29 | os << cast<Attribute>(); |
30 | break; |
31 | case Kind::Operation: |
32 | os << *cast<Operation *>(); |
33 | break; |
34 | case Kind::Type: |
35 | os << cast<Type>(); |
36 | break; |
37 | case Kind::TypeRange: |
38 | os << llvm::interleaved(R: cast<TypeRange>()); |
39 | break; |
40 | case Kind::Value: |
41 | os << cast<Value>(); |
42 | break; |
43 | case Kind::ValueRange: |
44 | os << llvm::interleaved(R: cast<ValueRange>()); |
45 | break; |
46 | } |
47 | } |
48 | |
49 | void PDLValue::print(raw_ostream &os, Kind kind) { |
50 | switch (kind) { |
51 | case Kind::Attribute: |
52 | os << "Attribute" ; |
53 | break; |
54 | case Kind::Operation: |
55 | os << "Operation" ; |
56 | break; |
57 | case Kind::Type: |
58 | os << "Type" ; |
59 | break; |
60 | case Kind::TypeRange: |
61 | os << "TypeRange" ; |
62 | break; |
63 | case Kind::Value: |
64 | os << "Value" ; |
65 | break; |
66 | case Kind::ValueRange: |
67 | os << "ValueRange" ; |
68 | break; |
69 | } |
70 | } |
71 | |
72 | //===----------------------------------------------------------------------===// |
73 | // PDLPatternModule |
74 | //===----------------------------------------------------------------------===// |
75 | |
76 | void PDLPatternModule::mergeIn(PDLPatternModule &&other) { |
77 | // Ignore the other module if it has no patterns. |
78 | if (!other.pdlModule) |
79 | return; |
80 | |
81 | // Steal the functions and config of the other module. |
82 | for (auto &it : other.constraintFunctions) |
83 | registerConstraintFunction(name: it.first(), constraintFn: std::move(it.second)); |
84 | for (auto &it : other.rewriteFunctions) |
85 | registerRewriteFunction(name: it.first(), rewriteFn: std::move(it.second)); |
86 | for (auto &it : other.configs) |
87 | configs.emplace_back(Args: std::move(it)); |
88 | for (auto &it : other.configMap) |
89 | configMap.insert(KV: it); |
90 | |
91 | // Steal the other state if we have no patterns. |
92 | if (!pdlModule) { |
93 | pdlModule = std::move(other.pdlModule); |
94 | return; |
95 | } |
96 | |
97 | // Merge the pattern operations from the other module into this one. |
98 | Block *block = pdlModule->getBody(); |
99 | block->getOperations().splice(block->end(), |
100 | other.pdlModule->getBody()->getOperations()); |
101 | } |
102 | |
103 | void PDLPatternModule::attachConfigToPatterns(ModuleOp module, |
104 | PDLPatternConfigSet &configSet) { |
105 | // Attach the configuration to the symbols within the module. We only add |
106 | // to symbols to avoid hardcoding any specific operation names here (given |
107 | // that we don't depend on any PDL dialect). We can't use |
108 | // cast<SymbolOpInterface> here because patterns may be optional symbols. |
109 | module->walk([&](Operation *op) { |
110 | if (op->hasTrait<SymbolOpInterface::Trait>()) |
111 | configMap[op] = &configSet; |
112 | }); |
113 | } |
114 | |
115 | //===----------------------------------------------------------------------===// |
116 | // Function Registry |
117 | //===----------------------------------------------------------------------===// |
118 | |
119 | void PDLPatternModule::registerConstraintFunction( |
120 | StringRef name, PDLConstraintFunction constraintFn) { |
121 | // TODO: Is it possible to diagnose when `name` is already registered to |
122 | // a function that is not equivalent to `constraintFn`? |
123 | // Allow existing mappings in the case multiple patterns depend on the same |
124 | // constraint. |
125 | constraintFunctions.try_emplace(Key: name, Args: std::move(constraintFn)); |
126 | } |
127 | |
128 | void PDLPatternModule::registerRewriteFunction(StringRef name, |
129 | PDLRewriteFunction rewriteFn) { |
130 | // TODO: Is it possible to diagnose when `name` is already registered to |
131 | // a function that is not equivalent to `rewriteFn`? |
132 | // Allow existing mappings in the case multiple patterns depend on the same |
133 | // rewrite. |
134 | rewriteFunctions.try_emplace(Key: name, Args: std::move(rewriteFn)); |
135 | } |
136 | |