| 1 | //===- RegistryManager.cpp - Matcher registry -----------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Registry map populated at static initialization time. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "RegistryManager.h" |
| 14 | #include "mlir/Query/Matcher/Registry.h" |
| 15 | |
| 16 | #include <set> |
| 17 | #include <utility> |
| 18 | |
| 19 | namespace mlir::query::matcher { |
| 20 | namespace { |
| 21 | |
| 22 | // Enum to string for autocomplete. |
| 23 | static std::string asArgString(ArgKind kind) { |
| 24 | switch (kind) { |
| 25 | case ArgKind::Boolean: |
| 26 | return "Boolean" ; |
| 27 | case ArgKind::Matcher: |
| 28 | return "Matcher" ; |
| 29 | case ArgKind::Signed: |
| 30 | return "Signed" ; |
| 31 | case ArgKind::String: |
| 32 | return "String" ; |
| 33 | } |
| 34 | llvm_unreachable("Unhandled ArgKind" ); |
| 35 | } |
| 36 | |
| 37 | } // namespace |
| 38 | |
| 39 | void Registry::registerMatcherDescriptor( |
| 40 | llvm::StringRef matcherName, |
| 41 | std::unique_ptr<internal::MatcherDescriptor> callback) { |
| 42 | assert(!constructorMap.contains(matcherName)); |
| 43 | constructorMap[matcherName] = std::move(callback); |
| 44 | } |
| 45 | |
| 46 | std::optional<MatcherCtor> |
| 47 | RegistryManager::lookupMatcherCtor(llvm::StringRef matcherName, |
| 48 | const Registry &matcherRegistry) { |
| 49 | auto it = matcherRegistry.constructors().find(Key: matcherName); |
| 50 | return it == matcherRegistry.constructors().end() |
| 51 | ? std::optional<MatcherCtor>() |
| 52 | : it->second.get(); |
| 53 | } |
| 54 | |
| 55 | std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes( |
| 56 | llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) { |
| 57 | // Starting with the above seed of acceptable top-level matcher types, compute |
| 58 | // the acceptable type set for the argument indicated by each context element. |
| 59 | std::set<ArgKind> typeSet; |
| 60 | typeSet.insert(x: ArgKind::Matcher); |
| 61 | |
| 62 | for (const auto &ctxEntry : context) { |
| 63 | MatcherCtor ctor = ctxEntry.first; |
| 64 | unsigned argNumber = ctxEntry.second; |
| 65 | std::vector<ArgKind> nextTypeSet; |
| 66 | |
| 67 | if (argNumber < ctor->getNumArgs()) |
| 68 | ctor->getArgKinds(argNo: argNumber, argKinds&: nextTypeSet); |
| 69 | |
| 70 | typeSet.insert(first: nextTypeSet.begin(), last: nextTypeSet.end()); |
| 71 | } |
| 72 | |
| 73 | return std::vector<ArgKind>(typeSet.begin(), typeSet.end()); |
| 74 | } |
| 75 | |
| 76 | std::vector<MatcherCompletion> |
| 77 | RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes, |
| 78 | const Registry &matcherRegistry) { |
| 79 | std::vector<MatcherCompletion> completions; |
| 80 | |
| 81 | // Search the registry for acceptable matchers. |
| 82 | for (const auto &m : matcherRegistry.constructors()) { |
| 83 | const internal::MatcherDescriptor &matcher = *m.getValue(); |
| 84 | llvm::StringRef name = m.getKey(); |
| 85 | |
| 86 | unsigned numArgs = matcher.getNumArgs(); |
| 87 | std::vector<std::vector<ArgKind>> argKinds(numArgs); |
| 88 | |
| 89 | for (const ArgKind &kind : acceptedTypes) { |
| 90 | if (kind != ArgKind::Matcher) |
| 91 | continue; |
| 92 | |
| 93 | for (unsigned arg = 0; arg != numArgs; ++arg) |
| 94 | matcher.getArgKinds(argNo: arg, argKinds&: argKinds[arg]); |
| 95 | } |
| 96 | |
| 97 | std::string decl; |
| 98 | llvm::raw_string_ostream os(decl); |
| 99 | |
| 100 | std::string typedText = std::string(name); |
| 101 | os << "Matcher: " << name << "(" ; |
| 102 | |
| 103 | for (const std::vector<ArgKind> &arg : argKinds) { |
| 104 | if (&arg != &argKinds[0]) |
| 105 | os << ", " ; |
| 106 | |
| 107 | bool firstArgKind = true; |
| 108 | // Two steps. First all non-matchers, then matchers only. |
| 109 | for (const ArgKind &argKind : arg) { |
| 110 | if (!firstArgKind) |
| 111 | os << "|" ; |
| 112 | |
| 113 | firstArgKind = false; |
| 114 | os << asArgString(kind: argKind); |
| 115 | } |
| 116 | } |
| 117 | |
| 118 | os << ")" ; |
| 119 | typedText += "(" ; |
| 120 | |
| 121 | if (argKinds.empty()) |
| 122 | typedText += ")" ; |
| 123 | else if (argKinds[0][0] == ArgKind::String) |
| 124 | typedText += "\"" ; |
| 125 | |
| 126 | completions.emplace_back(args&: typedText, args&: decl); |
| 127 | } |
| 128 | |
| 129 | return completions; |
| 130 | } |
| 131 | |
| 132 | VariantMatcher RegistryManager::constructMatcher( |
| 133 | MatcherCtor ctor, internal::SourceRange nameRange, |
| 134 | llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args, |
| 135 | internal::Diagnostics *error) { |
| 136 | VariantMatcher out = ctor->create(nameRange, args, error); |
| 137 | if (functionName.empty() || out.isNull()) |
| 138 | return out; |
| 139 | |
| 140 | if (std::optional<DynMatcher> result = out.getDynMatcher()) { |
| 141 | result->setFunctionName(functionName); |
| 142 | return VariantMatcher::SingleMatcher(matcher: *result); |
| 143 | } |
| 144 | |
| 145 | error->addError(range: nameRange, error: internal::ErrorType::RegistryNotBindable); |
| 146 | return {}; |
| 147 | } |
| 148 | |
| 149 | } // namespace mlir::query::matcher |
| 150 | |