1//===- RegistryManager.cpp - Matcher registry -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Registry map populated at static initialization time.
10//
11//===----------------------------------------------------------------------===//
12
13#include "RegistryManager.h"
14#include "mlir/Query/Matcher/Registry.h"
15
16#include <set>
17#include <utility>
18
19namespace mlir::query::matcher {
20namespace {
21
22// This is needed because these matchers are defined as overloaded functions.
23using IsConstantOp = detail::constant_op_matcher();
24using HasOpAttrName = detail::AttrOpMatcher(llvm::StringRef);
25using HasOpName = detail::NameOpMatcher(llvm::StringRef);
26
27// Enum to string for autocomplete.
28static std::string asArgString(ArgKind kind) {
29 switch (kind) {
30 case ArgKind::Matcher:
31 return "Matcher";
32 case ArgKind::String:
33 return "String";
34 }
35 llvm_unreachable("Unhandled ArgKind");
36}
37
38} // namespace
39
40void Registry::registerMatcherDescriptor(
41 llvm::StringRef matcherName,
42 std::unique_ptr<internal::MatcherDescriptor> callback) {
43 assert(!constructorMap.contains(matcherName));
44 constructorMap[matcherName] = std::move(callback);
45}
46
47std::optional<MatcherCtor>
48RegistryManager::lookupMatcherCtor(llvm::StringRef matcherName,
49 const Registry &matcherRegistry) {
50 auto it = matcherRegistry.constructors().find(Key: matcherName);
51 return it == matcherRegistry.constructors().end()
52 ? std::optional<MatcherCtor>()
53 : it->second.get();
54}
55
56std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes(
57 llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) {
58 // Starting with the above seed of acceptable top-level matcher types, compute
59 // the acceptable type set for the argument indicated by each context element.
60 std::set<ArgKind> typeSet;
61 typeSet.insert(x: ArgKind::Matcher);
62
63 for (const auto &ctxEntry : context) {
64 MatcherCtor ctor = ctxEntry.first;
65 unsigned argNumber = ctxEntry.second;
66 std::vector<ArgKind> nextTypeSet;
67
68 if (argNumber < ctor->getNumArgs())
69 ctor->getArgKinds(argNo: argNumber, argKinds&: nextTypeSet);
70
71 typeSet.insert(first: nextTypeSet.begin(), last: nextTypeSet.end());
72 }
73
74 return std::vector<ArgKind>(typeSet.begin(), typeSet.end());
75}
76
77std::vector<MatcherCompletion>
78RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
79 const Registry &matcherRegistry) {
80 std::vector<MatcherCompletion> completions;
81
82 // Search the registry for acceptable matchers.
83 for (const auto &m : matcherRegistry.constructors()) {
84 const internal::MatcherDescriptor &matcher = *m.getValue();
85 llvm::StringRef name = m.getKey();
86
87 unsigned numArgs = matcher.getNumArgs();
88 std::vector<std::vector<ArgKind>> argKinds(numArgs);
89
90 for (const ArgKind &kind : acceptedTypes) {
91 if (kind != ArgKind::Matcher)
92 continue;
93
94 for (unsigned arg = 0; arg != numArgs; ++arg)
95 matcher.getArgKinds(argNo: arg, argKinds&: argKinds[arg]);
96 }
97
98 std::string decl;
99 llvm::raw_string_ostream os(decl);
100
101 std::string typedText = std::string(name);
102 os << "Matcher: " << name << "(";
103
104 for (const std::vector<ArgKind> &arg : argKinds) {
105 if (&arg != &argKinds[0])
106 os << ", ";
107
108 bool firstArgKind = true;
109 // Two steps. First all non-matchers, then matchers only.
110 for (const ArgKind &argKind : arg) {
111 if (!firstArgKind)
112 os << "|";
113
114 firstArgKind = false;
115 os << asArgString(kind: argKind);
116 }
117 }
118
119 os << ")";
120 typedText += "(";
121
122 if (argKinds.empty())
123 typedText += ")";
124 else if (argKinds[0][0] == ArgKind::String)
125 typedText += "\"";
126
127 completions.emplace_back(args&: typedText, args&: os.str());
128 }
129
130 return completions;
131}
132
133VariantMatcher RegistryManager::constructMatcher(
134 MatcherCtor ctor, internal::SourceRange nameRange,
135 llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args,
136 internal::Diagnostics *error) {
137 VariantMatcher out = ctor->create(nameRange, args, error);
138 if (functionName.empty() || out.isNull())
139 return out;
140
141 if (std::optional<DynMatcher> result = out.getDynMatcher()) {
142 result->setFunctionName(functionName);
143 return VariantMatcher::SingleMatcher(matcher: *result);
144 }
145
146 error->addError(range: nameRange, error: internal::ErrorType::RegistryNotBindable);
147 return {};
148}
149
150} // namespace mlir::query::matcher
151

source code of mlir/lib/Query/Matcher/RegistryManager.cpp