1 | //===- PDLPatternMatch.cpp - Base classes for PDL pattern match |
2 | //------------===// |
3 | // |
4 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
5 | // See https://llvm.org/LICENSE.txt for license information. |
6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
7 | // |
8 | //===----------------------------------------------------------------------===// |
9 | |
10 | #include "mlir/IR/IRMapping.h" |
11 | #include "mlir/IR/Iterators.h" |
12 | #include "mlir/IR/PatternMatch.h" |
13 | #include "mlir/IR/RegionKindInterface.h" |
14 | |
15 | using namespace mlir; |
16 | |
17 | //===----------------------------------------------------------------------===// |
18 | // PDLValue |
19 | //===----------------------------------------------------------------------===// |
20 | |
21 | void PDLValue::print(raw_ostream &os) const { |
22 | if (!value) { |
23 | os << "<NULL-PDLValue>" ; |
24 | return; |
25 | } |
26 | switch (kind) { |
27 | case Kind::Attribute: |
28 | os << cast<Attribute>(); |
29 | break; |
30 | case Kind::Operation: |
31 | os << *cast<Operation *>(); |
32 | break; |
33 | case Kind::Type: |
34 | os << cast<Type>(); |
35 | break; |
36 | case Kind::TypeRange: |
37 | llvm::interleaveComma(c: cast<TypeRange>(), os); |
38 | break; |
39 | case Kind::Value: |
40 | os << cast<Value>(); |
41 | break; |
42 | case Kind::ValueRange: |
43 | llvm::interleaveComma(c: cast<ValueRange>(), os); |
44 | break; |
45 | } |
46 | } |
47 | |
48 | void PDLValue::print(raw_ostream &os, Kind kind) { |
49 | switch (kind) { |
50 | case Kind::Attribute: |
51 | os << "Attribute" ; |
52 | break; |
53 | case Kind::Operation: |
54 | os << "Operation" ; |
55 | break; |
56 | case Kind::Type: |
57 | os << "Type" ; |
58 | break; |
59 | case Kind::TypeRange: |
60 | os << "TypeRange" ; |
61 | break; |
62 | case Kind::Value: |
63 | os << "Value" ; |
64 | break; |
65 | case Kind::ValueRange: |
66 | os << "ValueRange" ; |
67 | break; |
68 | } |
69 | } |
70 | |
71 | //===----------------------------------------------------------------------===// |
72 | // PDLPatternModule |
73 | //===----------------------------------------------------------------------===// |
74 | |
75 | void PDLPatternModule::mergeIn(PDLPatternModule &&other) { |
76 | // Ignore the other module if it has no patterns. |
77 | if (!other.pdlModule) |
78 | return; |
79 | |
80 | // Steal the functions and config of the other module. |
81 | for (auto &it : other.constraintFunctions) |
82 | registerConstraintFunction(name: it.first(), constraintFn: std::move(it.second)); |
83 | for (auto &it : other.rewriteFunctions) |
84 | registerRewriteFunction(name: it.first(), rewriteFn: std::move(it.second)); |
85 | for (auto &it : other.configs) |
86 | configs.emplace_back(Args: std::move(it)); |
87 | for (auto &it : other.configMap) |
88 | configMap.insert(KV: it); |
89 | |
90 | // Steal the other state if we have no patterns. |
91 | if (!pdlModule) { |
92 | pdlModule = std::move(other.pdlModule); |
93 | return; |
94 | } |
95 | |
96 | // Merge the pattern operations from the other module into this one. |
97 | Block *block = pdlModule->getBody(); |
98 | block->getOperations().splice(block->end(), |
99 | other.pdlModule->getBody()->getOperations()); |
100 | } |
101 | |
102 | void PDLPatternModule::attachConfigToPatterns(ModuleOp module, |
103 | PDLPatternConfigSet &configSet) { |
104 | // Attach the configuration to the symbols within the module. We only add |
105 | // to symbols to avoid hardcoding any specific operation names here (given |
106 | // that we don't depend on any PDL dialect). We can't use |
107 | // cast<SymbolOpInterface> here because patterns may be optional symbols. |
108 | module->walk([&](Operation *op) { |
109 | if (op->hasTrait<SymbolOpInterface::Trait>()) |
110 | configMap[op] = &configSet; |
111 | }); |
112 | } |
113 | |
114 | //===----------------------------------------------------------------------===// |
115 | // Function Registry |
116 | |
117 | void PDLPatternModule::registerConstraintFunction( |
118 | StringRef name, PDLConstraintFunction constraintFn) { |
119 | // TODO: Is it possible to diagnose when `name` is already registered to |
120 | // a function that is not equivalent to `constraintFn`? |
121 | // Allow existing mappings in the case multiple patterns depend on the same |
122 | // constraint. |
123 | constraintFunctions.try_emplace(Key: name, Args: std::move(constraintFn)); |
124 | } |
125 | |
126 | void PDLPatternModule::registerRewriteFunction(StringRef name, |
127 | PDLRewriteFunction rewriteFn) { |
128 | // TODO: Is it possible to diagnose when `name` is already registered to |
129 | // a function that is not equivalent to `rewriteFn`? |
130 | // Allow existing mappings in the case multiple patterns depend on the same |
131 | // rewrite. |
132 | rewriteFunctions.try_emplace(Key: name, Args: std::move(rewriteFn)); |
133 | } |
134 | |