| 1 | //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Support/IndentedOstream.h" |
| 14 | #include "mlir/TableGen/Argument.h" |
| 15 | #include "mlir/TableGen/Attribute.h" |
| 16 | #include "mlir/TableGen/CodeGenHelpers.h" |
| 17 | #include "mlir/TableGen/Format.h" |
| 18 | #include "mlir/TableGen/GenInfo.h" |
| 19 | #include "mlir/TableGen/Operator.h" |
| 20 | #include "mlir/TableGen/Pattern.h" |
| 21 | #include "mlir/TableGen/Predicate.h" |
| 22 | #include "mlir/TableGen/Property.h" |
| 23 | #include "mlir/TableGen/Type.h" |
| 24 | #include "llvm/ADT/FunctionExtras.h" |
| 25 | #include "llvm/ADT/SetVector.h" |
| 26 | #include "llvm/ADT/StringExtras.h" |
| 27 | #include "llvm/ADT/StringSet.h" |
| 28 | #include "llvm/Support/CommandLine.h" |
| 29 | #include "llvm/Support/Debug.h" |
| 30 | #include "llvm/Support/FormatAdapters.h" |
| 31 | #include "llvm/Support/PrettyStackTrace.h" |
| 32 | #include "llvm/Support/Signals.h" |
| 33 | #include "llvm/TableGen/Error.h" |
| 34 | #include "llvm/TableGen/Main.h" |
| 35 | #include "llvm/TableGen/Record.h" |
| 36 | #include "llvm/TableGen/TableGenBackend.h" |
| 37 | |
| 38 | using namespace mlir; |
| 39 | using namespace mlir::tblgen; |
| 40 | |
| 41 | using llvm::formatv; |
| 42 | using llvm::Record; |
| 43 | using llvm::RecordKeeper; |
| 44 | |
| 45 | #define DEBUG_TYPE "mlir-tblgen-rewritergen" |
| 46 | |
| 47 | namespace llvm { |
| 48 | template <> |
| 49 | struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { |
| 50 | static void format(const mlir::tblgen::Pattern::IdentifierLine &v, |
| 51 | raw_ostream &os, StringRef style) { |
| 52 | os << v.first << ":" << v.second; |
| 53 | } |
| 54 | }; |
| 55 | } // namespace llvm |
| 56 | |
| 57 | //===----------------------------------------------------------------------===// |
| 58 | // PatternEmitter |
| 59 | //===----------------------------------------------------------------------===// |
| 60 | |
| 61 | namespace { |
| 62 | |
| 63 | class StaticMatcherHelper; |
| 64 | |
| 65 | class PatternEmitter { |
| 66 | public: |
| 67 | PatternEmitter(const Record *pat, RecordOperatorMap *mapper, raw_ostream &os, |
| 68 | StaticMatcherHelper &helper); |
| 69 | |
| 70 | // Emits the mlir::RewritePattern struct named `rewriteName`. |
| 71 | void emit(StringRef rewriteName); |
| 72 | |
| 73 | // Emits the static function of DAG matcher. |
| 74 | void emitStaticMatcher(DagNode tree, std::string funcName); |
| 75 | |
| 76 | private: |
| 77 | // Emits the code for matching ops. |
| 78 | void emitMatchLogic(DagNode tree, StringRef opName); |
| 79 | |
| 80 | // Emits the code for rewriting ops. |
| 81 | void emitRewriteLogic(); |
| 82 | |
| 83 | //===--------------------------------------------------------------------===// |
| 84 | // Match utilities |
| 85 | //===--------------------------------------------------------------------===// |
| 86 | |
| 87 | // Emits C++ statements for matching the DAG structure. |
| 88 | void emitMatch(DagNode tree, StringRef name, int depth); |
| 89 | |
| 90 | // Emit C++ function call to static DAG matcher. |
| 91 | void emitStaticMatchCall(DagNode tree, StringRef name); |
| 92 | |
| 93 | // Emit C++ function call to static type/attribute constraint function. |
| 94 | void emitStaticVerifierCall(StringRef funcName, StringRef opName, |
| 95 | StringRef arg, StringRef failureStr); |
| 96 | |
| 97 | // Emits C++ statements for matching using a native code call. |
| 98 | void emitNativeCodeMatch(DagNode tree, StringRef name, int depth); |
| 99 | |
| 100 | // Emits C++ statements for matching the op constrained by the given DAG |
| 101 | // `tree` returning the op's variable name. |
| 102 | void emitOpMatch(DagNode tree, StringRef opName, int depth); |
| 103 | |
| 104 | // Emits C++ statements for matching the `argIndex`-th argument of the given |
| 105 | // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the |
| 106 | // bound name and the constraint of the operand respectively. |
| 107 | void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName, |
| 108 | int operandIndex, DagLeaf operandMatcher, |
| 109 | StringRef argName, int argIndex, |
| 110 | std::optional<int> variadicSubIndex); |
| 111 | |
| 112 | // Emits C++ statements for matching the operands which can be matched in |
| 113 | // either order. |
| 114 | void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree, |
| 115 | StringRef opName, int argIndex, int &operandIndex, |
| 116 | int depth); |
| 117 | |
| 118 | // Emits C++ statements for matching a variadic operand. |
| 119 | void emitVariadicOperandMatch(DagNode tree, DagNode variadicArgTree, |
| 120 | StringRef opName, int argIndex, |
| 121 | int &operandIndex, int depth); |
| 122 | |
| 123 | // Emits C++ statements for matching the `argIndex`-th argument of the given |
| 124 | // DAG `tree` as an attribute. |
| 125 | void emitAttributeMatch(DagNode tree, StringRef castedName, int argIndex, |
| 126 | int depth); |
| 127 | |
| 128 | // Emits C++ for checking a match with a corresponding match failure |
| 129 | // diagnostic. |
| 130 | void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt, |
| 131 | const llvm::formatv_object_base &failureFmt); |
| 132 | |
| 133 | // Emits C++ for checking a match with a corresponding match failure |
| 134 | // diagnostics. |
| 135 | void emitMatchCheck(StringRef opName, const std::string &matchStr, |
| 136 | const std::string &failureStr); |
| 137 | |
| 138 | //===--------------------------------------------------------------------===// |
| 139 | // Rewrite utilities |
| 140 | //===--------------------------------------------------------------------===// |
| 141 | |
| 142 | // The entry point for handling a result pattern rooted at `resultTree`. This |
| 143 | // method dispatches to concrete handlers according to `resultTree`'s kind and |
| 144 | // returns a symbol representing the whole value pack. Callers are expected to |
| 145 | // further resolve the symbol according to the specific use case. |
| 146 | // |
| 147 | // `depth` is the nesting level of `resultTree`; 0 means top-level result |
| 148 | // pattern. For top-level result pattern, `resultIndex` indicates which result |
| 149 | // of the matched root op this pattern is intended to replace, which can be |
| 150 | // used to deduce the result type of the op generated from this result |
| 151 | // pattern. |
| 152 | std::string handleResultPattern(DagNode resultTree, int resultIndex, |
| 153 | int depth); |
| 154 | |
| 155 | // Emits the C++ statement to replace the matched DAG with a value built via |
| 156 | // calling native C++ code. |
| 157 | std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth); |
| 158 | |
| 159 | // Returns the symbol of the old value serving as the replacement. |
| 160 | StringRef handleReplaceWithValue(DagNode tree); |
| 161 | |
| 162 | // Emits the C++ statement to replace the matched DAG with an array of |
| 163 | // matched values. |
| 164 | std::string handleVariadic(DagNode tree, int depth); |
| 165 | |
| 166 | // Trailing directives are used at the end of DAG node argument lists to |
| 167 | // specify additional behaviour for op matchers and creators, etc. |
| 168 | struct TrailingDirectives { |
| 169 | // DAG node containing the `location` directive. Null if there is none. |
| 170 | DagNode location; |
| 171 | |
| 172 | // DAG node containing the `returnType` directive. Null if there is none. |
| 173 | DagNode returnType; |
| 174 | |
| 175 | // Number of found trailing directives. |
| 176 | int numDirectives; |
| 177 | }; |
| 178 | |
| 179 | // Collect any trailing directives. |
| 180 | TrailingDirectives getTrailingDirectives(DagNode tree); |
| 181 | |
| 182 | // Returns the location value to use. |
| 183 | std::string getLocation(TrailingDirectives &tail); |
| 184 | |
| 185 | // Returns the location value to use. |
| 186 | std::string handleLocationDirective(DagNode tree); |
| 187 | |
| 188 | // Emit return type argument. |
| 189 | std::string handleReturnTypeArg(DagNode returnType, int i, int depth); |
| 190 | |
| 191 | // Emits the C++ statement to build a new op out of the given DAG `tree` and |
| 192 | // returns the variable name that this op is assigned to. If the root op in |
| 193 | // DAG `tree` has a specified name, the created op will be assigned to a |
| 194 | // variable of the given name. Otherwise, a unique name will be used as the |
| 195 | // result value name. |
| 196 | std::string handleOpCreation(DagNode tree, int resultIndex, int depth); |
| 197 | |
| 198 | using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>; |
| 199 | |
| 200 | // Emits a local variable for each value and attribute to be used for creating |
| 201 | // an op. |
| 202 | void createSeparateLocalVarsForOpArgs(DagNode node, |
| 203 | ChildNodeIndexNameMap &childNodeNames); |
| 204 | |
| 205 | // Emits the concrete arguments used to call an op's builder. |
| 206 | void supplyValuesForOpArgs(DagNode node, |
| 207 | const ChildNodeIndexNameMap &childNodeNames, |
| 208 | int depth); |
| 209 | |
| 210 | // Emits the local variables for holding all values as a whole and all named |
| 211 | // attributes as a whole to be used for creating an op. |
| 212 | void createAggregateLocalVarsForOpArgs( |
| 213 | DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth); |
| 214 | |
| 215 | // Returns the C++ expression to construct a constant attribute of the given |
| 216 | // `value` for the given attribute kind `attr`. |
| 217 | std::string handleConstantAttr(Attribute attr, const Twine &value); |
| 218 | |
| 219 | // Returns the C++ expression to build an argument from the given DAG `leaf`. |
| 220 | // `patArgName` is used to bound the argument to the source pattern. |
| 221 | std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); |
| 222 | |
| 223 | //===--------------------------------------------------------------------===// |
| 224 | // General utilities |
| 225 | //===--------------------------------------------------------------------===// |
| 226 | |
| 227 | // Collects all of the operations within the given dag tree. |
| 228 | void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); |
| 229 | |
| 230 | // Returns a unique symbol for a local variable of the given `op`. |
| 231 | std::string getUniqueSymbol(const Operator *op); |
| 232 | |
| 233 | //===--------------------------------------------------------------------===// |
| 234 | // Symbol utilities |
| 235 | //===--------------------------------------------------------------------===// |
| 236 | |
| 237 | // Returns how many static values the given DAG `node` correspond to. |
| 238 | int getNodeValueCount(DagNode node); |
| 239 | |
| 240 | private: |
| 241 | // Pattern instantiation location followed by the location of multiclass |
| 242 | // prototypes used. This is intended to be used as a whole to |
| 243 | // PrintFatalError() on errors. |
| 244 | ArrayRef<SMLoc> loc; |
| 245 | |
| 246 | // Op's TableGen Record to wrapper object. |
| 247 | RecordOperatorMap *opMap; |
| 248 | |
| 249 | // Handy wrapper for pattern being emitted. |
| 250 | Pattern pattern; |
| 251 | |
| 252 | // Map for all bound symbols' info. |
| 253 | SymbolInfoMap symbolInfoMap; |
| 254 | |
| 255 | StaticMatcherHelper &staticMatcherHelper; |
| 256 | |
| 257 | // The next unused ID for newly created values. |
| 258 | unsigned nextValueId = 0; |
| 259 | |
| 260 | raw_indented_ostream os; |
| 261 | |
| 262 | // Format contexts containing placeholder substitutions. |
| 263 | FmtContext fmtCtx; |
| 264 | }; |
| 265 | |
| 266 | // Tracks DagNode's reference multiple times across patterns. Enables generating |
| 267 | // static matcher functions for DagNode's referenced multiple times rather than |
| 268 | // inlining them. |
| 269 | class StaticMatcherHelper { |
| 270 | public: |
| 271 | StaticMatcherHelper(raw_ostream &os, const RecordKeeper &records, |
| 272 | RecordOperatorMap &mapper); |
| 273 | |
| 274 | // Determine if we should inline the match logic or delegate to a static |
| 275 | // function. |
| 276 | bool useStaticMatcher(DagNode node) { |
| 277 | // either/variadic node must be associated to the parentOp, thus we can't |
| 278 | // emit a static matcher rooted at them. |
| 279 | if (node.isEither() || node.isVariadic()) |
| 280 | return false; |
| 281 | |
| 282 | return refStats[node] > kStaticMatcherThreshold; |
| 283 | } |
| 284 | |
| 285 | // Get the name of the static DAG matcher function corresponding to the node. |
| 286 | std::string getMatcherName(DagNode node) { |
| 287 | assert(useStaticMatcher(node)); |
| 288 | return matcherNames[node]; |
| 289 | } |
| 290 | |
| 291 | // Get the name of static type/attribute verification function. |
| 292 | StringRef getVerifierName(DagLeaf leaf); |
| 293 | |
| 294 | // Collect the `Record`s, i.e., the DRR, so that we can get the information of |
| 295 | // the duplicated DAGs. |
| 296 | void addPattern(const Record *record); |
| 297 | |
| 298 | // Emit all static functions of DAG Matcher. |
| 299 | void populateStaticMatchers(raw_ostream &os); |
| 300 | |
| 301 | // Emit all static functions for Constraints. |
| 302 | void populateStaticConstraintFunctions(raw_ostream &os); |
| 303 | |
| 304 | private: |
| 305 | static constexpr unsigned kStaticMatcherThreshold = 1; |
| 306 | |
| 307 | // Consider two patterns as down below, |
| 308 | // DagNode_Root_A DagNode_Root_B |
| 309 | // \ \ |
| 310 | // DagNode_C DagNode_C |
| 311 | // \ \ |
| 312 | // DagNode_D DagNode_D |
| 313 | // |
| 314 | // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of |
| 315 | // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced |
| 316 | // multiple times so we'll have static matchers for both of them. When we're |
| 317 | // emitting the match logic for DagNode_C, we will check if DagNode_D has the |
| 318 | // static matcher generated. If so, then we'll generate a call to the |
| 319 | // function, inline otherwise. In this case, inlining is not what we want. As |
| 320 | // a result, generate the static matcher in topological order to ensure all |
| 321 | // the dependent static matchers are generated and we can avoid accidentally |
| 322 | // inlining. |
| 323 | // |
| 324 | // The topological order of all the DagNodes among all patterns. |
| 325 | SmallVector<std::pair<DagNode, const Record *>> topologicalOrder; |
| 326 | |
| 327 | RecordOperatorMap &opMap; |
| 328 | |
| 329 | // Records of the static function name of each DagNode |
| 330 | DenseMap<DagNode, std::string> matcherNames; |
| 331 | |
| 332 | // After collecting all the DagNode in each pattern, `refStats` records the |
| 333 | // number of users for each DagNode. We will generate the static matcher for a |
| 334 | // DagNode while the number of users exceeds a certain threshold. |
| 335 | DenseMap<DagNode, unsigned> refStats; |
| 336 | |
| 337 | // Number of static matcher generated. This is used to generate a unique name |
| 338 | // for each DagNode. |
| 339 | int staticMatcherCounter = 0; |
| 340 | |
| 341 | // The DagLeaf which contains type or attr constraint. |
| 342 | SetVector<DagLeaf> constraints; |
| 343 | |
| 344 | // Static type/attribute verification function emitter. |
| 345 | StaticVerifierFunctionEmitter staticVerifierEmitter; |
| 346 | }; |
| 347 | |
| 348 | } // namespace |
| 349 | |
| 350 | PatternEmitter::PatternEmitter(const Record *pat, RecordOperatorMap *mapper, |
| 351 | raw_ostream &os, StaticMatcherHelper &helper) |
| 352 | : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), |
| 353 | symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), os(os) { |
| 354 | fmtCtx.withBuilder(subst: "rewriter" ); |
| 355 | } |
| 356 | |
| 357 | std::string PatternEmitter::handleConstantAttr(Attribute attr, |
| 358 | const Twine &value) { |
| 359 | if (!attr.isConstBuildable()) |
| 360 | PrintFatalError(ErrorLoc: loc, Msg: "Attribute " + attr.getAttrDefName() + |
| 361 | " does not have the 'constBuilderCall' field" ); |
| 362 | |
| 363 | // TODO: Verify the constants here |
| 364 | return std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fmtCtx, vals: value)); |
| 365 | } |
| 366 | |
| 367 | void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) { |
| 368 | os << formatv( |
| 369 | Fmt: "static ::llvm::LogicalResult {0}(::mlir::PatternRewriter &rewriter, " |
| 370 | "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation " |
| 371 | "*, 4> &tblgen_ops" , |
| 372 | Vals&: funcName); |
| 373 | |
| 374 | // We pass the reference of the variables that need to be captured. Hence we |
| 375 | // need to collect all the symbols in the tree first. |
| 376 | pattern.collectBoundSymbols(tree, infoMap&: symbolInfoMap, /*isSrcPattern=*/true); |
| 377 | symbolInfoMap.assignUniqueAlternativeNames(); |
| 378 | for (const auto &info : symbolInfoMap) |
| 379 | os << formatv(Fmt: ", {0}" , Vals: info.second.getArgDecl(name: info.first)); |
| 380 | |
| 381 | os << ") {\n" ; |
| 382 | os.indent(); |
| 383 | os << "(void)tblgen_ops;\n" ; |
| 384 | |
| 385 | // Note that a static matcher is considered at least one step from the match |
| 386 | // entry. |
| 387 | emitMatch(tree, name: "op0" , /*depth=*/1); |
| 388 | |
| 389 | os << "return ::mlir::success();\n" ; |
| 390 | os.unindent(); |
| 391 | os << "}\n\n" ; |
| 392 | } |
| 393 | |
| 394 | // Helper function to match patterns. |
| 395 | void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) { |
| 396 | if (tree.isNativeCodeCall()) { |
| 397 | emitNativeCodeMatch(tree, name, depth); |
| 398 | return; |
| 399 | } |
| 400 | |
| 401 | if (tree.isOperation()) { |
| 402 | emitOpMatch(tree, opName: name, depth); |
| 403 | return; |
| 404 | } |
| 405 | |
| 406 | PrintFatalError(ErrorLoc: loc, Msg: "encountered non-op, non-NativeCodeCall match." ); |
| 407 | } |
| 408 | |
| 409 | void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) { |
| 410 | std::string funcName = staticMatcherHelper.getMatcherName(node: tree); |
| 411 | os << formatv(Fmt: "if(::mlir::failed({0}(rewriter, {1}, tblgen_ops" , Vals&: funcName, |
| 412 | Vals&: opName); |
| 413 | |
| 414 | // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in |
| 415 | // one pass. |
| 416 | |
| 417 | // In general, bound symbol should have the unique name in the pattern but |
| 418 | // for the operand, binding same symbol to multiple operands imply a |
| 419 | // constraint at the same time. In this case, we will rename those operands |
| 420 | // with different names. As a result, we need to collect all the symbolInfos |
| 421 | // from the DagNode then get the updated name of the local variables from the |
| 422 | // global symbolInfoMap. |
| 423 | |
| 424 | // Collect all the bound symbols in the Dag |
| 425 | SymbolInfoMap localSymbolMap(loc); |
| 426 | pattern.collectBoundSymbols(tree, infoMap&: localSymbolMap, /*isSrcPattern=*/true); |
| 427 | |
| 428 | for (const auto &info : localSymbolMap) { |
| 429 | auto name = info.first; |
| 430 | auto symboInfo = info.second; |
| 431 | auto ret = symbolInfoMap.findBoundSymbol(key: name, symbolInfo: symboInfo); |
| 432 | os << formatv(Fmt: ", {0}" , Vals: ret->second.getVarName(name)); |
| 433 | } |
| 434 | |
| 435 | os << "))) {\n" ; |
| 436 | os.scope().os << "return ::mlir::failure();\n" ; |
| 437 | os << "}\n" ; |
| 438 | } |
| 439 | |
| 440 | void PatternEmitter::emitStaticVerifierCall(StringRef funcName, |
| 441 | StringRef opName, StringRef arg, |
| 442 | StringRef failureStr) { |
| 443 | os << formatv(Fmt: "if(::mlir::failed({0}(rewriter, {1}, {2}, {3}))) {{\n" , |
| 444 | Vals&: funcName, Vals&: opName, Vals&: arg, Vals&: failureStr); |
| 445 | os.scope().os << "return ::mlir::failure();\n" ; |
| 446 | os << "}\n" ; |
| 447 | } |
| 448 | |
| 449 | // Helper function to match patterns. |
| 450 | void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, |
| 451 | int depth) { |
| 452 | LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: " ); |
| 453 | LLVM_DEBUG(tree.print(llvm::dbgs())); |
| 454 | LLVM_DEBUG(llvm::dbgs() << '\n'); |
| 455 | |
| 456 | // The order of generating static matcher follows the topological order so |
| 457 | // that for every dependent DagNode already have their static matcher |
| 458 | // generated if needed. The reason we check if `getMatcherName(tree).empty()` |
| 459 | // is when we are generating the static matcher for a DagNode itself. In this |
| 460 | // case, we need to emit the function body rather than a function call. |
| 461 | if (staticMatcherHelper.useStaticMatcher(node: tree) && |
| 462 | !staticMatcherHelper.getMatcherName(node: tree).empty()) { |
| 463 | emitStaticMatchCall(tree, opName); |
| 464 | |
| 465 | // NativeCodeCall will never be at depth 0 so that we don't need to catch |
| 466 | // the root operation as emitOpMatch(); |
| 467 | |
| 468 | return; |
| 469 | } |
| 470 | |
| 471 | // TODO(suderman): iterate through arguments, determine their types, output |
| 472 | // names. |
| 473 | SmallVector<std::string, 8> capture; |
| 474 | |
| 475 | raw_indented_ostream::DelimitedScope scope(os); |
| 476 | |
| 477 | for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { |
| 478 | std::string argName = formatv(Fmt: "arg{0}_{1}" , Vals&: depth, Vals&: i); |
| 479 | if (DagNode argTree = tree.getArgAsNestedDag(index: i)) { |
| 480 | if (argTree.isEither()) |
| 481 | PrintFatalError(ErrorLoc: loc, Msg: "NativeCodeCall cannot have `either` operands" ); |
| 482 | if (argTree.isVariadic()) |
| 483 | PrintFatalError(ErrorLoc: loc, Msg: "NativeCodeCall cannot have `variadic` operands" ); |
| 484 | |
| 485 | os << "::mlir::Value " << argName << ";\n" ; |
| 486 | } else { |
| 487 | auto leaf = tree.getArgAsLeaf(index: i); |
| 488 | if (leaf.isAttrMatcher() || leaf.isConstantAttr()) { |
| 489 | os << "::mlir::Attribute " << argName << ";\n" ; |
| 490 | } else { |
| 491 | os << "::mlir::Value " << argName << ";\n" ; |
| 492 | } |
| 493 | } |
| 494 | |
| 495 | capture.push_back(Elt: std::move(argName)); |
| 496 | } |
| 497 | |
| 498 | auto tail = getTrailingDirectives(tree); |
| 499 | if (tail.returnType) |
| 500 | PrintFatalError(ErrorLoc: loc, Msg: "`NativeCodeCall` cannot have return type specifier" ); |
| 501 | auto locToUse = getLocation(tail); |
| 502 | |
| 503 | auto fmt = tree.getNativeCodeTemplate(); |
| 504 | if (fmt.count(Str: "$_self" ) != 1) |
| 505 | PrintFatalError(ErrorLoc: loc, Msg: "NativeCodeCall must have $_self as argument for " |
| 506 | "passing the defining Operation" ); |
| 507 | |
| 508 | auto nativeCodeCall = std::string( |
| 509 | tgfmt(fmt, ctx: &fmtCtx.addSubst(placeholder: "_loc" , subst: locToUse).withSelf(subst: opName.str()), |
| 510 | params: static_cast<ArrayRef<std::string>>(capture))); |
| 511 | |
| 512 | emitMatchCheck(opName, matchStr: formatv(Fmt: "!::mlir::failed({0})" , Vals&: nativeCodeCall), |
| 513 | failureStr: formatv(Fmt: "\"{0} return ::mlir::failure\"" , Vals&: nativeCodeCall)); |
| 514 | |
| 515 | for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { |
| 516 | auto name = tree.getArgName(index: i); |
| 517 | if (!name.empty() && name != "_" ) { |
| 518 | os << formatv(Fmt: "{0} = {1};\n" , Vals&: name, Vals&: capture[i]); |
| 519 | } |
| 520 | } |
| 521 | |
| 522 | for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { |
| 523 | std::string argName = capture[i]; |
| 524 | |
| 525 | // Handle nested DAG construct first |
| 526 | if (tree.getArgAsNestedDag(index: i)) { |
| 527 | PrintFatalError( |
| 528 | ErrorLoc: loc, Msg: formatv(Fmt: "Matching nested tree in NativeCodecall not support for " |
| 529 | "{0} as arg {1}" , |
| 530 | Vals&: argName, Vals&: i)); |
| 531 | } |
| 532 | |
| 533 | DagLeaf leaf = tree.getArgAsLeaf(index: i); |
| 534 | |
| 535 | // The parameter for native function doesn't bind any constraints. |
| 536 | if (leaf.isUnspecified()) |
| 537 | continue; |
| 538 | |
| 539 | auto constraint = leaf.getAsConstraint(); |
| 540 | |
| 541 | std::string self; |
| 542 | if (leaf.isAttrMatcher() || leaf.isConstantAttr()) |
| 543 | self = argName; |
| 544 | else |
| 545 | self = formatv(Fmt: "{0}.getType()" , Vals&: argName); |
| 546 | StringRef verifier = staticMatcherHelper.getVerifierName(leaf); |
| 547 | emitStaticVerifierCall( |
| 548 | funcName: verifier, opName, arg: self, |
| 549 | failureStr: formatv(Fmt: "\"operand {0} of native code call '{1}' failed to satisfy " |
| 550 | "constraint: " |
| 551 | "'{2}'\"" , |
| 552 | Vals&: i, Vals: tree.getNativeCodeTemplate(), |
| 553 | Vals: escapeString(value: constraint.getSummary())) |
| 554 | .str()); |
| 555 | } |
| 556 | |
| 557 | LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n" ); |
| 558 | } |
| 559 | |
| 560 | // Helper function to match patterns. |
| 561 | void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { |
| 562 | Operator &op = tree.getDialectOp(mapper: opMap); |
| 563 | LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" |
| 564 | << op.getOperationName() << "' at depth " << depth |
| 565 | << '\n'); |
| 566 | |
| 567 | auto getCastedName = [depth]() -> std::string { |
| 568 | return formatv(Fmt: "castedOp{0}" , Vals: depth); |
| 569 | }; |
| 570 | |
| 571 | // The order of generating static matcher follows the topological order so |
| 572 | // that for every dependent DagNode already have their static matcher |
| 573 | // generated if needed. The reason we check if `getMatcherName(tree).empty()` |
| 574 | // is when we are generating the static matcher for a DagNode itself. In this |
| 575 | // case, we need to emit the function body rather than a function call. |
| 576 | if (staticMatcherHelper.useStaticMatcher(node: tree) && |
| 577 | !staticMatcherHelper.getMatcherName(node: tree).empty()) { |
| 578 | emitStaticMatchCall(tree, opName); |
| 579 | // In the codegen of rewriter, we suppose that castedOp0 will capture the |
| 580 | // root operation. Manually add it if the root DagNode is a static matcher. |
| 581 | if (depth == 0) |
| 582 | os << formatv(Fmt: "auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); " |
| 583 | "(void){2};\n" , |
| 584 | Vals&: opName, Vals: op.getQualCppClassName(), Vals: getCastedName()); |
| 585 | return; |
| 586 | } |
| 587 | |
| 588 | std::string castedName = getCastedName(); |
| 589 | os << formatv(Fmt: "auto {0} = ::llvm::dyn_cast<{2}>({1}); " |
| 590 | "(void){0};\n" , |
| 591 | Vals&: castedName, Vals&: opName, Vals: op.getQualCppClassName()); |
| 592 | |
| 593 | // Skip the operand matching at depth 0 as the pattern rewriter already does. |
| 594 | if (depth != 0) |
| 595 | emitMatchCheck(opName, /*matchStr=*/castedName, |
| 596 | failureStr: formatv(Fmt: "\"{0} is not {1} type\"" , Vals&: castedName, |
| 597 | Vals: op.getQualCppClassName())); |
| 598 | |
| 599 | // If the operand's name is set, set to that variable. |
| 600 | auto name = tree.getSymbol(); |
| 601 | if (!name.empty()) |
| 602 | os << formatv(Fmt: "{0} = {1};\n" , Vals&: name, Vals&: castedName); |
| 603 | |
| 604 | for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; |
| 605 | ++i, ++opArgIdx) { |
| 606 | auto opArg = op.getArg(index: opArgIdx); |
| 607 | std::string argName = formatv(Fmt: "op{0}" , Vals: depth + 1); |
| 608 | |
| 609 | // Handle nested DAG construct first |
| 610 | if (DagNode argTree = tree.getArgAsNestedDag(index: i)) { |
| 611 | if (argTree.isEither()) { |
| 612 | emitEitherOperandMatch(tree, eitherArgTree: argTree, opName: castedName, argIndex: opArgIdx, operandIndex&: nextOperand, |
| 613 | depth); |
| 614 | ++opArgIdx; |
| 615 | continue; |
| 616 | } |
| 617 | if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: opArg)) { |
| 618 | if (argTree.isVariadic()) { |
| 619 | if (!operand->isVariadic()) { |
| 620 | auto error = formatv(Fmt: "variadic DAG construct can't match op {0}'s " |
| 621 | "non-variadic operand #{1}" , |
| 622 | Vals: op.getOperationName(), Vals&: opArgIdx); |
| 623 | PrintFatalError(ErrorLoc: loc, Msg: error); |
| 624 | } |
| 625 | emitVariadicOperandMatch(tree, variadicArgTree: argTree, opName: castedName, argIndex: opArgIdx, |
| 626 | operandIndex&: nextOperand, depth); |
| 627 | ++nextOperand; |
| 628 | continue; |
| 629 | } |
| 630 | if (operand->isVariableLength()) { |
| 631 | auto error = formatv(Fmt: "use nested DAG construct to match op {0}'s " |
| 632 | "variadic operand #{1} unsupported now" , |
| 633 | Vals: op.getOperationName(), Vals&: opArgIdx); |
| 634 | PrintFatalError(ErrorLoc: loc, Msg: error); |
| 635 | } |
| 636 | } |
| 637 | |
| 638 | os << "{\n" ; |
| 639 | |
| 640 | // Attributes don't count for getODSOperands. |
| 641 | // TODO: Operand is a Value, check if we should remove `getDefiningOp()`. |
| 642 | os.indent() << formatv( |
| 643 | Fmt: "auto *{0} = " |
| 644 | "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n" , |
| 645 | Vals&: argName, Vals&: castedName, Vals&: nextOperand); |
| 646 | // Null check of operand's definingOp |
| 647 | emitMatchCheck( |
| 648 | opName: castedName, /*matchStr=*/argName, |
| 649 | failureStr: formatv(Fmt: "\"There's no operation that defines operand {0} of {1}\"" , |
| 650 | Vals: nextOperand++, Vals&: castedName)); |
| 651 | emitMatch(tree: argTree, name: argName, depth: depth + 1); |
| 652 | os << formatv(Fmt: "tblgen_ops.push_back({0});\n" , Vals&: argName); |
| 653 | os.unindent() << "}\n" ; |
| 654 | continue; |
| 655 | } |
| 656 | |
| 657 | // Next handle DAG leaf: operand or attribute |
| 658 | if (isa<NamedTypeConstraint *>(Val: opArg)) { |
| 659 | auto operandName = |
| 660 | formatv(Fmt: "{0}.getODSOperands({1})" , Vals&: castedName, Vals&: nextOperand); |
| 661 | emitOperandMatch(tree, opName: castedName, operandName: operandName.str(), operandIndex: nextOperand, |
| 662 | /*operandMatcher=*/tree.getArgAsLeaf(index: i), |
| 663 | /*argName=*/tree.getArgName(index: i), argIndex: opArgIdx, |
| 664 | /*variadicSubIndex=*/std::nullopt); |
| 665 | ++nextOperand; |
| 666 | } else if (isa<NamedAttribute *>(Val: opArg)) { |
| 667 | emitAttributeMatch(tree, castedName, argIndex: opArgIdx, depth); |
| 668 | } else { |
| 669 | PrintFatalError(ErrorLoc: loc, Msg: "unhandled case when matching op" ); |
| 670 | } |
| 671 | } |
| 672 | LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" |
| 673 | << op.getOperationName() << "' at depth " << depth |
| 674 | << '\n'); |
| 675 | } |
| 676 | |
| 677 | void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, |
| 678 | StringRef operandName, int operandIndex, |
| 679 | DagLeaf operandMatcher, StringRef argName, |
| 680 | int argIndex, |
| 681 | std::optional<int> variadicSubIndex) { |
| 682 | Operator &op = tree.getDialectOp(mapper: opMap); |
| 683 | NamedTypeConstraint operand = op.getOperand(index: operandIndex); |
| 684 | |
| 685 | // If a constraint is specified, we need to generate C++ statements to |
| 686 | // check the constraint. |
| 687 | if (!operandMatcher.isUnspecified()) { |
| 688 | if (!operandMatcher.isOperandMatcher()) |
| 689 | PrintFatalError( |
| 690 | ErrorLoc: loc, Msg: formatv(Fmt: "the {1}-th argument of op '{0}' should be an operand" , |
| 691 | Vals: op.getOperationName(), Vals: argIndex + 1)); |
| 692 | |
| 693 | // Only need to verify if the matcher's type is different from the one |
| 694 | // of op definition. |
| 695 | Constraint constraint = operandMatcher.getAsConstraint(); |
| 696 | if (operand.constraint != constraint) { |
| 697 | if (operand.isVariableLength()) { |
| 698 | auto error = formatv( |
| 699 | Fmt: "further constrain op {0}'s variadic operand #{1} unsupported now" , |
| 700 | Vals: op.getOperationName(), Vals&: argIndex); |
| 701 | PrintFatalError(ErrorLoc: loc, Msg: error); |
| 702 | } |
| 703 | auto self = formatv(Fmt: "(*{0}.begin()).getType()" , Vals&: operandName); |
| 704 | StringRef verifier = staticMatcherHelper.getVerifierName(leaf: operandMatcher); |
| 705 | emitStaticVerifierCall( |
| 706 | funcName: verifier, opName, arg: self.str(), |
| 707 | failureStr: formatv( |
| 708 | Fmt: "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"" , |
| 709 | Vals&: operandIndex, Vals: op.getOperationName(), |
| 710 | Vals: escapeString(value: constraint.getSummary())) |
| 711 | .str()); |
| 712 | } |
| 713 | } |
| 714 | |
| 715 | // Capture the value |
| 716 | // `$_` is a special symbol to ignore op argument matching. |
| 717 | if (!argName.empty() && argName != "_" ) { |
| 718 | auto res = symbolInfoMap.findBoundSymbol(key: argName, node: tree, op, argIndex, |
| 719 | variadicSubIndex); |
| 720 | if (res == symbolInfoMap.end()) |
| 721 | PrintFatalError(ErrorLoc: loc, Msg: formatv(Fmt: "symbol not found: {0}" , Vals&: argName)); |
| 722 | |
| 723 | os << formatv(Fmt: "{0} = {1};\n" , Vals: res->second.getVarName(name: argName), Vals&: operandName); |
| 724 | } |
| 725 | } |
| 726 | |
| 727 | void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree, |
| 728 | StringRef opName, int argIndex, |
| 729 | int &operandIndex, int depth) { |
| 730 | constexpr int numEitherArgs = 2; |
| 731 | if (eitherArgTree.getNumArgs() != numEitherArgs) |
| 732 | PrintFatalError(ErrorLoc: loc, Msg: "`either` only supports grouping two operands" ); |
| 733 | |
| 734 | Operator &op = tree.getDialectOp(mapper: opMap); |
| 735 | |
| 736 | std::string codeBuffer; |
| 737 | llvm::raw_string_ostream tblgenOps(codeBuffer); |
| 738 | |
| 739 | std::string lambda = formatv(Fmt: "eitherLambda{0}" , Vals&: depth); |
| 740 | os << formatv( |
| 741 | Fmt: "auto {0} = [&](::mlir::OperandRange v0, ::mlir::OperandRange v1) {{\n" , |
| 742 | Vals&: lambda); |
| 743 | |
| 744 | os.indent(); |
| 745 | |
| 746 | for (int i = 0; i < numEitherArgs; ++i, ++argIndex) { |
| 747 | if (DagNode argTree = eitherArgTree.getArgAsNestedDag(index: i)) { |
| 748 | if (argTree.isEither()) |
| 749 | PrintFatalError(ErrorLoc: loc, Msg: "either cannot be nested" ); |
| 750 | |
| 751 | std::string argName = formatv(Fmt: "local_op_{0}" , Vals&: i).str(); |
| 752 | |
| 753 | os << formatv(Fmt: "auto {0} = (*v{1}.begin()).getDefiningOp();\n" , Vals&: argName, |
| 754 | Vals&: i); |
| 755 | |
| 756 | // Indent emitMatchCheck and emitMatch because they declare local |
| 757 | // variables. |
| 758 | os << "{\n" ; |
| 759 | os.indent(); |
| 760 | |
| 761 | emitMatchCheck( |
| 762 | opName, /*matchStr=*/argName, |
| 763 | failureStr: formatv(Fmt: "\"There's no operation that defines operand {0} of {1}\"" , |
| 764 | Vals: operandIndex++, Vals&: opName)); |
| 765 | emitMatch(tree: argTree, name: argName, depth: depth + 1); |
| 766 | |
| 767 | os.unindent() << "}\n" ; |
| 768 | |
| 769 | // `tblgen_ops` is used to collect the matched operations. In either, we |
| 770 | // need to queue the operation only if the matching success. Thus we emit |
| 771 | // the code at the end. |
| 772 | tblgenOps << formatv(Fmt: "tblgen_ops.push_back({0});\n" , Vals&: argName); |
| 773 | } else if (isa<NamedTypeConstraint *>(Val: op.getArg(index: argIndex))) { |
| 774 | emitOperandMatch(tree, opName, /*operandName=*/formatv(Fmt: "v{0}" , Vals&: i).str(), |
| 775 | operandIndex, |
| 776 | /*operandMatcher=*/eitherArgTree.getArgAsLeaf(index: i), |
| 777 | /*argName=*/eitherArgTree.getArgName(index: i), argIndex, |
| 778 | /*variadicSubIndex=*/std::nullopt); |
| 779 | ++operandIndex; |
| 780 | } else { |
| 781 | PrintFatalError(ErrorLoc: loc, Msg: "either can only be applied on operand" ); |
| 782 | } |
| 783 | } |
| 784 | |
| 785 | os << tblgenOps.str(); |
| 786 | os << "return ::mlir::success();\n" ; |
| 787 | os.unindent() << "};\n" ; |
| 788 | |
| 789 | os << "{\n" ; |
| 790 | os.indent(); |
| 791 | |
| 792 | os << formatv(Fmt: "auto eitherOperand0 = {0}.getODSOperands({1});\n" , Vals&: opName, |
| 793 | Vals: operandIndex - 2); |
| 794 | os << formatv(Fmt: "auto eitherOperand1 = {0}.getODSOperands({1});\n" , Vals&: opName, |
| 795 | Vals: operandIndex - 1); |
| 796 | |
| 797 | os << formatv(Fmt: "if(::mlir::failed({0}(eitherOperand0, eitherOperand1)) && " |
| 798 | "::mlir::failed({0}(eitherOperand1, " |
| 799 | "eitherOperand0)))\n" , |
| 800 | Vals&: lambda); |
| 801 | os.indent() << "return ::mlir::failure();\n" ; |
| 802 | |
| 803 | os.unindent().unindent() << "}\n" ; |
| 804 | } |
| 805 | |
| 806 | void PatternEmitter::emitVariadicOperandMatch(DagNode tree, |
| 807 | DagNode variadicArgTree, |
| 808 | StringRef opName, int argIndex, |
| 809 | int &operandIndex, int depth) { |
| 810 | Operator &op = tree.getDialectOp(mapper: opMap); |
| 811 | |
| 812 | os << "{\n" ; |
| 813 | os.indent(); |
| 814 | |
| 815 | os << formatv(Fmt: "auto variadic_operand_range = {0}.getODSOperands({1});\n" , |
| 816 | Vals&: opName, Vals&: operandIndex); |
| 817 | os << formatv(Fmt: "if (variadic_operand_range.size() != {0}) " |
| 818 | "return ::mlir::failure();\n" , |
| 819 | Vals: variadicArgTree.getNumArgs()); |
| 820 | |
| 821 | StringRef variadicTreeName = variadicArgTree.getSymbol(); |
| 822 | if (!variadicTreeName.empty()) { |
| 823 | auto res = |
| 824 | symbolInfoMap.findBoundSymbol(key: variadicTreeName, node: tree, op, argIndex, |
| 825 | /*variadicSubIndex=*/std::nullopt); |
| 826 | if (res == symbolInfoMap.end()) |
| 827 | PrintFatalError(ErrorLoc: loc, Msg: formatv(Fmt: "symbol not found: {0}" , Vals&: variadicTreeName)); |
| 828 | |
| 829 | os << formatv(Fmt: "{0} = variadic_operand_range;\n" , |
| 830 | Vals: res->second.getVarName(name: variadicTreeName)); |
| 831 | } |
| 832 | |
| 833 | for (int i = 0; i < variadicArgTree.getNumArgs(); ++i) { |
| 834 | if (DagNode argTree = variadicArgTree.getArgAsNestedDag(index: i)) { |
| 835 | if (!argTree.isOperation()) |
| 836 | PrintFatalError(ErrorLoc: loc, Msg: "variadic only accepts operation sub-dags" ); |
| 837 | |
| 838 | os << "{\n" ; |
| 839 | os.indent(); |
| 840 | |
| 841 | std::string argName = formatv(Fmt: "local_op_{0}" , Vals&: i).str(); |
| 842 | os << formatv(Fmt: "auto *{0} = " |
| 843 | "variadic_operand_range[{1}].getDefiningOp();\n" , |
| 844 | Vals&: argName, Vals&: i); |
| 845 | emitMatchCheck( |
| 846 | opName, /*matchStr=*/argName, |
| 847 | failureStr: formatv(Fmt: "\"There's no operation that defines variadic operand " |
| 848 | "{0} (variadic sub-opearnd #{1}) of {2}\"" , |
| 849 | Vals&: operandIndex, Vals&: i, Vals&: opName)); |
| 850 | emitMatch(tree: argTree, name: argName, depth: depth + 1); |
| 851 | os << formatv(Fmt: "tblgen_ops.push_back({0});\n" , Vals&: argName); |
| 852 | |
| 853 | os.unindent() << "}\n" ; |
| 854 | } else if (isa<NamedTypeConstraint *>(Val: op.getArg(index: argIndex))) { |
| 855 | auto operandName = formatv(Fmt: "variadic_operand_range.slice({0}, 1)" , Vals&: i); |
| 856 | emitOperandMatch(tree, opName, operandName: operandName.str(), operandIndex, |
| 857 | /*operandMatcher=*/variadicArgTree.getArgAsLeaf(index: i), |
| 858 | /*argName=*/variadicArgTree.getArgName(index: i), argIndex, variadicSubIndex: i); |
| 859 | } else { |
| 860 | PrintFatalError(ErrorLoc: loc, Msg: "variadic can only be applied on operand" ); |
| 861 | } |
| 862 | } |
| 863 | |
| 864 | os.unindent() << "}\n" ; |
| 865 | } |
| 866 | |
| 867 | void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef castedName, |
| 868 | int argIndex, int depth) { |
| 869 | Operator &op = tree.getDialectOp(mapper: opMap); |
| 870 | auto *namedAttr = cast<NamedAttribute *>(Val: op.getArg(index: argIndex)); |
| 871 | const auto &attr = namedAttr->attr; |
| 872 | |
| 873 | os << "{\n" ; |
| 874 | if (op.getDialect().usePropertiesForAttributes()) { |
| 875 | os.indent() << formatv( |
| 876 | Fmt: "[[maybe_unused]] auto tblgen_attr = {0}.getProperties().{1}();\n" , |
| 877 | Vals&: castedName, Vals: op.getGetterName(name: namedAttr->name)); |
| 878 | } else { |
| 879 | os.indent() << formatv(Fmt: "[[maybe_unused]] auto tblgen_attr = " |
| 880 | "{0}->getAttrOfType<{1}>(\"{2}\");\n" , |
| 881 | Vals&: castedName, Vals: attr.getStorageType(), Vals&: namedAttr->name); |
| 882 | } |
| 883 | |
| 884 | // TODO: This should use getter method to avoid duplication. |
| 885 | if (attr.hasDefaultValue()) { |
| 886 | os << "if (!tblgen_attr) tblgen_attr = " |
| 887 | << std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fmtCtx, |
| 888 | vals: tgfmt(fmt: attr.getDefaultValue(), ctx: &fmtCtx))) |
| 889 | << ";\n" ; |
| 890 | } else if (attr.isOptional()) { |
| 891 | // For a missing attribute that is optional according to definition, we |
| 892 | // should just capture a mlir::Attribute() to signal the missing state. |
| 893 | // That is precisely what getDiscardableAttr() returns on missing |
| 894 | // attributes. |
| 895 | } else { |
| 896 | emitMatchCheck(opName: castedName, matchFmt: tgfmt(fmt: "tblgen_attr" , ctx: &fmtCtx), |
| 897 | failureFmt: formatv(Fmt: "\"expected op '{0}' to have attribute '{1}' " |
| 898 | "of type '{2}'\"" , |
| 899 | Vals: op.getOperationName(), Vals&: namedAttr->name, |
| 900 | Vals: attr.getStorageType())); |
| 901 | } |
| 902 | |
| 903 | auto matcher = tree.getArgAsLeaf(index: argIndex); |
| 904 | if (!matcher.isUnspecified()) { |
| 905 | if (!matcher.isAttrMatcher()) { |
| 906 | PrintFatalError( |
| 907 | ErrorLoc: loc, Msg: formatv(Fmt: "the {1}-th argument of op '{0}' should be an attribute" , |
| 908 | Vals: op.getOperationName(), Vals: argIndex + 1)); |
| 909 | } |
| 910 | |
| 911 | // If a constraint is specified, we need to generate function call to its |
| 912 | // static verifier. |
| 913 | StringRef verifier = staticMatcherHelper.getVerifierName(leaf: matcher); |
| 914 | if (attr.isOptional()) { |
| 915 | // Avoid dereferencing null attribute. This is using a simple heuristic to |
| 916 | // avoid common cases of attempting to dereference null attribute. This |
| 917 | // will return where there is no check if attribute is null unless the |
| 918 | // attribute's value is not used. |
| 919 | // FIXME: This could be improved as some null dereferences could slip |
| 920 | // through. |
| 921 | if (!StringRef(matcher.getConditionTemplate()).contains(Other: "!$_self" ) && |
| 922 | StringRef(matcher.getConditionTemplate()).contains(Other: "$_self" )) { |
| 923 | os << "if (!tblgen_attr) return ::mlir::failure();\n" ; |
| 924 | } |
| 925 | } |
| 926 | emitStaticVerifierCall( |
| 927 | funcName: verifier, opName: castedName, arg: "tblgen_attr" , |
| 928 | failureStr: formatv(Fmt: "\"op '{0}' attribute '{1}' failed to satisfy constraint: " |
| 929 | "'{2}'\"" , |
| 930 | Vals: op.getOperationName(), Vals&: namedAttr->name, |
| 931 | Vals: escapeString(value: matcher.getAsConstraint().getSummary())) |
| 932 | .str()); |
| 933 | } |
| 934 | |
| 935 | // Capture the value |
| 936 | auto name = tree.getArgName(index: argIndex); |
| 937 | // `$_` is a special symbol to ignore op argument matching. |
| 938 | if (!name.empty() && name != "_" ) { |
| 939 | os << formatv(Fmt: "{0} = tblgen_attr;\n" , Vals&: name); |
| 940 | } |
| 941 | |
| 942 | os.unindent() << "}\n" ; |
| 943 | } |
| 944 | |
| 945 | void PatternEmitter::emitMatchCheck( |
| 946 | StringRef opName, const FmtObjectBase &matchFmt, |
| 947 | const llvm::formatv_object_base &failureFmt) { |
| 948 | emitMatchCheck(opName, matchStr: matchFmt.str(), failureStr: failureFmt.str()); |
| 949 | } |
| 950 | |
| 951 | void PatternEmitter::emitMatchCheck(StringRef opName, |
| 952 | const std::string &matchStr, |
| 953 | const std::string &failureStr) { |
| 954 | |
| 955 | os << "if (!(" << matchStr << "))" ; |
| 956 | os.scope(open: "{\n" , close: "\n}\n" ).os << "return rewriter.notifyMatchFailure(" << opName |
| 957 | << ", [&](::mlir::Diagnostic &diag) {\n diag << " |
| 958 | << failureStr << ";\n});" ; |
| 959 | } |
| 960 | |
| 961 | void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { |
| 962 | LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n" ); |
| 963 | int depth = 0; |
| 964 | emitMatch(tree, name: opName, depth); |
| 965 | |
| 966 | for (auto &appliedConstraint : pattern.getConstraints()) { |
| 967 | auto &constraint = appliedConstraint.constraint; |
| 968 | auto &entities = appliedConstraint.entities; |
| 969 | |
| 970 | auto condition = constraint.getConditionTemplate(); |
| 971 | if (isa<TypeConstraint>(Val: constraint)) { |
| 972 | if (entities.size() != 1) |
| 973 | PrintFatalError(ErrorLoc: loc, Msg: "type constraint requires exactly one argument" ); |
| 974 | |
| 975 | auto self = formatv(Fmt: "({0}.getType())" , |
| 976 | Vals: symbolInfoMap.getValueAndRangeUse(symbol: entities.front())); |
| 977 | emitMatchCheck( |
| 978 | opName, matchFmt: tgfmt(fmt: condition, ctx: &fmtCtx.withSelf(subst: self.str())), |
| 979 | failureFmt: formatv(Fmt: "\"value entity '{0}' failed to satisfy constraint: '{1}'\"" , |
| 980 | Vals&: entities.front(), Vals: escapeString(value: constraint.getSummary()))); |
| 981 | |
| 982 | } else if (isa<AttrConstraint>(Val: constraint)) { |
| 983 | PrintFatalError( |
| 984 | ErrorLoc: loc, Msg: "cannot use AttrConstraint in Pattern multi-entity constraints" ); |
| 985 | } else { |
| 986 | // TODO: replace formatv arguments with the exact specified |
| 987 | // args. |
| 988 | if (entities.size() > 4) { |
| 989 | PrintFatalError(ErrorLoc: loc, Msg: "only support up to 4-entity constraints now" ); |
| 990 | } |
| 991 | SmallVector<std::string, 4> names; |
| 992 | int i = 0; |
| 993 | for (int e = entities.size(); i < e; ++i) |
| 994 | names.push_back(Elt: symbolInfoMap.getValueAndRangeUse(symbol: entities[i])); |
| 995 | std::string self = appliedConstraint.self; |
| 996 | if (!self.empty()) |
| 997 | self = symbolInfoMap.getValueAndRangeUse(symbol: self); |
| 998 | for (; i < 4; ++i) |
| 999 | names.push_back(Elt: "<unused>" ); |
| 1000 | emitMatchCheck(opName, |
| 1001 | matchFmt: tgfmt(fmt: condition, ctx: &fmtCtx.withSelf(subst: self), vals&: names[0], |
| 1002 | vals&: names[1], vals&: names[2], vals&: names[3]), |
| 1003 | failureFmt: formatv(Fmt: "\"entities '{0}' failed to satisfy constraint: " |
| 1004 | "'{1}'\"" , |
| 1005 | Vals: llvm::join(R&: entities, Separator: ", " ), |
| 1006 | Vals: escapeString(value: constraint.getSummary()))); |
| 1007 | } |
| 1008 | } |
| 1009 | |
| 1010 | // Some of the operands could be bound to the same symbol name, we need |
| 1011 | // to enforce equality constraint on those. |
| 1012 | // TODO: we should be able to emit equality checks early |
| 1013 | // and short circuit unnecessary work if vars are not equal. |
| 1014 | for (auto symbolInfoIt = symbolInfoMap.begin(); |
| 1015 | symbolInfoIt != symbolInfoMap.end();) { |
| 1016 | auto range = symbolInfoMap.getRangeOfEqualElements(key: symbolInfoIt->first); |
| 1017 | auto startRange = range.first; |
| 1018 | auto endRange = range.second; |
| 1019 | |
| 1020 | auto firstOperand = symbolInfoIt->second.getVarName(name: symbolInfoIt->first); |
| 1021 | for (++startRange; startRange != endRange; ++startRange) { |
| 1022 | auto secondOperand = startRange->second.getVarName(name: symbolInfoIt->first); |
| 1023 | emitMatchCheck( |
| 1024 | opName, |
| 1025 | matchStr: formatv(Fmt: "*{0}.begin() == *{1}.begin()" , Vals&: firstOperand, Vals&: secondOperand), |
| 1026 | failureStr: formatv(Fmt: "\"Operands '{0}' and '{1}' must be equal\"" , Vals&: firstOperand, |
| 1027 | Vals&: secondOperand)); |
| 1028 | } |
| 1029 | |
| 1030 | symbolInfoIt = endRange; |
| 1031 | } |
| 1032 | |
| 1033 | LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n" ); |
| 1034 | } |
| 1035 | |
| 1036 | void PatternEmitter::collectOps(DagNode tree, |
| 1037 | llvm::SmallPtrSetImpl<const Operator *> &ops) { |
| 1038 | // Check if this tree is an operation. |
| 1039 | if (tree.isOperation()) { |
| 1040 | const Operator &op = tree.getDialectOp(mapper: opMap); |
| 1041 | LLVM_DEBUG(llvm::dbgs() |
| 1042 | << "found operation " << op.getOperationName() << '\n'); |
| 1043 | ops.insert(Ptr: &op); |
| 1044 | } |
| 1045 | |
| 1046 | // Recurse the arguments of the tree. |
| 1047 | for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) |
| 1048 | if (auto child = tree.getArgAsNestedDag(index: i)) |
| 1049 | collectOps(tree: child, ops); |
| 1050 | } |
| 1051 | |
| 1052 | void PatternEmitter::emit(StringRef rewriteName) { |
| 1053 | // Get the DAG tree for the source pattern. |
| 1054 | DagNode sourceTree = pattern.getSourcePattern(); |
| 1055 | |
| 1056 | const Operator &rootOp = pattern.getSourceRootOp(); |
| 1057 | auto rootName = rootOp.getOperationName(); |
| 1058 | |
| 1059 | // Collect the set of result operations. |
| 1060 | llvm::SmallPtrSet<const Operator *, 4> resultOps; |
| 1061 | LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n" ); |
| 1062 | for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { |
| 1063 | collectOps(tree: pattern.getResultPattern(index: i), ops&: resultOps); |
| 1064 | } |
| 1065 | LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n" ); |
| 1066 | |
| 1067 | // Emit RewritePattern for Pattern. |
| 1068 | auto locs = pattern.getLocation(); |
| 1069 | os << formatv(Fmt: "/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n" , |
| 1070 | Vals: llvm::reverse(C&: locs)); |
| 1071 | os << formatv(Fmt: R"(struct {0} : public ::mlir::RewritePattern { |
| 1072 | {0}(::mlir::MLIRContext *context) |
| 1073 | : ::mlir::RewritePattern("{1}", {2}, context, {{)" , |
| 1074 | Vals&: rewriteName, Vals&: rootName, Vals: pattern.getBenefit()); |
| 1075 | // Sort result operators by name. |
| 1076 | llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(), |
| 1077 | resultOps.end()); |
| 1078 | llvm::sort(C&: sortedResultOps, Comp: [&](const Operator *lhs, const Operator *rhs) { |
| 1079 | return lhs->getOperationName() < rhs->getOperationName(); |
| 1080 | }); |
| 1081 | llvm::interleaveComma(c: sortedResultOps, os, each_fn: [&](const Operator *op) { |
| 1082 | os << '"' << op->getOperationName() << '"'; |
| 1083 | }); |
| 1084 | os << "}) {}\n" ; |
| 1085 | |
| 1086 | // Emit matchAndRewrite() function. |
| 1087 | { |
| 1088 | auto classScope = os.scope(); |
| 1089 | os.printReindented(str: R"( |
| 1090 | ::llvm::LogicalResult matchAndRewrite(::mlir::Operation *op0, |
| 1091 | ::mlir::PatternRewriter &rewriter) const override {)" ) |
| 1092 | << '\n'; |
| 1093 | { |
| 1094 | auto functionScope = os.scope(); |
| 1095 | |
| 1096 | // Register all symbols bound in the source pattern. |
| 1097 | pattern.collectSourcePatternBoundSymbols(infoMap&: symbolInfoMap); |
| 1098 | |
| 1099 | LLVM_DEBUG(llvm::dbgs() |
| 1100 | << "start creating local variables for capturing matches\n" ); |
| 1101 | os << "// Variables for capturing values and attributes used while " |
| 1102 | "creating ops\n" ; |
| 1103 | // Create local variables for storing the arguments and results bound |
| 1104 | // to symbols. |
| 1105 | for (const auto &symbolInfoPair : symbolInfoMap) { |
| 1106 | const auto &symbol = symbolInfoPair.first; |
| 1107 | const auto &info = symbolInfoPair.second; |
| 1108 | |
| 1109 | os << info.getVarDecl(name: symbol); |
| 1110 | } |
| 1111 | // TODO: capture ops with consistent numbering so that it can be |
| 1112 | // reused for fused loc. |
| 1113 | os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n" ; |
| 1114 | LLVM_DEBUG(llvm::dbgs() |
| 1115 | << "done creating local variables for capturing matches\n" ); |
| 1116 | |
| 1117 | os << "// Match\n" ; |
| 1118 | os << "tblgen_ops.push_back(op0);\n" ; |
| 1119 | emitMatchLogic(tree: sourceTree, opName: "op0" ); |
| 1120 | |
| 1121 | os << "\n// Rewrite\n" ; |
| 1122 | emitRewriteLogic(); |
| 1123 | |
| 1124 | os << "return ::mlir::success();\n" ; |
| 1125 | } |
| 1126 | os << "}\n" ; |
| 1127 | } |
| 1128 | os << "};\n\n" ; |
| 1129 | } |
| 1130 | |
| 1131 | void PatternEmitter::emitRewriteLogic() { |
| 1132 | LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n" ); |
| 1133 | const Operator &rootOp = pattern.getSourceRootOp(); |
| 1134 | int numExpectedResults = rootOp.getNumResults(); |
| 1135 | int numResultPatterns = pattern.getNumResultPatterns(); |
| 1136 | |
| 1137 | // First register all symbols bound to ops generated in result patterns. |
| 1138 | pattern.collectResultPatternBoundSymbols(infoMap&: symbolInfoMap); |
| 1139 | |
| 1140 | // Only the last N static values generated are used to replace the matched |
| 1141 | // root N-result op. We need to calculate the starting index (of the results |
| 1142 | // of the matched op) each result pattern is to replace. |
| 1143 | SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); |
| 1144 | // If we don't need to replace any value at all, set the replacement starting |
| 1145 | // index as the number of result patterns so we skip all of them when trying |
| 1146 | // to replace the matched op's results. |
| 1147 | int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; |
| 1148 | for (int i = numResultPatterns - 1; i >= 0; --i) { |
| 1149 | auto numValues = getNodeValueCount(node: pattern.getResultPattern(index: i)); |
| 1150 | offsets[i] = offsets[i + 1] - numValues; |
| 1151 | if (offsets[i] == 0) { |
| 1152 | if (replStartIndex == -1) |
| 1153 | replStartIndex = i; |
| 1154 | } else if (offsets[i] < 0 && offsets[i + 1] > 0) { |
| 1155 | auto error = formatv( |
| 1156 | Fmt: "cannot use the same multi-result op '{0}' to generate both " |
| 1157 | "auxiliary values and values to be used for replacing the matched op" , |
| 1158 | Vals: pattern.getResultPattern(index: i).getSymbol()); |
| 1159 | PrintFatalError(ErrorLoc: loc, Msg: error); |
| 1160 | } |
| 1161 | } |
| 1162 | |
| 1163 | if (offsets.front() > 0) { |
| 1164 | const char error[] = |
| 1165 | "not enough values generated to replace the matched op" ; |
| 1166 | PrintFatalError(ErrorLoc: loc, Msg: error); |
| 1167 | } |
| 1168 | |
| 1169 | os << "auto odsLoc = rewriter.getFusedLoc({" ; |
| 1170 | for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { |
| 1171 | os << (i ? ", " : "" ) << "tblgen_ops[" << i << "]->getLoc()" ; |
| 1172 | } |
| 1173 | os << "}); (void)odsLoc;\n" ; |
| 1174 | |
| 1175 | // Process auxiliary result patterns. |
| 1176 | for (int i = 0; i < replStartIndex; ++i) { |
| 1177 | DagNode resultTree = pattern.getResultPattern(index: i); |
| 1178 | auto val = handleResultPattern(resultTree, resultIndex: offsets[i], depth: 0); |
| 1179 | // Normal op creation will be streamed to `os` by the above call; but |
| 1180 | // NativeCodeCall will only be materialized to `os` if it is used. Here |
| 1181 | // we are handling auxiliary patterns so we want the side effect even if |
| 1182 | // NativeCodeCall is not replacing matched root op's results. |
| 1183 | if (resultTree.isNativeCodeCall() && |
| 1184 | resultTree.getNumReturnsOfNativeCode() == 0) |
| 1185 | os << val << ";\n" ; |
| 1186 | } |
| 1187 | |
| 1188 | auto processSupplementalPatterns = [&]() { |
| 1189 | int numSupplementalPatterns = pattern.getNumSupplementalPatterns(); |
| 1190 | for (int i = 0, offset = -numSupplementalPatterns; |
| 1191 | i < numSupplementalPatterns; ++i) { |
| 1192 | DagNode resultTree = pattern.getSupplementalPattern(index: i); |
| 1193 | auto val = handleResultPattern(resultTree, resultIndex: offset++, depth: 0); |
| 1194 | if (resultTree.isNativeCodeCall() && |
| 1195 | resultTree.getNumReturnsOfNativeCode() == 0) |
| 1196 | os << val << ";\n" ; |
| 1197 | } |
| 1198 | }; |
| 1199 | |
| 1200 | if (numExpectedResults == 0) { |
| 1201 | assert(replStartIndex >= numResultPatterns && |
| 1202 | "invalid auxiliary vs. replacement pattern division!" ); |
| 1203 | processSupplementalPatterns(); |
| 1204 | // No result to replace. Just erase the op. |
| 1205 | os << "rewriter.eraseOp(op0);\n" ; |
| 1206 | } else { |
| 1207 | // Process replacement result patterns. |
| 1208 | os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n" ; |
| 1209 | for (int i = replStartIndex; i < numResultPatterns; ++i) { |
| 1210 | DagNode resultTree = pattern.getResultPattern(index: i); |
| 1211 | auto val = handleResultPattern(resultTree, resultIndex: offsets[i], depth: 0); |
| 1212 | os << "\n" ; |
| 1213 | // Resolve each symbol for all range use so that we can loop over them. |
| 1214 | // We need an explicit cast to `SmallVector` to capture the cases where |
| 1215 | // `{0}` resolves to an `Operation::result_range` as well as cases that |
| 1216 | // are not iterable (e.g. vector that gets wrapped in additional braces by |
| 1217 | // RewriterGen). |
| 1218 | // TODO: Revisit the need for materializing a vector. |
| 1219 | os << symbolInfoMap.getAllRangeUse( |
| 1220 | symbol: val, |
| 1221 | fmt: "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n" |
| 1222 | " tblgen_repl_values.push_back(v);\n}\n" , |
| 1223 | separator: "\n" ); |
| 1224 | } |
| 1225 | processSupplementalPatterns(); |
| 1226 | os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n" ; |
| 1227 | } |
| 1228 | |
| 1229 | LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n" ); |
| 1230 | } |
| 1231 | |
| 1232 | std::string PatternEmitter::getUniqueSymbol(const Operator *op) { |
| 1233 | return std::string( |
| 1234 | formatv(Fmt: "tblgen_{0}_{1}" , Vals: op->getCppClassName(), Vals: nextValueId++)); |
| 1235 | } |
| 1236 | |
| 1237 | std::string PatternEmitter::handleResultPattern(DagNode resultTree, |
| 1238 | int resultIndex, int depth) { |
| 1239 | LLVM_DEBUG(llvm::dbgs() << "handle result pattern: " ); |
| 1240 | LLVM_DEBUG(resultTree.print(llvm::dbgs())); |
| 1241 | LLVM_DEBUG(llvm::dbgs() << '\n'); |
| 1242 | |
| 1243 | if (resultTree.isLocationDirective()) { |
| 1244 | PrintFatalError(ErrorLoc: loc, |
| 1245 | Msg: "location directive can only be used with op creation" ); |
| 1246 | } |
| 1247 | |
| 1248 | if (resultTree.isNativeCodeCall()) |
| 1249 | return handleReplaceWithNativeCodeCall(resultTree, depth); |
| 1250 | |
| 1251 | if (resultTree.isReplaceWithValue()) |
| 1252 | return handleReplaceWithValue(tree: resultTree).str(); |
| 1253 | |
| 1254 | if (resultTree.isVariadic()) |
| 1255 | return handleVariadic(tree: resultTree, depth); |
| 1256 | |
| 1257 | // Normal op creation. |
| 1258 | auto symbol = handleOpCreation(tree: resultTree, resultIndex, depth); |
| 1259 | if (resultTree.getSymbol().empty()) { |
| 1260 | // This is an op not explicitly bound to a symbol in the rewrite rule. |
| 1261 | // Register the auto-generated symbol for it. |
| 1262 | symbolInfoMap.bindOpResult(symbol, op: pattern.getDialectOp(node: resultTree)); |
| 1263 | } |
| 1264 | return symbol; |
| 1265 | } |
| 1266 | |
| 1267 | std::string PatternEmitter::handleVariadic(DagNode tree, int depth) { |
| 1268 | assert(tree.isVariadic()); |
| 1269 | |
| 1270 | std::string output; |
| 1271 | llvm::raw_string_ostream oss(output); |
| 1272 | auto name = std::string(formatv(Fmt: "tblgen_variadic_values_{0}" , Vals: nextValueId++)); |
| 1273 | symbolInfoMap.bindValue(symbol: name); |
| 1274 | oss << "::llvm::SmallVector<::mlir::Value, 4> " << name << ";\n" ; |
| 1275 | for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { |
| 1276 | if (auto child = tree.getArgAsNestedDag(index: i)) { |
| 1277 | oss << name << ".push_back(" << handleResultPattern(resultTree: child, resultIndex: i, depth: depth + 1) |
| 1278 | << ");\n" ; |
| 1279 | } else { |
| 1280 | oss << name << ".push_back(" |
| 1281 | << handleOpArgument(leaf: tree.getArgAsLeaf(index: i), patArgName: tree.getArgName(index: i)) |
| 1282 | << ");\n" ; |
| 1283 | } |
| 1284 | } |
| 1285 | |
| 1286 | os << oss.str(); |
| 1287 | return name; |
| 1288 | } |
| 1289 | |
| 1290 | StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) { |
| 1291 | assert(tree.isReplaceWithValue()); |
| 1292 | |
| 1293 | if (tree.getNumArgs() != 1) { |
| 1294 | PrintFatalError( |
| 1295 | ErrorLoc: loc, Msg: "replaceWithValue directive must take exactly one argument" ); |
| 1296 | } |
| 1297 | |
| 1298 | if (!tree.getSymbol().empty()) { |
| 1299 | PrintFatalError(ErrorLoc: loc, Msg: "cannot bind symbol to replaceWithValue" ); |
| 1300 | } |
| 1301 | |
| 1302 | return tree.getArgName(index: 0); |
| 1303 | } |
| 1304 | |
| 1305 | std::string PatternEmitter::handleLocationDirective(DagNode tree) { |
| 1306 | assert(tree.isLocationDirective()); |
| 1307 | auto lookUpArgLoc = [this, &tree](int idx) { |
| 1308 | const auto *const lookupFmt = "{0}.getLoc()" ; |
| 1309 | return symbolInfoMap.getValueAndRangeUse(symbol: tree.getArgName(index: idx), fmt: lookupFmt); |
| 1310 | }; |
| 1311 | |
| 1312 | if (tree.getNumArgs() == 0) |
| 1313 | llvm::PrintFatalError( |
| 1314 | Msg: "At least one argument to location directive required" ); |
| 1315 | |
| 1316 | if (!tree.getSymbol().empty()) |
| 1317 | PrintFatalError(ErrorLoc: loc, Msg: "cannot bind symbol to location" ); |
| 1318 | |
| 1319 | if (tree.getNumArgs() == 1) { |
| 1320 | DagLeaf leaf = tree.getArgAsLeaf(index: 0); |
| 1321 | if (leaf.isStringAttr()) |
| 1322 | return formatv(Fmt: "::mlir::NameLoc::get(rewriter.getStringAttr(\"{0}\"))" , |
| 1323 | Vals: leaf.getStringAttr()) |
| 1324 | .str(); |
| 1325 | return lookUpArgLoc(0); |
| 1326 | } |
| 1327 | |
| 1328 | std::string ret; |
| 1329 | llvm::raw_string_ostream os(ret); |
| 1330 | std::string strAttr; |
| 1331 | os << "rewriter.getFusedLoc({" ; |
| 1332 | bool first = true; |
| 1333 | for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { |
| 1334 | DagLeaf leaf = tree.getArgAsLeaf(index: i); |
| 1335 | // Handle the optional string value. |
| 1336 | if (leaf.isStringAttr()) { |
| 1337 | if (!strAttr.empty()) |
| 1338 | llvm::PrintFatalError(Msg: "Only one string attribute may be specified" ); |
| 1339 | strAttr = leaf.getStringAttr(); |
| 1340 | continue; |
| 1341 | } |
| 1342 | os << (first ? "" : ", " ) << lookUpArgLoc(i); |
| 1343 | first = false; |
| 1344 | } |
| 1345 | os << "}" ; |
| 1346 | if (!strAttr.empty()) { |
| 1347 | os << ", rewriter.getStringAttr(\"" << strAttr << "\")" ; |
| 1348 | } |
| 1349 | os << ")" ; |
| 1350 | return os.str(); |
| 1351 | } |
| 1352 | |
| 1353 | std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i, |
| 1354 | int depth) { |
| 1355 | // Nested NativeCodeCall. |
| 1356 | if (auto dagNode = returnType.getArgAsNestedDag(index: i)) { |
| 1357 | if (!dagNode.isNativeCodeCall()) |
| 1358 | PrintFatalError(ErrorLoc: loc, Msg: "nested DAG in `returnType` must be a native code " |
| 1359 | "call" ); |
| 1360 | return handleReplaceWithNativeCodeCall(resultTree: dagNode, depth); |
| 1361 | } |
| 1362 | // String literal. |
| 1363 | auto dagLeaf = returnType.getArgAsLeaf(index: i); |
| 1364 | if (dagLeaf.isStringAttr()) |
| 1365 | return tgfmt(fmt: dagLeaf.getStringAttr(), ctx: &fmtCtx); |
| 1366 | return tgfmt( |
| 1367 | fmt: "$0.getType()" , ctx: &fmtCtx, |
| 1368 | vals: handleOpArgument(leaf: returnType.getArgAsLeaf(index: i), patArgName: returnType.getArgName(index: i))); |
| 1369 | } |
| 1370 | |
| 1371 | std::string PatternEmitter::handleOpArgument(DagLeaf leaf, |
| 1372 | StringRef patArgName) { |
| 1373 | if (leaf.isStringAttr()) |
| 1374 | PrintFatalError(ErrorLoc: loc, Msg: "raw string not supported as argument" ); |
| 1375 | if (leaf.isConstantAttr()) { |
| 1376 | auto constAttr = leaf.getAsConstantAttr(); |
| 1377 | return handleConstantAttr(attr: constAttr.getAttribute(), |
| 1378 | value: constAttr.getConstantValue()); |
| 1379 | } |
| 1380 | if (leaf.isEnumCase()) { |
| 1381 | auto enumCase = leaf.getAsEnumCase(); |
| 1382 | // This is an enum case backed by an IntegerAttr. We need to get its value |
| 1383 | // to build the constant. |
| 1384 | std::string val = std::to_string(val: enumCase.getValue()); |
| 1385 | return handleConstantAttr(attr: Attribute(&enumCase.getDef()), value: val); |
| 1386 | } |
| 1387 | |
| 1388 | LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n" ); |
| 1389 | auto argName = symbolInfoMap.getValueAndRangeUse(symbol: patArgName); |
| 1390 | if (leaf.isUnspecified() || leaf.isOperandMatcher()) { |
| 1391 | LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName |
| 1392 | << "' (via symbol ref)\n" ); |
| 1393 | return argName; |
| 1394 | } |
| 1395 | if (leaf.isNativeCodeCall()) { |
| 1396 | auto repl = tgfmt(fmt: leaf.getNativeCodeTemplate(), ctx: &fmtCtx.withSelf(subst: argName)); |
| 1397 | LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl |
| 1398 | << "' (via NativeCodeCall)\n" ); |
| 1399 | return std::string(repl); |
| 1400 | } |
| 1401 | PrintFatalError(ErrorLoc: loc, Msg: "unhandled case when rewriting op" ); |
| 1402 | } |
| 1403 | |
| 1404 | std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree, |
| 1405 | int depth) { |
| 1406 | LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: " ); |
| 1407 | LLVM_DEBUG(tree.print(llvm::dbgs())); |
| 1408 | LLVM_DEBUG(llvm::dbgs() << '\n'); |
| 1409 | |
| 1410 | auto fmt = tree.getNativeCodeTemplate(); |
| 1411 | |
| 1412 | SmallVector<std::string, 16> attrs; |
| 1413 | |
| 1414 | auto tail = getTrailingDirectives(tree); |
| 1415 | if (tail.returnType) |
| 1416 | PrintFatalError(ErrorLoc: loc, Msg: "`NativeCodeCall` cannot have return type specifier" ); |
| 1417 | auto locToUse = getLocation(tail); |
| 1418 | |
| 1419 | for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { |
| 1420 | if (tree.isNestedDagArg(index: i)) { |
| 1421 | attrs.push_back( |
| 1422 | Elt: handleResultPattern(resultTree: tree.getArgAsNestedDag(index: i), resultIndex: i, depth: depth + 1)); |
| 1423 | } else { |
| 1424 | attrs.push_back( |
| 1425 | Elt: handleOpArgument(leaf: tree.getArgAsLeaf(index: i), patArgName: tree.getArgName(index: i))); |
| 1426 | } |
| 1427 | LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i |
| 1428 | << " replacement: " << attrs[i] << "\n" ); |
| 1429 | } |
| 1430 | |
| 1431 | std::string symbol = tgfmt(fmt, ctx: &fmtCtx.addSubst(placeholder: "_loc" , subst: locToUse), |
| 1432 | params: static_cast<ArrayRef<std::string>>(attrs)); |
| 1433 | |
| 1434 | // In general, NativeCodeCall without naming binding don't need this. To |
| 1435 | // ensure void helper function has been correctly labeled, i.e., use |
| 1436 | // NativeCodeCallVoid, we cache the result to a local variable so that we will |
| 1437 | // get a compilation error in the auto-generated file. |
| 1438 | // Example. |
| 1439 | // // In the td file |
| 1440 | // Pat<(...), (NativeCodeCall<Foo> ...)> |
| 1441 | // |
| 1442 | // --- |
| 1443 | // |
| 1444 | // // In the auto-generated .cpp |
| 1445 | // ... |
| 1446 | // // Causes compilation error if Foo() returns void. |
| 1447 | // auto nativeVar = Foo(); |
| 1448 | // ... |
| 1449 | if (tree.getNumReturnsOfNativeCode() != 0) { |
| 1450 | // Determine the local variable name for return value. |
| 1451 | std::string varName = |
| 1452 | SymbolInfoMap::getValuePackName(symbol: tree.getSymbol()).str(); |
| 1453 | if (varName.empty()) { |
| 1454 | varName = formatv(Fmt: "nativeVar_{0}" , Vals: nextValueId++); |
| 1455 | // Register the local variable for later uses. |
| 1456 | symbolInfoMap.bindValues(symbol: varName, numValues: tree.getNumReturnsOfNativeCode()); |
| 1457 | } |
| 1458 | |
| 1459 | // Catch the return value of helper function. |
| 1460 | os << formatv(Fmt: "auto {0} = {1}; (void){0};\n" , Vals&: varName, Vals&: symbol); |
| 1461 | |
| 1462 | if (!tree.getSymbol().empty()) |
| 1463 | symbol = tree.getSymbol().str(); |
| 1464 | else |
| 1465 | symbol = varName; |
| 1466 | } |
| 1467 | |
| 1468 | return symbol; |
| 1469 | } |
| 1470 | |
| 1471 | int PatternEmitter::getNodeValueCount(DagNode node) { |
| 1472 | if (node.isOperation()) { |
| 1473 | // If the op is bound to a symbol in the rewrite rule, query its result |
| 1474 | // count from the symbol info map. |
| 1475 | auto symbol = node.getSymbol(); |
| 1476 | if (!symbol.empty()) { |
| 1477 | return symbolInfoMap.getStaticValueCount(symbol); |
| 1478 | } |
| 1479 | // Otherwise this is an unbound op; we will use all its results. |
| 1480 | return pattern.getDialectOp(node).getNumResults(); |
| 1481 | } |
| 1482 | |
| 1483 | if (node.isNativeCodeCall()) |
| 1484 | return node.getNumReturnsOfNativeCode(); |
| 1485 | |
| 1486 | return 1; |
| 1487 | } |
| 1488 | |
| 1489 | PatternEmitter::TrailingDirectives |
| 1490 | PatternEmitter::getTrailingDirectives(DagNode tree) { |
| 1491 | TrailingDirectives tail = {.location: DagNode(nullptr), .returnType: DagNode(nullptr), .numDirectives: 0}; |
| 1492 | |
| 1493 | // Look backwards through the arguments. |
| 1494 | auto numPatArgs = tree.getNumArgs(); |
| 1495 | for (int i = numPatArgs - 1; i >= 0; --i) { |
| 1496 | auto dagArg = tree.getArgAsNestedDag(index: i); |
| 1497 | // A leaf is not a directive. Stop looking. |
| 1498 | if (!dagArg) |
| 1499 | break; |
| 1500 | |
| 1501 | auto isLocation = dagArg.isLocationDirective(); |
| 1502 | auto isReturnType = dagArg.isReturnTypeDirective(); |
| 1503 | // If encountered a DAG node that isn't a trailing directive, stop looking. |
| 1504 | if (!(isLocation || isReturnType)) |
| 1505 | break; |
| 1506 | // Save the directive, but error if one of the same type was already |
| 1507 | // found. |
| 1508 | ++tail.numDirectives; |
| 1509 | if (isLocation) { |
| 1510 | if (tail.location) |
| 1511 | PrintFatalError(ErrorLoc: loc, Msg: "`location` directive can only be specified " |
| 1512 | "once" ); |
| 1513 | tail.location = dagArg; |
| 1514 | } else if (isReturnType) { |
| 1515 | if (tail.returnType) |
| 1516 | PrintFatalError(ErrorLoc: loc, Msg: "`returnType` directive can only be specified " |
| 1517 | "once" ); |
| 1518 | tail.returnType = dagArg; |
| 1519 | } |
| 1520 | } |
| 1521 | |
| 1522 | return tail; |
| 1523 | } |
| 1524 | |
| 1525 | std::string |
| 1526 | PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) { |
| 1527 | if (tail.location) |
| 1528 | return handleLocationDirective(tree: tail.location); |
| 1529 | |
| 1530 | // If no explicit location is given, use the default, all fused, location. |
| 1531 | return "odsLoc" ; |
| 1532 | } |
| 1533 | |
| 1534 | std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, |
| 1535 | int depth) { |
| 1536 | LLVM_DEBUG(llvm::dbgs() << "create op for pattern: " ); |
| 1537 | LLVM_DEBUG(tree.print(llvm::dbgs())); |
| 1538 | LLVM_DEBUG(llvm::dbgs() << '\n'); |
| 1539 | |
| 1540 | Operator &resultOp = tree.getDialectOp(mapper: opMap); |
| 1541 | bool useProperties = resultOp.getDialect().usePropertiesForAttributes(); |
| 1542 | auto numOpArgs = resultOp.getNumArgs(); |
| 1543 | auto numPatArgs = tree.getNumArgs(); |
| 1544 | |
| 1545 | auto tail = getTrailingDirectives(tree); |
| 1546 | auto locToUse = getLocation(tail); |
| 1547 | |
| 1548 | auto inPattern = numPatArgs - tail.numDirectives; |
| 1549 | if (numOpArgs != inPattern) { |
| 1550 | PrintFatalError(ErrorLoc: loc, |
| 1551 | Msg: formatv(Fmt: "resultant op '{0}' argument number mismatch: " |
| 1552 | "{1} in pattern vs. {2} in definition" , |
| 1553 | Vals: resultOp.getOperationName(), Vals&: inPattern, Vals&: numOpArgs)); |
| 1554 | } |
| 1555 | |
| 1556 | // A map to collect all nested DAG child nodes' names, with operand index as |
| 1557 | // the key. This includes both bound and unbound child nodes. |
| 1558 | ChildNodeIndexNameMap childNodeNames; |
| 1559 | |
| 1560 | // If the argument is a type constraint, then its an operand. Check if the |
| 1561 | // op's argument is variadic that the argument in the pattern is too. |
| 1562 | auto checkIfMatchedVariadic = [&](int i) { |
| 1563 | // FIXME: This does not yet check for variable/leaf case. |
| 1564 | // FIXME: Change so that native code call can be handled. |
| 1565 | const auto *operand = |
| 1566 | llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val: resultOp.getArg(index: i)); |
| 1567 | if (!operand || !operand->isVariadic()) |
| 1568 | return; |
| 1569 | |
| 1570 | auto child = tree.getArgAsNestedDag(index: i); |
| 1571 | if (!child) |
| 1572 | return; |
| 1573 | |
| 1574 | // Skip over replaceWithValues. |
| 1575 | while (child.isReplaceWithValue()) { |
| 1576 | if (!(child = child.getArgAsNestedDag(index: 0))) |
| 1577 | return; |
| 1578 | } |
| 1579 | if (!child.isNativeCodeCall() && !child.isVariadic()) |
| 1580 | PrintFatalError(ErrorLoc: loc, Msg: formatv(Fmt: "op expects variadic operand `{0}`, while " |
| 1581 | "provided is non-variadic" , |
| 1582 | Vals: resultOp.getArgName(index: i))); |
| 1583 | }; |
| 1584 | |
| 1585 | // First go through all the child nodes who are nested DAG constructs to |
| 1586 | // create ops for them and remember the symbol names for them, so that we can |
| 1587 | // use the results in the current node. This happens in a recursive manner. |
| 1588 | for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { |
| 1589 | checkIfMatchedVariadic(i); |
| 1590 | if (auto child = tree.getArgAsNestedDag(index: i)) |
| 1591 | childNodeNames[i] = handleResultPattern(resultTree: child, resultIndex: i, depth: depth + 1); |
| 1592 | } |
| 1593 | |
| 1594 | // The name of the local variable holding this op. |
| 1595 | std::string valuePackName; |
| 1596 | // The symbol for holding the result of this pattern. Note that the result of |
| 1597 | // this pattern is not necessarily the same as the variable created by this |
| 1598 | // pattern because we can use `__N` suffix to refer only a specific result if |
| 1599 | // the generated op is a multi-result op. |
| 1600 | std::string resultValue; |
| 1601 | if (tree.getSymbol().empty()) { |
| 1602 | // No symbol is explicitly bound to this op in the pattern. Generate a |
| 1603 | // unique name. |
| 1604 | valuePackName = resultValue = getUniqueSymbol(op: &resultOp); |
| 1605 | } else { |
| 1606 | resultValue = std::string(tree.getSymbol()); |
| 1607 | // Strip the index to get the name for the value pack and use it to name the |
| 1608 | // local variable for the op. |
| 1609 | valuePackName = std::string(SymbolInfoMap::getValuePackName(symbol: resultValue)); |
| 1610 | } |
| 1611 | |
| 1612 | // Create the local variable for this op. |
| 1613 | os << formatv(Fmt: "{0} {1};\n{{\n" , Vals: resultOp.getQualCppClassName(), |
| 1614 | Vals&: valuePackName); |
| 1615 | |
| 1616 | // Right now ODS don't have general type inference support. Except a few |
| 1617 | // special cases listed below, DRR needs to supply types for all results |
| 1618 | // when building an op. |
| 1619 | bool isSameOperandsAndResultType = |
| 1620 | resultOp.getTrait(trait: "::mlir::OpTrait::SameOperandsAndResultType" ); |
| 1621 | bool useFirstAttr = |
| 1622 | resultOp.getTrait(trait: "::mlir::OpTrait::FirstAttrDerivedResultType" ); |
| 1623 | |
| 1624 | if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) { |
| 1625 | // We know how to deduce the result type for ops with these traits and we've |
| 1626 | // generated builders taking aggregate parameters. Use those builders to |
| 1627 | // create the ops. |
| 1628 | |
| 1629 | // First prepare local variables for op arguments used in builder call. |
| 1630 | createAggregateLocalVarsForOpArgs(node: tree, childNodeNames, depth); |
| 1631 | |
| 1632 | // Then create the op. |
| 1633 | os.scope(open: "" , close: "\n}\n" ).os |
| 1634 | << formatv(Fmt: "{0} = rewriter.create<{1}>({2}, tblgen_values, {3});" , |
| 1635 | Vals&: valuePackName, Vals: resultOp.getQualCppClassName(), Vals&: locToUse, |
| 1636 | Vals: useProperties ? "tblgen_props" : "tblgen_attrs" ); |
| 1637 | return resultValue; |
| 1638 | } |
| 1639 | |
| 1640 | bool usePartialResults = valuePackName != resultValue; |
| 1641 | |
| 1642 | if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) { |
| 1643 | // For these cases (broadcastable ops, op results used both as auxiliary |
| 1644 | // values and replacement values, ops in nested patterns, auxiliary ops), we |
| 1645 | // still need to supply the result types when building the op. But because |
| 1646 | // we don't generate a builder automatically with ODS for them, it's the |
| 1647 | // developer's responsibility to make sure such a builder (with result type |
| 1648 | // deduction ability) exists. We go through the separate-parameter builder |
| 1649 | // here given that it's easier for developers to write compared to |
| 1650 | // aggregate-parameter builders. |
| 1651 | createSeparateLocalVarsForOpArgs(node: tree, childNodeNames); |
| 1652 | |
| 1653 | os.scope().os << formatv(Fmt: "{0} = rewriter.create<{1}>({2}" , Vals&: valuePackName, |
| 1654 | Vals: resultOp.getQualCppClassName(), Vals&: locToUse); |
| 1655 | supplyValuesForOpArgs(node: tree, childNodeNames, depth); |
| 1656 | os << "\n );\n}\n" ; |
| 1657 | return resultValue; |
| 1658 | } |
| 1659 | |
| 1660 | // If we are provided explicit return types, use them to build the op. |
| 1661 | // However, if depth == 0 and resultIndex >= 0, it means we are replacing |
| 1662 | // the values generated from the source pattern root op. Then we must use the |
| 1663 | // source pattern's value types to determine the value type of the generated |
| 1664 | // op here. |
| 1665 | if (depth == 0 && resultIndex >= 0 && tail.returnType) |
| 1666 | PrintFatalError(ErrorLoc: loc, Msg: "Cannot specify explicit return types in an op whose " |
| 1667 | "return values replace the source pattern's root op" ); |
| 1668 | |
| 1669 | // First prepare local variables for op arguments used in builder call. |
| 1670 | createAggregateLocalVarsForOpArgs(node: tree, childNodeNames, depth); |
| 1671 | |
| 1672 | // Then prepare the result types. We need to specify the types for all |
| 1673 | // results. |
| 1674 | os.indent() << formatv(Fmt: "::llvm::SmallVector<::mlir::Type, 4> tblgen_types; " |
| 1675 | "(void)tblgen_types;\n" ); |
| 1676 | int numResults = resultOp.getNumResults(); |
| 1677 | if (tail.returnType) { |
| 1678 | auto numRetTys = tail.returnType.getNumArgs(); |
| 1679 | for (int i = 0; i < numRetTys; ++i) { |
| 1680 | auto varName = handleReturnTypeArg(returnType: tail.returnType, i, depth: depth + 1); |
| 1681 | os << "tblgen_types.push_back(" << varName << ");\n" ; |
| 1682 | } |
| 1683 | } else { |
| 1684 | if (numResults != 0) { |
| 1685 | // Copy the result types from the source pattern. |
| 1686 | for (int i = 0; i < numResults; ++i) |
| 1687 | os << formatv(Fmt: "for (auto v: castedOp0.getODSResults({0})) {{\n" |
| 1688 | " tblgen_types.push_back(v.getType());\n}\n" , |
| 1689 | Vals: resultIndex + i); |
| 1690 | } |
| 1691 | } |
| 1692 | os << formatv(Fmt: "{0} = rewriter.create<{1}>({2}, tblgen_types, " |
| 1693 | "tblgen_values, {3});\n" , |
| 1694 | Vals&: valuePackName, Vals: resultOp.getQualCppClassName(), Vals&: locToUse, |
| 1695 | Vals: useProperties ? "tblgen_props" : "tblgen_attrs" ); |
| 1696 | os.unindent() << "}\n" ; |
| 1697 | return resultValue; |
| 1698 | } |
| 1699 | |
| 1700 | void PatternEmitter::createSeparateLocalVarsForOpArgs( |
| 1701 | DagNode node, ChildNodeIndexNameMap &childNodeNames) { |
| 1702 | Operator &resultOp = node.getDialectOp(mapper: opMap); |
| 1703 | |
| 1704 | // Now prepare operands used for building this op: |
| 1705 | // * If the operand is non-variadic, we create a `Value` local variable. |
| 1706 | // * If the operand is variadic, we create a `SmallVector<Value>` local |
| 1707 | // variable. |
| 1708 | |
| 1709 | int valueIndex = 0; // An index for uniquing local variable names. |
| 1710 | for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { |
| 1711 | const auto *operand = |
| 1712 | llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val: resultOp.getArg(index: argIndex)); |
| 1713 | // We do not need special handling for attributes. |
| 1714 | if (!operand) |
| 1715 | continue; |
| 1716 | |
| 1717 | raw_indented_ostream::DelimitedScope scope(os); |
| 1718 | std::string varName; |
| 1719 | if (operand->isVariadic()) { |
| 1720 | varName = std::string(formatv(Fmt: "tblgen_values_{0}" , Vals: valueIndex++)); |
| 1721 | os << formatv(Fmt: "::llvm::SmallVector<::mlir::Value, 4> {0};\n" , Vals&: varName); |
| 1722 | std::string range; |
| 1723 | if (node.isNestedDagArg(index: argIndex)) { |
| 1724 | range = childNodeNames[argIndex]; |
| 1725 | } else { |
| 1726 | range = std::string(node.getArgName(index: argIndex)); |
| 1727 | } |
| 1728 | // Resolve the symbol for all range use so that we have a uniform way of |
| 1729 | // capturing the values. |
| 1730 | range = symbolInfoMap.getValueAndRangeUse(symbol: range); |
| 1731 | os << formatv(Fmt: "for (auto v: {0}) {{\n {1}.push_back(v);\n}\n" , Vals&: range, |
| 1732 | Vals&: varName); |
| 1733 | } else { |
| 1734 | varName = std::string(formatv(Fmt: "tblgen_value_{0}" , Vals: valueIndex++)); |
| 1735 | os << formatv(Fmt: "::mlir::Value {0} = " , Vals&: varName); |
| 1736 | if (node.isNestedDagArg(index: argIndex)) { |
| 1737 | os << symbolInfoMap.getValueAndRangeUse(symbol: childNodeNames[argIndex]); |
| 1738 | } else { |
| 1739 | DagLeaf leaf = node.getArgAsLeaf(index: argIndex); |
| 1740 | auto symbol = |
| 1741 | symbolInfoMap.getValueAndRangeUse(symbol: node.getArgName(index: argIndex)); |
| 1742 | if (leaf.isNativeCodeCall()) { |
| 1743 | os << std::string( |
| 1744 | tgfmt(fmt: leaf.getNativeCodeTemplate(), ctx: &fmtCtx.withSelf(subst: symbol))); |
| 1745 | } else { |
| 1746 | os << symbol; |
| 1747 | } |
| 1748 | } |
| 1749 | os << ";\n" ; |
| 1750 | } |
| 1751 | |
| 1752 | // Update to use the newly created local variable for building the op later. |
| 1753 | childNodeNames[argIndex] = varName; |
| 1754 | } |
| 1755 | } |
| 1756 | |
| 1757 | void PatternEmitter::supplyValuesForOpArgs( |
| 1758 | DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { |
| 1759 | Operator &resultOp = node.getDialectOp(mapper: opMap); |
| 1760 | for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); |
| 1761 | argIndex != numOpArgs; ++argIndex) { |
| 1762 | // Start each argument on its own line. |
| 1763 | os << ",\n " ; |
| 1764 | |
| 1765 | Argument opArg = resultOp.getArg(index: argIndex); |
| 1766 | // Handle the case of operand first. |
| 1767 | if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: opArg)) { |
| 1768 | if (!operand->name.empty()) |
| 1769 | os << "/*" << operand->name << "=*/" ; |
| 1770 | os << childNodeNames.lookup(Val: argIndex); |
| 1771 | continue; |
| 1772 | } |
| 1773 | |
| 1774 | // The argument in the op definition. |
| 1775 | auto opArgName = resultOp.getArgName(index: argIndex); |
| 1776 | if (auto subTree = node.getArgAsNestedDag(index: argIndex)) { |
| 1777 | if (!subTree.isNativeCodeCall()) |
| 1778 | PrintFatalError(ErrorLoc: loc, Msg: "only NativeCodeCall allowed in nested dag node " |
| 1779 | "for creating attribute" ); |
| 1780 | os << formatv(Fmt: "/*{0}=*/{1}" , Vals&: opArgName, Vals: childNodeNames.lookup(Val: argIndex)); |
| 1781 | } else { |
| 1782 | auto leaf = node.getArgAsLeaf(index: argIndex); |
| 1783 | // The argument in the result DAG pattern. |
| 1784 | auto patArgName = node.getArgName(index: argIndex); |
| 1785 | if (leaf.isConstantAttr() || leaf.isEnumCase()) { |
| 1786 | // TODO: Refactor out into map to avoid recomputing these. |
| 1787 | if (!isa<NamedAttribute *>(Val: opArg)) |
| 1788 | PrintFatalError(ErrorLoc: loc, Msg: Twine("expected attribute " ) + Twine(argIndex)); |
| 1789 | if (!patArgName.empty()) |
| 1790 | os << "/*" << patArgName << "=*/" ; |
| 1791 | } else { |
| 1792 | os << "/*" << opArgName << "=*/" ; |
| 1793 | } |
| 1794 | os << handleOpArgument(leaf, patArgName); |
| 1795 | } |
| 1796 | } |
| 1797 | } |
| 1798 | |
| 1799 | void PatternEmitter::createAggregateLocalVarsForOpArgs( |
| 1800 | DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { |
| 1801 | Operator &resultOp = node.getDialectOp(mapper: opMap); |
| 1802 | |
| 1803 | bool useProperties = resultOp.getDialect().usePropertiesForAttributes(); |
| 1804 | auto scope = os.scope(); |
| 1805 | os << formatv(Fmt: "::llvm::SmallVector<::mlir::Value, 4> " |
| 1806 | "tblgen_values; (void)tblgen_values;\n" ); |
| 1807 | if (useProperties) { |
| 1808 | os << formatv(Fmt: "{0}::Properties tblgen_props; (void)tblgen_props;\n" , |
| 1809 | Vals: resultOp.getQualCppClassName()); |
| 1810 | } else { |
| 1811 | os << formatv(Fmt: "::llvm::SmallVector<::mlir::NamedAttribute, 4> " |
| 1812 | "tblgen_attrs; (void)tblgen_attrs;\n" ); |
| 1813 | } |
| 1814 | |
| 1815 | const char *setPropCmd = |
| 1816 | "tblgen_props.{0} = " |
| 1817 | "::llvm::dyn_cast_if_present<decltype(tblgen_props.{0})>({1});\n" ; |
| 1818 | const char *addAttrCmd = |
| 1819 | "if (auto tmpAttr = {1}) {\n" |
| 1820 | " tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), " |
| 1821 | "tmpAttr);\n}\n" ; |
| 1822 | const char *setterCmd = (useProperties) ? setPropCmd : addAttrCmd; |
| 1823 | |
| 1824 | int numVariadic = 0; |
| 1825 | bool hasOperandSegmentSizes = false; |
| 1826 | std::vector<std::string> sizes; |
| 1827 | for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { |
| 1828 | if (isa<NamedAttribute *>(Val: resultOp.getArg(index: argIndex))) { |
| 1829 | // The argument in the op definition. |
| 1830 | auto opArgName = resultOp.getArgName(index: argIndex); |
| 1831 | hasOperandSegmentSizes = |
| 1832 | hasOperandSegmentSizes || opArgName == "operandSegmentSizes" ; |
| 1833 | if (auto subTree = node.getArgAsNestedDag(index: argIndex)) { |
| 1834 | if (!subTree.isNativeCodeCall()) |
| 1835 | PrintFatalError(ErrorLoc: loc, Msg: "only NativeCodeCall allowed in nested dag node " |
| 1836 | "for creating attribute" ); |
| 1837 | |
| 1838 | os << formatv(Fmt: setterCmd, Vals&: opArgName, Vals: childNodeNames.lookup(Val: argIndex)); |
| 1839 | } else { |
| 1840 | auto leaf = node.getArgAsLeaf(index: argIndex); |
| 1841 | // The argument in the result DAG pattern. |
| 1842 | auto patArgName = node.getArgName(index: argIndex); |
| 1843 | os << formatv(Fmt: setterCmd, Vals&: opArgName, Vals: handleOpArgument(leaf, patArgName)); |
| 1844 | } |
| 1845 | continue; |
| 1846 | } |
| 1847 | |
| 1848 | const auto *operand = |
| 1849 | cast<NamedTypeConstraint *>(Val: resultOp.getArg(index: argIndex)); |
| 1850 | if (operand->isVariadic()) { |
| 1851 | ++numVariadic; |
| 1852 | std::string range; |
| 1853 | if (node.isNestedDagArg(index: argIndex)) { |
| 1854 | range = childNodeNames.lookup(Val: argIndex); |
| 1855 | } else { |
| 1856 | range = std::string(node.getArgName(index: argIndex)); |
| 1857 | } |
| 1858 | // Resolve the symbol for all range use so that we have a uniform way of |
| 1859 | // capturing the values. |
| 1860 | range = symbolInfoMap.getValueAndRangeUse(symbol: range); |
| 1861 | os << formatv(Fmt: "for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n" , |
| 1862 | Vals&: range); |
| 1863 | sizes.push_back(x: formatv(Fmt: "static_cast<int32_t>({0}.size())" , Vals&: range)); |
| 1864 | } else { |
| 1865 | sizes.emplace_back(args: "1" ); |
| 1866 | os << formatv(Fmt: "tblgen_values.push_back(" ); |
| 1867 | if (node.isNestedDagArg(index: argIndex)) { |
| 1868 | os << symbolInfoMap.getValueAndRangeUse( |
| 1869 | symbol: childNodeNames.lookup(Val: argIndex)); |
| 1870 | } else { |
| 1871 | DagLeaf leaf = node.getArgAsLeaf(index: argIndex); |
| 1872 | if (leaf.isConstantAttr()) |
| 1873 | // TODO: Use better location |
| 1874 | PrintFatalError( |
| 1875 | ErrorLoc: loc, |
| 1876 | Msg: "attribute found where value was expected, if attempting to use " |
| 1877 | "constant value, construct a constant op with given attribute " |
| 1878 | "instead" ); |
| 1879 | |
| 1880 | auto symbol = |
| 1881 | symbolInfoMap.getValueAndRangeUse(symbol: node.getArgName(index: argIndex)); |
| 1882 | if (leaf.isNativeCodeCall()) { |
| 1883 | os << std::string( |
| 1884 | tgfmt(fmt: leaf.getNativeCodeTemplate(), ctx: &fmtCtx.withSelf(subst: symbol))); |
| 1885 | } else { |
| 1886 | os << symbol; |
| 1887 | } |
| 1888 | } |
| 1889 | os << ");\n" ; |
| 1890 | } |
| 1891 | } |
| 1892 | |
| 1893 | if (numVariadic > 1 && !hasOperandSegmentSizes) { |
| 1894 | // Only set size if it can't be computed. |
| 1895 | const auto *sameVariadicSize = |
| 1896 | resultOp.getTrait(trait: "::mlir::OpTrait::SameVariadicOperandSize" ); |
| 1897 | if (!sameVariadicSize) { |
| 1898 | if (useProperties) { |
| 1899 | const char *setSizes = R"( |
| 1900 | tblgen_props.operandSegmentSizes = {{ {0} }; |
| 1901 | )" ; |
| 1902 | os.printReindented(str: formatv(Fmt: setSizes, Vals: llvm::join(R&: sizes, Separator: ", " )).str()); |
| 1903 | } else { |
| 1904 | const char *setSizes = R"( |
| 1905 | tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"), |
| 1906 | rewriter.getDenseI32ArrayAttr({{ {0} })); |
| 1907 | )" ; |
| 1908 | os.printReindented(str: formatv(Fmt: setSizes, Vals: llvm::join(R&: sizes, Separator: ", " )).str()); |
| 1909 | } |
| 1910 | } |
| 1911 | } |
| 1912 | } |
| 1913 | |
| 1914 | StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os, |
| 1915 | const RecordKeeper &records, |
| 1916 | RecordOperatorMap &mapper) |
| 1917 | : opMap(mapper), staticVerifierEmitter(os, records) {} |
| 1918 | |
| 1919 | void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) { |
| 1920 | // PatternEmitter will use the static matcher if there's one generated. To |
| 1921 | // ensure that all the dependent static matchers are generated before emitting |
| 1922 | // the matching logic of the DagNode, we use topological order to achieve it. |
| 1923 | for (auto &dagInfo : topologicalOrder) { |
| 1924 | DagNode node = dagInfo.first; |
| 1925 | if (!useStaticMatcher(node)) |
| 1926 | continue; |
| 1927 | |
| 1928 | std::string funcName = |
| 1929 | formatv(Fmt: "static_dag_matcher_{0}" , Vals: staticMatcherCounter++); |
| 1930 | assert(!matcherNames.contains(node)); |
| 1931 | PatternEmitter(dagInfo.second, &opMap, os, *this) |
| 1932 | .emitStaticMatcher(tree: node, funcName); |
| 1933 | matcherNames[node] = funcName; |
| 1934 | } |
| 1935 | } |
| 1936 | |
| 1937 | void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) { |
| 1938 | staticVerifierEmitter.emitPatternConstraints(constraints: constraints.getArrayRef()); |
| 1939 | } |
| 1940 | |
| 1941 | void StaticMatcherHelper::addPattern(const Record *record) { |
| 1942 | Pattern pat(record, &opMap); |
| 1943 | |
| 1944 | // While generating the function body of the DAG matcher, it may depends on |
| 1945 | // other DAG matchers. To ensure the dependent matchers are ready, we compute |
| 1946 | // the topological order for all the DAGs and emit the DAG matchers in this |
| 1947 | // order. |
| 1948 | llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) { |
| 1949 | ++refStats[node]; |
| 1950 | |
| 1951 | if (refStats[node] != 1) |
| 1952 | return; |
| 1953 | |
| 1954 | for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i) |
| 1955 | if (DagNode sibling = node.getArgAsNestedDag(index: i)) |
| 1956 | dfs(sibling); |
| 1957 | else { |
| 1958 | DagLeaf leaf = node.getArgAsLeaf(index: i); |
| 1959 | if (!leaf.isUnspecified()) |
| 1960 | constraints.insert(X: leaf); |
| 1961 | } |
| 1962 | |
| 1963 | topologicalOrder.push_back(Elt: std::make_pair(x&: node, y&: record)); |
| 1964 | }; |
| 1965 | |
| 1966 | dfs(pat.getSourcePattern()); |
| 1967 | } |
| 1968 | |
| 1969 | StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) { |
| 1970 | if (leaf.isAttrMatcher()) { |
| 1971 | std::optional<StringRef> constraint = |
| 1972 | staticVerifierEmitter.getAttrConstraintFn(constraint: leaf.getAsConstraint()); |
| 1973 | assert(constraint && "attribute constraint was not uniqued" ); |
| 1974 | return *constraint; |
| 1975 | } |
| 1976 | assert(leaf.isOperandMatcher()); |
| 1977 | return staticVerifierEmitter.getTypeConstraintFn(constraint: leaf.getAsConstraint()); |
| 1978 | } |
| 1979 | |
| 1980 | static void emitRewriters(const RecordKeeper &records, raw_ostream &os) { |
| 1981 | emitSourceFileHeader(Desc: "Rewriters" , OS&: os, Record: records); |
| 1982 | |
| 1983 | auto patterns = records.getAllDerivedDefinitions(ClassName: "Pattern" ); |
| 1984 | |
| 1985 | // We put the map here because it can be shared among multiple patterns. |
| 1986 | RecordOperatorMap recordOpMap; |
| 1987 | |
| 1988 | // Exam all the patterns and generate static matcher for the duplicated |
| 1989 | // DagNode. |
| 1990 | StaticMatcherHelper staticMatcher(os, records, recordOpMap); |
| 1991 | for (const Record *p : patterns) |
| 1992 | staticMatcher.addPattern(record: p); |
| 1993 | staticMatcher.populateStaticConstraintFunctions(os); |
| 1994 | staticMatcher.populateStaticMatchers(os); |
| 1995 | |
| 1996 | std::vector<std::string> rewriterNames; |
| 1997 | rewriterNames.reserve(n: patterns.size()); |
| 1998 | |
| 1999 | std::string baseRewriterName = "GeneratedConvert" ; |
| 2000 | int rewriterIndex = 0; |
| 2001 | |
| 2002 | for (const Record *p : patterns) { |
| 2003 | std::string name; |
| 2004 | if (p->isAnonymous()) { |
| 2005 | // If no name is provided, ensure unique rewriter names simply by |
| 2006 | // appending unique suffix. |
| 2007 | name = baseRewriterName + llvm::utostr(X: rewriterIndex++); |
| 2008 | } else { |
| 2009 | name = std::string(p->getName()); |
| 2010 | } |
| 2011 | LLVM_DEBUG(llvm::dbgs() |
| 2012 | << "=== start generating pattern '" << name << "' ===\n" ); |
| 2013 | PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(rewriteName: name); |
| 2014 | LLVM_DEBUG(llvm::dbgs() |
| 2015 | << "=== done generating pattern '" << name << "' ===\n" ); |
| 2016 | rewriterNames.push_back(x: std::move(name)); |
| 2017 | } |
| 2018 | |
| 2019 | // Emit function to add the generated matchers to the pattern list. |
| 2020 | os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(" |
| 2021 | "::mlir::RewritePatternSet &patterns) {\n" ; |
| 2022 | for (const auto &name : rewriterNames) { |
| 2023 | os << " patterns.add<" << name << ">(patterns.getContext());\n" ; |
| 2024 | } |
| 2025 | os << "}\n" ; |
| 2026 | } |
| 2027 | |
| 2028 | static mlir::GenRegistration |
| 2029 | genRewriters("gen-rewriters" , "Generate pattern rewriters" , |
| 2030 | [](const RecordKeeper &records, raw_ostream &os) { |
| 2031 | emitRewriters(records, os); |
| 2032 | return false; |
| 2033 | }); |
| 2034 | |