1 | //===- RegistryManager.cpp - Matcher registry -----------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Registry map populated at static initialization time. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "RegistryManager.h" |
14 | #include "mlir/Query/Matcher/Registry.h" |
15 | |
16 | #include <set> |
17 | #include <utility> |
18 | |
19 | namespace mlir::query::matcher { |
20 | namespace { |
21 | |
22 | // This is needed because these matchers are defined as overloaded functions. |
23 | using IsConstantOp = detail::constant_op_matcher(); |
24 | using HasOpAttrName = detail::AttrOpMatcher(llvm::StringRef); |
25 | using HasOpName = detail::NameOpMatcher(llvm::StringRef); |
26 | |
27 | // Enum to string for autocomplete. |
28 | static std::string asArgString(ArgKind kind) { |
29 | switch (kind) { |
30 | case ArgKind::Matcher: |
31 | return "Matcher" ; |
32 | case ArgKind::String: |
33 | return "String" ; |
34 | } |
35 | llvm_unreachable("Unhandled ArgKind" ); |
36 | } |
37 | |
38 | } // namespace |
39 | |
40 | void Registry::registerMatcherDescriptor( |
41 | llvm::StringRef matcherName, |
42 | std::unique_ptr<internal::MatcherDescriptor> callback) { |
43 | assert(!constructorMap.contains(matcherName)); |
44 | constructorMap[matcherName] = std::move(callback); |
45 | } |
46 | |
47 | std::optional<MatcherCtor> |
48 | RegistryManager::lookupMatcherCtor(llvm::StringRef matcherName, |
49 | const Registry &matcherRegistry) { |
50 | auto it = matcherRegistry.constructors().find(Key: matcherName); |
51 | return it == matcherRegistry.constructors().end() |
52 | ? std::optional<MatcherCtor>() |
53 | : it->second.get(); |
54 | } |
55 | |
56 | std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes( |
57 | llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) { |
58 | // Starting with the above seed of acceptable top-level matcher types, compute |
59 | // the acceptable type set for the argument indicated by each context element. |
60 | std::set<ArgKind> typeSet; |
61 | typeSet.insert(x: ArgKind::Matcher); |
62 | |
63 | for (const auto &ctxEntry : context) { |
64 | MatcherCtor ctor = ctxEntry.first; |
65 | unsigned argNumber = ctxEntry.second; |
66 | std::vector<ArgKind> nextTypeSet; |
67 | |
68 | if (argNumber < ctor->getNumArgs()) |
69 | ctor->getArgKinds(argNo: argNumber, argKinds&: nextTypeSet); |
70 | |
71 | typeSet.insert(first: nextTypeSet.begin(), last: nextTypeSet.end()); |
72 | } |
73 | |
74 | return std::vector<ArgKind>(typeSet.begin(), typeSet.end()); |
75 | } |
76 | |
77 | std::vector<MatcherCompletion> |
78 | RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes, |
79 | const Registry &matcherRegistry) { |
80 | std::vector<MatcherCompletion> completions; |
81 | |
82 | // Search the registry for acceptable matchers. |
83 | for (const auto &m : matcherRegistry.constructors()) { |
84 | const internal::MatcherDescriptor &matcher = *m.getValue(); |
85 | llvm::StringRef name = m.getKey(); |
86 | |
87 | unsigned numArgs = matcher.getNumArgs(); |
88 | std::vector<std::vector<ArgKind>> argKinds(numArgs); |
89 | |
90 | for (const ArgKind &kind : acceptedTypes) { |
91 | if (kind != ArgKind::Matcher) |
92 | continue; |
93 | |
94 | for (unsigned arg = 0; arg != numArgs; ++arg) |
95 | matcher.getArgKinds(argNo: arg, argKinds&: argKinds[arg]); |
96 | } |
97 | |
98 | std::string decl; |
99 | llvm::raw_string_ostream os(decl); |
100 | |
101 | std::string typedText = std::string(name); |
102 | os << "Matcher: " << name << "(" ; |
103 | |
104 | for (const std::vector<ArgKind> &arg : argKinds) { |
105 | if (&arg != &argKinds[0]) |
106 | os << ", " ; |
107 | |
108 | bool firstArgKind = true; |
109 | // Two steps. First all non-matchers, then matchers only. |
110 | for (const ArgKind &argKind : arg) { |
111 | if (!firstArgKind) |
112 | os << "|" ; |
113 | |
114 | firstArgKind = false; |
115 | os << asArgString(kind: argKind); |
116 | } |
117 | } |
118 | |
119 | os << ")" ; |
120 | typedText += "(" ; |
121 | |
122 | if (argKinds.empty()) |
123 | typedText += ")" ; |
124 | else if (argKinds[0][0] == ArgKind::String) |
125 | typedText += "\"" ; |
126 | |
127 | completions.emplace_back(args&: typedText, args&: os.str()); |
128 | } |
129 | |
130 | return completions; |
131 | } |
132 | |
133 | VariantMatcher RegistryManager::constructMatcher( |
134 | MatcherCtor ctor, internal::SourceRange nameRange, |
135 | llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args, |
136 | internal::Diagnostics *error) { |
137 | VariantMatcher out = ctor->create(nameRange, args, error); |
138 | if (functionName.empty() || out.isNull()) |
139 | return out; |
140 | |
141 | if (std::optional<DynMatcher> result = out.getDynMatcher()) { |
142 | result->setFunctionName(functionName); |
143 | return VariantMatcher::SingleMatcher(matcher: *result); |
144 | } |
145 | |
146 | error->addError(range: nameRange, error: internal::ErrorType::RegistryNotBindable); |
147 | return {}; |
148 | } |
149 | |
150 | } // namespace mlir::query::matcher |
151 | |