1 | //===- RegistryManager.cpp - Matcher registry -----------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Registry map populated at static initialization time. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "RegistryManager.h" |
14 | #include "mlir/Query/Matcher/Registry.h" |
15 | |
16 | #include <set> |
17 | #include <utility> |
18 | |
19 | namespace mlir::query::matcher { |
20 | namespace { |
21 | |
22 | // Enum to string for autocomplete. |
23 | static std::string asArgString(ArgKind kind) { |
24 | switch (kind) { |
25 | case ArgKind::Boolean: |
26 | return "Boolean" ; |
27 | case ArgKind::Matcher: |
28 | return "Matcher" ; |
29 | case ArgKind::Signed: |
30 | return "Signed" ; |
31 | case ArgKind::String: |
32 | return "String" ; |
33 | } |
34 | llvm_unreachable("Unhandled ArgKind" ); |
35 | } |
36 | |
37 | } // namespace |
38 | |
39 | void Registry::registerMatcherDescriptor( |
40 | llvm::StringRef matcherName, |
41 | std::unique_ptr<internal::MatcherDescriptor> callback) { |
42 | assert(!constructorMap.contains(matcherName)); |
43 | constructorMap[matcherName] = std::move(callback); |
44 | } |
45 | |
46 | std::optional<MatcherCtor> |
47 | RegistryManager::lookupMatcherCtor(llvm::StringRef matcherName, |
48 | const Registry &matcherRegistry) { |
49 | auto it = matcherRegistry.constructors().find(Key: matcherName); |
50 | return it == matcherRegistry.constructors().end() |
51 | ? std::optional<MatcherCtor>() |
52 | : it->second.get(); |
53 | } |
54 | |
55 | std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes( |
56 | llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) { |
57 | // Starting with the above seed of acceptable top-level matcher types, compute |
58 | // the acceptable type set for the argument indicated by each context element. |
59 | std::set<ArgKind> typeSet; |
60 | typeSet.insert(x: ArgKind::Matcher); |
61 | |
62 | for (const auto &ctxEntry : context) { |
63 | MatcherCtor ctor = ctxEntry.first; |
64 | unsigned argNumber = ctxEntry.second; |
65 | std::vector<ArgKind> nextTypeSet; |
66 | |
67 | if (argNumber < ctor->getNumArgs()) |
68 | ctor->getArgKinds(argNo: argNumber, argKinds&: nextTypeSet); |
69 | |
70 | typeSet.insert(first: nextTypeSet.begin(), last: nextTypeSet.end()); |
71 | } |
72 | |
73 | return std::vector<ArgKind>(typeSet.begin(), typeSet.end()); |
74 | } |
75 | |
76 | std::vector<MatcherCompletion> |
77 | RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes, |
78 | const Registry &matcherRegistry) { |
79 | std::vector<MatcherCompletion> completions; |
80 | |
81 | // Search the registry for acceptable matchers. |
82 | for (const auto &m : matcherRegistry.constructors()) { |
83 | const internal::MatcherDescriptor &matcher = *m.getValue(); |
84 | llvm::StringRef name = m.getKey(); |
85 | |
86 | unsigned numArgs = matcher.getNumArgs(); |
87 | std::vector<std::vector<ArgKind>> argKinds(numArgs); |
88 | |
89 | for (const ArgKind &kind : acceptedTypes) { |
90 | if (kind != ArgKind::Matcher) |
91 | continue; |
92 | |
93 | for (unsigned arg = 0; arg != numArgs; ++arg) |
94 | matcher.getArgKinds(argNo: arg, argKinds&: argKinds[arg]); |
95 | } |
96 | |
97 | std::string decl; |
98 | llvm::raw_string_ostream os(decl); |
99 | |
100 | std::string typedText = std::string(name); |
101 | os << "Matcher: " << name << "(" ; |
102 | |
103 | for (const std::vector<ArgKind> &arg : argKinds) { |
104 | if (&arg != &argKinds[0]) |
105 | os << ", " ; |
106 | |
107 | bool firstArgKind = true; |
108 | // Two steps. First all non-matchers, then matchers only. |
109 | for (const ArgKind &argKind : arg) { |
110 | if (!firstArgKind) |
111 | os << "|" ; |
112 | |
113 | firstArgKind = false; |
114 | os << asArgString(kind: argKind); |
115 | } |
116 | } |
117 | |
118 | os << ")" ; |
119 | typedText += "(" ; |
120 | |
121 | if (argKinds.empty()) |
122 | typedText += ")" ; |
123 | else if (argKinds[0][0] == ArgKind::String) |
124 | typedText += "\"" ; |
125 | |
126 | completions.emplace_back(args&: typedText, args&: decl); |
127 | } |
128 | |
129 | return completions; |
130 | } |
131 | |
132 | VariantMatcher RegistryManager::constructMatcher( |
133 | MatcherCtor ctor, internal::SourceRange nameRange, |
134 | llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args, |
135 | internal::Diagnostics *error) { |
136 | VariantMatcher out = ctor->create(nameRange, args, error); |
137 | if (functionName.empty() || out.isNull()) |
138 | return out; |
139 | |
140 | if (std::optional<DynMatcher> result = out.getDynMatcher()) { |
141 | result->setFunctionName(functionName); |
142 | return VariantMatcher::SingleMatcher(matcher: *result); |
143 | } |
144 | |
145 | error->addError(range: nameRange, error: internal::ErrorType::RegistryNotBindable); |
146 | return {}; |
147 | } |
148 | |
149 | } // namespace mlir::query::matcher |
150 | |