| 1 | //===- OpGenHelpers.cpp - MLIR operation generator helpers ----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file defines helpers used in the op generators. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "OpGenHelpers.h" |
| 14 | #include "llvm/ADT/StringSet.h" |
| 15 | #include "llvm/Support/CommandLine.h" |
| 16 | #include "llvm/Support/FormatVariadic.h" |
| 17 | #include "llvm/Support/Regex.h" |
| 18 | #include "llvm/TableGen/Error.h" |
| 19 | |
| 20 | using namespace llvm; |
| 21 | using namespace mlir; |
| 22 | using namespace mlir::tblgen; |
| 23 | |
| 24 | cl::OptionCategory opDefGenCat("Options for op definition generators" ); |
| 25 | |
| 26 | static cl::opt<std::string> opIncFilter( |
| 27 | "op-include-regex" , |
| 28 | cl::desc("Regex of name of op's to include (no filter if empty)" ), |
| 29 | cl::cat(opDefGenCat)); |
| 30 | static cl::opt<std::string> opExcFilter( |
| 31 | "op-exclude-regex" , |
| 32 | cl::desc("Regex of name of op's to exclude (no filter if empty)" ), |
| 33 | cl::cat(opDefGenCat)); |
| 34 | static cl::opt<unsigned> opShardCount( |
| 35 | "op-shard-count" , |
| 36 | cl::desc("The number of shards into which the op classes will be divided" ), |
| 37 | cl::cat(opDefGenCat), cl::init(Val: 1)); |
| 38 | |
| 39 | static std::string getOperationName(const Record &def) { |
| 40 | auto prefix = def.getValueAsDef(FieldName: "opDialect" )->getValueAsString(FieldName: "name" ); |
| 41 | auto opName = def.getValueAsString(FieldName: "opName" ); |
| 42 | if (prefix.empty()) |
| 43 | return std::string(opName); |
| 44 | return std::string(formatv(Fmt: "{0}.{1}" , Vals&: prefix, Vals&: opName)); |
| 45 | } |
| 46 | |
| 47 | std::vector<const Record *> |
| 48 | mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &records) { |
| 49 | const Record *classDef = records.getClass(Name: "Op" ); |
| 50 | if (!classDef) |
| 51 | PrintFatalError(Msg: "ERROR: Couldn't find the 'Op' class!\n" ); |
| 52 | |
| 53 | Regex includeRegex(opIncFilter), excludeRegex(opExcFilter); |
| 54 | std::vector<const Record *> defs; |
| 55 | for (const auto &def : records.getDefs()) { |
| 56 | if (!def.second->isSubClassOf(R: classDef)) |
| 57 | continue; |
| 58 | // Include if no include filter or include filter matches. |
| 59 | if (!opIncFilter.empty() && |
| 60 | !includeRegex.match(String: getOperationName(def: *def.second))) |
| 61 | continue; |
| 62 | // Unless there is an exclude filter and it matches. |
| 63 | if (!opExcFilter.empty() && |
| 64 | excludeRegex.match(String: getOperationName(def: *def.second))) |
| 65 | continue; |
| 66 | defs.push_back(x: def.second.get()); |
| 67 | } |
| 68 | |
| 69 | return defs; |
| 70 | } |
| 71 | |
| 72 | bool mlir::tblgen::isPythonReserved(StringRef str) { |
| 73 | static StringSet<> reserved({ |
| 74 | "False" , "None" , "True" , "and" , "as" , "assert" , "async" , |
| 75 | "await" , "break" , "class" , "continue" , "def" , "del" , "elif" , |
| 76 | "else" , "except" , "finally" , "for" , "from" , "global" , "if" , |
| 77 | "import" , "in" , "is" , "lambda" , "nonlocal" , "not" , "or" , |
| 78 | "pass" , "raise" , "return" , "try" , "while" , "with" , "yield" , |
| 79 | }); |
| 80 | // These aren't Python keywords but builtin functions that shouldn't/can't be |
| 81 | // shadowed. |
| 82 | reserved.insert(key: "callable" ); |
| 83 | reserved.insert(key: "issubclass" ); |
| 84 | reserved.insert(key: "type" ); |
| 85 | return reserved.contains(key: str); |
| 86 | } |
| 87 | |
| 88 | void mlir::tblgen::shardOpDefinitions( |
| 89 | ArrayRef<const Record *> defs, |
| 90 | SmallVectorImpl<ArrayRef<const Record *>> &shardedDefs) { |
| 91 | assert(opShardCount > 0 && "expected a positive shard count" ); |
| 92 | if (opShardCount == 1) { |
| 93 | shardedDefs.push_back(Elt: defs); |
| 94 | return; |
| 95 | } |
| 96 | |
| 97 | unsigned minShardSize = defs.size() / opShardCount; |
| 98 | unsigned numMissing = defs.size() - minShardSize * opShardCount; |
| 99 | shardedDefs.reserve(N: opShardCount); |
| 100 | for (unsigned i = 0, start = 0; i < opShardCount; ++i) { |
| 101 | unsigned size = minShardSize + (i < numMissing); |
| 102 | shardedDefs.push_back(Elt: defs.slice(N: start, M: size)); |
| 103 | start += size; |
| 104 | } |
| 105 | } |
| 106 | |