1 | //===- OpGenHelpers.cpp - MLIR operation generator helpers ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines helpers used in the op generators. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "OpGenHelpers.h" |
14 | #include "llvm/ADT/StringSet.h" |
15 | #include "llvm/Support/CommandLine.h" |
16 | #include "llvm/Support/FormatVariadic.h" |
17 | #include "llvm/Support/Regex.h" |
18 | #include "llvm/TableGen/Error.h" |
19 | |
20 | using namespace llvm; |
21 | using namespace mlir; |
22 | using namespace mlir::tblgen; |
23 | |
24 | cl::OptionCategory opDefGenCat("Options for op definition generators" ); |
25 | |
26 | static cl::opt<std::string> opIncFilter( |
27 | "op-include-regex" , |
28 | cl::desc("Regex of name of op's to include (no filter if empty)" ), |
29 | cl::cat(opDefGenCat)); |
30 | static cl::opt<std::string> opExcFilter( |
31 | "op-exclude-regex" , |
32 | cl::desc("Regex of name of op's to exclude (no filter if empty)" ), |
33 | cl::cat(opDefGenCat)); |
34 | static cl::opt<unsigned> opShardCount( |
35 | "op-shard-count" , |
36 | cl::desc("The number of shards into which the op classes will be divided" ), |
37 | cl::cat(opDefGenCat), cl::init(Val: 1)); |
38 | |
39 | static std::string getOperationName(const Record &def) { |
40 | auto prefix = def.getValueAsDef(FieldName: "opDialect" )->getValueAsString(FieldName: "name" ); |
41 | auto opName = def.getValueAsString(FieldName: "opName" ); |
42 | if (prefix.empty()) |
43 | return std::string(opName); |
44 | return std::string(formatv(Fmt: "{0}.{1}" , Vals&: prefix, Vals&: opName)); |
45 | } |
46 | |
47 | std::vector<const Record *> |
48 | mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &records) { |
49 | const Record *classDef = records.getClass(Name: "Op" ); |
50 | if (!classDef) |
51 | PrintFatalError(Msg: "ERROR: Couldn't find the 'Op' class!\n" ); |
52 | |
53 | Regex includeRegex(opIncFilter), excludeRegex(opExcFilter); |
54 | std::vector<const Record *> defs; |
55 | for (const auto &def : records.getDefs()) { |
56 | if (!def.second->isSubClassOf(R: classDef)) |
57 | continue; |
58 | // Include if no include filter or include filter matches. |
59 | if (!opIncFilter.empty() && |
60 | !includeRegex.match(String: getOperationName(def: *def.second))) |
61 | continue; |
62 | // Unless there is an exclude filter and it matches. |
63 | if (!opExcFilter.empty() && |
64 | excludeRegex.match(String: getOperationName(def: *def.second))) |
65 | continue; |
66 | defs.push_back(x: def.second.get()); |
67 | } |
68 | |
69 | return defs; |
70 | } |
71 | |
72 | bool mlir::tblgen::isPythonReserved(StringRef str) { |
73 | static StringSet<> reserved({ |
74 | "False" , "None" , "True" , "and" , "as" , "assert" , "async" , |
75 | "await" , "break" , "class" , "continue" , "def" , "del" , "elif" , |
76 | "else" , "except" , "finally" , "for" , "from" , "global" , "if" , |
77 | "import" , "in" , "is" , "lambda" , "nonlocal" , "not" , "or" , |
78 | "pass" , "raise" , "return" , "try" , "while" , "with" , "yield" , |
79 | }); |
80 | // These aren't Python keywords but builtin functions that shouldn't/can't be |
81 | // shadowed. |
82 | reserved.insert(key: "callable" ); |
83 | reserved.insert(key: "issubclass" ); |
84 | reserved.insert(key: "type" ); |
85 | return reserved.contains(key: str); |
86 | } |
87 | |
88 | void mlir::tblgen::shardOpDefinitions( |
89 | ArrayRef<const Record *> defs, |
90 | SmallVectorImpl<ArrayRef<const Record *>> &shardedDefs) { |
91 | assert(opShardCount > 0 && "expected a positive shard count" ); |
92 | if (opShardCount == 1) { |
93 | shardedDefs.push_back(Elt: defs); |
94 | return; |
95 | } |
96 | |
97 | unsigned minShardSize = defs.size() / opShardCount; |
98 | unsigned numMissing = defs.size() - minShardSize * opShardCount; |
99 | shardedDefs.reserve(N: opShardCount); |
100 | for (unsigned i = 0, start = 0; i < opShardCount; ++i) { |
101 | unsigned size = minShardSize + (i < numMissing); |
102 | shardedDefs.push_back(Elt: defs.slice(N: start, M: size)); |
103 | start += size; |
104 | } |
105 | } |
106 | |