1 | //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Support/IndentedOstream.h" |
14 | #include "mlir/TableGen/Argument.h" |
15 | #include "mlir/TableGen/Attribute.h" |
16 | #include "mlir/TableGen/CodeGenHelpers.h" |
17 | #include "mlir/TableGen/Format.h" |
18 | #include "mlir/TableGen/GenInfo.h" |
19 | #include "mlir/TableGen/Operator.h" |
20 | #include "mlir/TableGen/Pattern.h" |
21 | #include "mlir/TableGen/Predicate.h" |
22 | #include "mlir/TableGen/Property.h" |
23 | #include "mlir/TableGen/Type.h" |
24 | #include "llvm/ADT/FunctionExtras.h" |
25 | #include "llvm/ADT/SetVector.h" |
26 | #include "llvm/ADT/StringExtras.h" |
27 | #include "llvm/ADT/StringSet.h" |
28 | #include "llvm/Support/CommandLine.h" |
29 | #include "llvm/Support/Debug.h" |
30 | #include "llvm/Support/FormatAdapters.h" |
31 | #include "llvm/Support/PrettyStackTrace.h" |
32 | #include "llvm/Support/Signals.h" |
33 | #include "llvm/TableGen/Error.h" |
34 | #include "llvm/TableGen/Main.h" |
35 | #include "llvm/TableGen/Record.h" |
36 | #include "llvm/TableGen/TableGenBackend.h" |
37 | |
38 | using namespace mlir; |
39 | using namespace mlir::tblgen; |
40 | |
41 | using llvm::formatv; |
42 | using llvm::Record; |
43 | using llvm::RecordKeeper; |
44 | |
45 | #define DEBUG_TYPE "mlir-tblgen-rewritergen" |
46 | |
47 | namespace llvm { |
48 | template <> |
49 | struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { |
50 | static void format(const mlir::tblgen::Pattern::IdentifierLine &v, |
51 | raw_ostream &os, StringRef style) { |
52 | os << v.first << ":" << v.second; |
53 | } |
54 | }; |
55 | } // namespace llvm |
56 | |
57 | //===----------------------------------------------------------------------===// |
58 | // PatternEmitter |
59 | //===----------------------------------------------------------------------===// |
60 | |
61 | namespace { |
62 | |
63 | class StaticMatcherHelper; |
64 | |
65 | class PatternEmitter { |
66 | public: |
67 | PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os, |
68 | StaticMatcherHelper &helper); |
69 | |
70 | // Emits the mlir::RewritePattern struct named `rewriteName`. |
71 | void emit(StringRef rewriteName); |
72 | |
73 | // Emits the static function of DAG matcher. |
74 | void emitStaticMatcher(DagNode tree, std::string funcName); |
75 | |
76 | private: |
77 | // Emits the code for matching ops. |
78 | void emitMatchLogic(DagNode tree, StringRef opName); |
79 | |
80 | // Emits the code for rewriting ops. |
81 | void emitRewriteLogic(); |
82 | |
83 | //===--------------------------------------------------------------------===// |
84 | // Match utilities |
85 | //===--------------------------------------------------------------------===// |
86 | |
87 | // Emits C++ statements for matching the DAG structure. |
88 | void emitMatch(DagNode tree, StringRef name, int depth); |
89 | |
90 | // Emit C++ function call to static DAG matcher. |
91 | void emitStaticMatchCall(DagNode tree, StringRef name); |
92 | |
93 | // Emit C++ function call to static type/attribute constraint function. |
94 | void emitStaticVerifierCall(StringRef funcName, StringRef opName, |
95 | StringRef arg, StringRef failureStr); |
96 | |
97 | // Emits C++ statements for matching using a native code call. |
98 | void emitNativeCodeMatch(DagNode tree, StringRef name, int depth); |
99 | |
100 | // Emits C++ statements for matching the op constrained by the given DAG |
101 | // `tree` returning the op's variable name. |
102 | void emitOpMatch(DagNode tree, StringRef opName, int depth); |
103 | |
104 | // Emits C++ statements for matching the `argIndex`-th argument of the given |
105 | // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the |
106 | // bound name and the constraint of the operand respectively. |
107 | void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName, |
108 | int operandIndex, DagLeaf operandMatcher, |
109 | StringRef argName, int argIndex, |
110 | std::optional<int> variadicSubIndex); |
111 | |
112 | // Emits C++ statements for matching the operands which can be matched in |
113 | // either order. |
114 | void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree, |
115 | StringRef opName, int argIndex, int &operandIndex, |
116 | int depth); |
117 | |
118 | // Emits C++ statements for matching a variadic operand. |
119 | void emitVariadicOperandMatch(DagNode tree, DagNode variadicArgTree, |
120 | StringRef opName, int argIndex, |
121 | int &operandIndex, int depth); |
122 | |
123 | // Emits C++ statements for matching the `argIndex`-th argument of the given |
124 | // DAG `tree` as an attribute. |
125 | void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex, |
126 | int depth); |
127 | |
128 | // Emits C++ for checking a match with a corresponding match failure |
129 | // diagnostic. |
130 | void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt, |
131 | const llvm::formatv_object_base &failureFmt); |
132 | |
133 | // Emits C++ for checking a match with a corresponding match failure |
134 | // diagnostics. |
135 | void emitMatchCheck(StringRef opName, const std::string &matchStr, |
136 | const std::string &failureStr); |
137 | |
138 | //===--------------------------------------------------------------------===// |
139 | // Rewrite utilities |
140 | //===--------------------------------------------------------------------===// |
141 | |
142 | // The entry point for handling a result pattern rooted at `resultTree`. This |
143 | // method dispatches to concrete handlers according to `resultTree`'s kind and |
144 | // returns a symbol representing the whole value pack. Callers are expected to |
145 | // further resolve the symbol according to the specific use case. |
146 | // |
147 | // `depth` is the nesting level of `resultTree`; 0 means top-level result |
148 | // pattern. For top-level result pattern, `resultIndex` indicates which result |
149 | // of the matched root op this pattern is intended to replace, which can be |
150 | // used to deduce the result type of the op generated from this result |
151 | // pattern. |
152 | std::string handleResultPattern(DagNode resultTree, int resultIndex, |
153 | int depth); |
154 | |
155 | // Emits the C++ statement to replace the matched DAG with a value built via |
156 | // calling native C++ code. |
157 | std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth); |
158 | |
159 | // Returns the symbol of the old value serving as the replacement. |
160 | StringRef handleReplaceWithValue(DagNode tree); |
161 | |
162 | // Trailing directives are used at the end of DAG node argument lists to |
163 | // specify additional behaviour for op matchers and creators, etc. |
164 | struct TrailingDirectives { |
165 | // DAG node containing the `location` directive. Null if there is none. |
166 | DagNode location; |
167 | |
168 | // DAG node containing the `returnType` directive. Null if there is none. |
169 | DagNode returnType; |
170 | |
171 | // Number of found trailing directives. |
172 | int numDirectives; |
173 | }; |
174 | |
175 | // Collect any trailing directives. |
176 | TrailingDirectives getTrailingDirectives(DagNode tree); |
177 | |
178 | // Returns the location value to use. |
179 | std::string getLocation(TrailingDirectives &tail); |
180 | |
181 | // Returns the location value to use. |
182 | std::string handleLocationDirective(DagNode tree); |
183 | |
184 | // Emit return type argument. |
185 | std::string handleReturnTypeArg(DagNode returnType, int i, int depth); |
186 | |
187 | // Emits the C++ statement to build a new op out of the given DAG `tree` and |
188 | // returns the variable name that this op is assigned to. If the root op in |
189 | // DAG `tree` has a specified name, the created op will be assigned to a |
190 | // variable of the given name. Otherwise, a unique name will be used as the |
191 | // result value name. |
192 | std::string handleOpCreation(DagNode tree, int resultIndex, int depth); |
193 | |
194 | using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>; |
195 | |
196 | // Emits a local variable for each value and attribute to be used for creating |
197 | // an op. |
198 | void createSeparateLocalVarsForOpArgs(DagNode node, |
199 | ChildNodeIndexNameMap &childNodeNames); |
200 | |
201 | // Emits the concrete arguments used to call an op's builder. |
202 | void supplyValuesForOpArgs(DagNode node, |
203 | const ChildNodeIndexNameMap &childNodeNames, |
204 | int depth); |
205 | |
206 | // Emits the local variables for holding all values as a whole and all named |
207 | // attributes as a whole to be used for creating an op. |
208 | void createAggregateLocalVarsForOpArgs( |
209 | DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth); |
210 | |
211 | // Returns the C++ expression to construct a constant attribute of the given |
212 | // `value` for the given attribute kind `attr`. |
213 | std::string handleConstantAttr(Attribute attr, const Twine &value); |
214 | |
215 | // Returns the C++ expression to build an argument from the given DAG `leaf`. |
216 | // `patArgName` is used to bound the argument to the source pattern. |
217 | std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); |
218 | |
219 | //===--------------------------------------------------------------------===// |
220 | // General utilities |
221 | //===--------------------------------------------------------------------===// |
222 | |
223 | // Collects all of the operations within the given dag tree. |
224 | void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); |
225 | |
226 | // Returns a unique symbol for a local variable of the given `op`. |
227 | std::string getUniqueSymbol(const Operator *op); |
228 | |
229 | //===--------------------------------------------------------------------===// |
230 | // Symbol utilities |
231 | //===--------------------------------------------------------------------===// |
232 | |
233 | // Returns how many static values the given DAG `node` correspond to. |
234 | int getNodeValueCount(DagNode node); |
235 | |
236 | private: |
237 | // Pattern instantiation location followed by the location of multiclass |
238 | // prototypes used. This is intended to be used as a whole to |
239 | // PrintFatalError() on errors. |
240 | ArrayRef<SMLoc> loc; |
241 | |
242 | // Op's TableGen Record to wrapper object. |
243 | RecordOperatorMap *opMap; |
244 | |
245 | // Handy wrapper for pattern being emitted. |
246 | Pattern pattern; |
247 | |
248 | // Map for all bound symbols' info. |
249 | SymbolInfoMap symbolInfoMap; |
250 | |
251 | StaticMatcherHelper &staticMatcherHelper; |
252 | |
253 | // The next unused ID for newly created values. |
254 | unsigned nextValueId = 0; |
255 | |
256 | raw_indented_ostream os; |
257 | |
258 | // Format contexts containing placeholder substitutions. |
259 | FmtContext fmtCtx; |
260 | }; |
261 | |
262 | // Tracks DagNode's reference multiple times across patterns. Enables generating |
263 | // static matcher functions for DagNode's referenced multiple times rather than |
264 | // inlining them. |
265 | class StaticMatcherHelper { |
266 | public: |
267 | StaticMatcherHelper(raw_ostream &os, const RecordKeeper &recordKeeper, |
268 | RecordOperatorMap &mapper); |
269 | |
270 | // Determine if we should inline the match logic or delegate to a static |
271 | // function. |
272 | bool useStaticMatcher(DagNode node) { |
273 | // either/variadic node must be associated to the parentOp, thus we can't |
274 | // emit a static matcher rooted at them. |
275 | if (node.isEither() || node.isVariadic()) |
276 | return false; |
277 | |
278 | return refStats[node] > kStaticMatcherThreshold; |
279 | } |
280 | |
281 | // Get the name of the static DAG matcher function corresponding to the node. |
282 | std::string getMatcherName(DagNode node) { |
283 | assert(useStaticMatcher(node)); |
284 | return matcherNames[node]; |
285 | } |
286 | |
287 | // Get the name of static type/attribute verification function. |
288 | StringRef getVerifierName(DagLeaf leaf); |
289 | |
290 | // Collect the `Record`s, i.e., the DRR, so that we can get the information of |
291 | // the duplicated DAGs. |
292 | void addPattern(Record *record); |
293 | |
294 | // Emit all static functions of DAG Matcher. |
295 | void populateStaticMatchers(raw_ostream &os); |
296 | |
297 | // Emit all static functions for Constraints. |
298 | void populateStaticConstraintFunctions(raw_ostream &os); |
299 | |
300 | private: |
301 | static constexpr unsigned kStaticMatcherThreshold = 1; |
302 | |
303 | // Consider two patterns as down below, |
304 | // DagNode_Root_A DagNode_Root_B |
305 | // \ \ |
306 | // DagNode_C DagNode_C |
307 | // \ \ |
308 | // DagNode_D DagNode_D |
309 | // |
310 | // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of |
311 | // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced |
312 | // multiple times so we'll have static matchers for both of them. When we're |
313 | // emitting the match logic for DagNode_C, we will check if DagNode_D has the |
314 | // static matcher generated. If so, then we'll generate a call to the |
315 | // function, inline otherwise. In this case, inlining is not what we want. As |
316 | // a result, generate the static matcher in topological order to ensure all |
317 | // the dependent static matchers are generated and we can avoid accidentally |
318 | // inlining. |
319 | // |
320 | // The topological order of all the DagNodes among all patterns. |
321 | SmallVector<std::pair<DagNode, Record *>> topologicalOrder; |
322 | |
323 | RecordOperatorMap &opMap; |
324 | |
325 | // Records of the static function name of each DagNode |
326 | DenseMap<DagNode, std::string> matcherNames; |
327 | |
328 | // After collecting all the DagNode in each pattern, `refStats` records the |
329 | // number of users for each DagNode. We will generate the static matcher for a |
330 | // DagNode while the number of users exceeds a certain threshold. |
331 | DenseMap<DagNode, unsigned> refStats; |
332 | |
333 | // Number of static matcher generated. This is used to generate a unique name |
334 | // for each DagNode. |
335 | int staticMatcherCounter = 0; |
336 | |
337 | // The DagLeaf which contains type or attr constraint. |
338 | SetVector<DagLeaf> constraints; |
339 | |
340 | // Static type/attribute verification function emitter. |
341 | StaticVerifierFunctionEmitter staticVerifierEmitter; |
342 | }; |
343 | |
344 | } // namespace |
345 | |
346 | PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, |
347 | raw_ostream &os, StaticMatcherHelper &helper) |
348 | : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), |
349 | symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), os(os) { |
350 | fmtCtx.withBuilder(subst: "rewriter" ); |
351 | } |
352 | |
353 | std::string PatternEmitter::handleConstantAttr(Attribute attr, |
354 | const Twine &value) { |
355 | if (!attr.isConstBuildable()) |
356 | PrintFatalError(ErrorLoc: loc, Msg: "Attribute " + attr.getAttrDefName() + |
357 | " does not have the 'constBuilderCall' field" ); |
358 | |
359 | // TODO: Verify the constants here |
360 | return std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fmtCtx, vals: value)); |
361 | } |
362 | |
363 | void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) { |
364 | os << formatv( |
365 | Fmt: "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, " |
366 | "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation " |
367 | "*, 4> &tblgen_ops" , |
368 | Vals&: funcName); |
369 | |
370 | // We pass the reference of the variables that need to be captured. Hence we |
371 | // need to collect all the symbols in the tree first. |
372 | pattern.collectBoundSymbols(tree, infoMap&: symbolInfoMap, /*isSrcPattern=*/true); |
373 | symbolInfoMap.assignUniqueAlternativeNames(); |
374 | for (const auto &info : symbolInfoMap) |
375 | os << formatv(Fmt: ", {0}" , Vals: info.second.getArgDecl(name: info.first)); |
376 | |
377 | os << ") {\n" ; |
378 | os.indent(); |
379 | os << "(void)tblgen_ops;\n" ; |
380 | |
381 | // Note that a static matcher is considered at least one step from the match |
382 | // entry. |
383 | emitMatch(tree, name: "op0" , /*depth=*/1); |
384 | |
385 | os << "return ::mlir::success();\n" ; |
386 | os.unindent(); |
387 | os << "}\n\n" ; |
388 | } |
389 | |
390 | // Helper function to match patterns. |
391 | void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) { |
392 | if (tree.isNativeCodeCall()) { |
393 | emitNativeCodeMatch(tree, name, depth); |
394 | return; |
395 | } |
396 | |
397 | if (tree.isOperation()) { |
398 | emitOpMatch(tree, opName: name, depth); |
399 | return; |
400 | } |
401 | |
402 | PrintFatalError(ErrorLoc: loc, Msg: "encountered non-op, non-NativeCodeCall match." ); |
403 | } |
404 | |
405 | void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) { |
406 | std::string funcName = staticMatcherHelper.getMatcherName(node: tree); |
407 | os << formatv(Fmt: "if(::mlir::failed({0}(rewriter, {1}, tblgen_ops" , Vals&: funcName, |
408 | Vals&: opName); |
409 | |
410 | // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in |
411 | // one pass. |
412 | |
413 | // In general, bound symbol should have the unique name in the pattern but |
414 | // for the operand, binding same symbol to multiple operands imply a |
415 | // constraint at the same time. In this case, we will rename those operands |
416 | // with different names. As a result, we need to collect all the symbolInfos |
417 | // from the DagNode then get the updated name of the local variables from the |
418 | // global symbolInfoMap. |
419 | |
420 | // Collect all the bound symbols in the Dag |
421 | SymbolInfoMap localSymbolMap(loc); |
422 | pattern.collectBoundSymbols(tree, infoMap&: localSymbolMap, /*isSrcPattern=*/true); |
423 | |
424 | for (const auto &info : localSymbolMap) { |
425 | auto name = info.first; |
426 | auto symboInfo = info.second; |
427 | auto ret = symbolInfoMap.findBoundSymbol(key: name, symbolInfo: symboInfo); |
428 | os << formatv(Fmt: ", {0}" , Vals: ret->second.getVarName(name)); |
429 | } |
430 | |
431 | os << "))) {\n" ; |
432 | os.scope().os << "return ::mlir::failure();\n" ; |
433 | os << "}\n" ; |
434 | } |
435 | |
436 | void PatternEmitter::emitStaticVerifierCall(StringRef funcName, |
437 | StringRef opName, StringRef arg, |
438 | StringRef failureStr) { |
439 | os << formatv(Fmt: "if(::mlir::failed({0}(rewriter, {1}, {2}, {3}))) {{\n" , |
440 | Vals&: funcName, Vals&: opName, Vals&: arg, Vals&: failureStr); |
441 | os.scope().os << "return ::mlir::failure();\n" ; |
442 | os << "}\n" ; |
443 | } |
444 | |
445 | // Helper function to match patterns. |
446 | void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, |
447 | int depth) { |
448 | LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: " ); |
449 | LLVM_DEBUG(tree.print(llvm::dbgs())); |
450 | LLVM_DEBUG(llvm::dbgs() << '\n'); |
451 | |
452 | // The order of generating static matcher follows the topological order so |
453 | // that for every dependent DagNode already have their static matcher |
454 | // generated if needed. The reason we check if `getMatcherName(tree).empty()` |
455 | // is when we are generating the static matcher for a DagNode itself. In this |
456 | // case, we need to emit the function body rather than a function call. |
457 | if (staticMatcherHelper.useStaticMatcher(node: tree) && |
458 | !staticMatcherHelper.getMatcherName(node: tree).empty()) { |
459 | emitStaticMatchCall(tree, opName); |
460 | |
461 | // NativeCodeCall will never be at depth 0 so that we don't need to catch |
462 | // the root operation as emitOpMatch(); |
463 | |
464 | return; |
465 | } |
466 | |
467 | // TODO(suderman): iterate through arguments, determine their types, output |
468 | // names. |
469 | SmallVector<std::string, 8> capture; |
470 | |
471 | raw_indented_ostream::DelimitedScope scope(os); |
472 | |
473 | for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { |
474 | std::string argName = formatv(Fmt: "arg{0}_{1}" , Vals&: depth, Vals&: i); |
475 | if (DagNode argTree = tree.getArgAsNestedDag(index: i)) { |
476 | if (argTree.isEither()) |
477 | PrintFatalError(ErrorLoc: loc, Msg: "NativeCodeCall cannot have `either` operands" ); |
478 | if (argTree.isVariadic()) |
479 | PrintFatalError(ErrorLoc: loc, Msg: "NativeCodeCall cannot have `variadic` operands" ); |
480 | |
481 | os << "::mlir::Value " << argName << ";\n" ; |
482 | } else { |
483 | auto leaf = tree.getArgAsLeaf(index: i); |
484 | if (leaf.isAttrMatcher() || leaf.isConstantAttr()) { |
485 | os << "::mlir::Attribute " << argName << ";\n" ; |
486 | } else { |
487 | os << "::mlir::Value " << argName << ";\n" ; |
488 | } |
489 | } |
490 | |
491 | capture.push_back(Elt: std::move(argName)); |
492 | } |
493 | |
494 | auto tail = getTrailingDirectives(tree); |
495 | if (tail.returnType) |
496 | PrintFatalError(ErrorLoc: loc, Msg: "`NativeCodeCall` cannot have return type specifier" ); |
497 | auto locToUse = getLocation(tail); |
498 | |
499 | auto fmt = tree.getNativeCodeTemplate(); |
500 | if (fmt.count(Str: "$_self" ) != 1) |
501 | PrintFatalError(ErrorLoc: loc, Msg: "NativeCodeCall must have $_self as argument for " |
502 | "passing the defining Operation" ); |
503 | |
504 | auto nativeCodeCall = std::string( |
505 | tgfmt(fmt, ctx: &fmtCtx.addSubst(placeholder: "_loc" , subst: locToUse).withSelf(subst: opName.str()), |
506 | params: static_cast<ArrayRef<std::string>>(capture))); |
507 | |
508 | emitMatchCheck(opName, matchStr: formatv(Fmt: "!::mlir::failed({0})" , Vals&: nativeCodeCall), |
509 | failureStr: formatv(Fmt: "\"{0} return ::mlir::failure\"" , Vals&: nativeCodeCall)); |
510 | |
511 | for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { |
512 | auto name = tree.getArgName(index: i); |
513 | if (!name.empty() && name != "_" ) { |
514 | os << formatv(Fmt: "{0} = {1};\n" , Vals&: name, Vals&: capture[i]); |
515 | } |
516 | } |
517 | |
518 | for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { |
519 | std::string argName = capture[i]; |
520 | |
521 | // Handle nested DAG construct first |
522 | if (tree.getArgAsNestedDag(index: i)) { |
523 | PrintFatalError( |
524 | ErrorLoc: loc, Msg: formatv(Fmt: "Matching nested tree in NativeCodecall not support for " |
525 | "{0} as arg {1}" , |
526 | Vals&: argName, Vals&: i)); |
527 | } |
528 | |
529 | DagLeaf leaf = tree.getArgAsLeaf(index: i); |
530 | |
531 | // The parameter for native function doesn't bind any constraints. |
532 | if (leaf.isUnspecified()) |
533 | continue; |
534 | |
535 | auto constraint = leaf.getAsConstraint(); |
536 | |
537 | std::string self; |
538 | if (leaf.isAttrMatcher() || leaf.isConstantAttr()) |
539 | self = argName; |
540 | else |
541 | self = formatv(Fmt: "{0}.getType()" , Vals&: argName); |
542 | StringRef verifier = staticMatcherHelper.getVerifierName(leaf); |
543 | emitStaticVerifierCall( |
544 | funcName: verifier, opName, arg: self, |
545 | failureStr: formatv(Fmt: "\"operand {0} of native code call '{1}' failed to satisfy " |
546 | "constraint: " |
547 | "'{2}'\"" , |
548 | Vals&: i, Vals: tree.getNativeCodeTemplate(), |
549 | Vals: escapeString(value: constraint.getSummary())) |
550 | .str()); |
551 | } |
552 | |
553 | LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n" ); |
554 | } |
555 | |
556 | // Helper function to match patterns. |
557 | void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { |
558 | Operator &op = tree.getDialectOp(mapper: opMap); |
559 | LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" |
560 | << op.getOperationName() << "' at depth " << depth |
561 | << '\n'); |
562 | |
563 | auto getCastedName = [depth]() -> std::string { |
564 | return formatv(Fmt: "castedOp{0}" , Vals: depth); |
565 | }; |
566 | |
567 | // The order of generating static matcher follows the topological order so |
568 | // that for every dependent DagNode already have their static matcher |
569 | // generated if needed. The reason we check if `getMatcherName(tree).empty()` |
570 | // is when we are generating the static matcher for a DagNode itself. In this |
571 | // case, we need to emit the function body rather than a function call. |
572 | if (staticMatcherHelper.useStaticMatcher(node: tree) && |
573 | !staticMatcherHelper.getMatcherName(node: tree).empty()) { |
574 | emitStaticMatchCall(tree, opName); |
575 | // In the codegen of rewriter, we suppose that castedOp0 will capture the |
576 | // root operation. Manually add it if the root DagNode is a static matcher. |
577 | if (depth == 0) |
578 | os << formatv(Fmt: "auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); " |
579 | "(void){2};\n" , |
580 | Vals&: opName, Vals: op.getQualCppClassName(), Vals: getCastedName()); |
581 | return; |
582 | } |
583 | |
584 | std::string castedName = getCastedName(); |
585 | os << formatv(Fmt: "auto {0} = ::llvm::dyn_cast<{2}>({1}); " |
586 | "(void){0};\n" , |
587 | Vals&: castedName, Vals&: opName, Vals: op.getQualCppClassName()); |
588 | |
589 | // Skip the operand matching at depth 0 as the pattern rewriter already does. |
590 | if (depth != 0) |
591 | emitMatchCheck(opName, /*matchStr=*/castedName, |
592 | failureStr: formatv(Fmt: "\"{0} is not {1} type\"" , Vals&: castedName, |
593 | Vals: op.getQualCppClassName())); |
594 | |
595 | // If the operand's name is set, set to that variable. |
596 | auto name = tree.getSymbol(); |
597 | if (!name.empty()) |
598 | os << formatv(Fmt: "{0} = {1};\n" , Vals&: name, Vals&: castedName); |
599 | |
600 | for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; |
601 | ++i, ++opArgIdx) { |
602 | auto opArg = op.getArg(index: opArgIdx); |
603 | std::string argName = formatv(Fmt: "op{0}" , Vals: depth + 1); |
604 | |
605 | // Handle nested DAG construct first |
606 | if (DagNode argTree = tree.getArgAsNestedDag(index: i)) { |
607 | if (argTree.isEither()) { |
608 | emitEitherOperandMatch(tree, eitherArgTree: argTree, opName: castedName, argIndex: opArgIdx, operandIndex&: nextOperand, |
609 | depth); |
610 | ++opArgIdx; |
611 | continue; |
612 | } |
613 | if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: opArg)) { |
614 | if (argTree.isVariadic()) { |
615 | if (!operand->isVariadic()) { |
616 | auto error = formatv(Fmt: "variadic DAG construct can't match op {0}'s " |
617 | "non-variadic operand #{1}" , |
618 | Vals: op.getOperationName(), Vals&: opArgIdx); |
619 | PrintFatalError(ErrorLoc: loc, Msg: error); |
620 | } |
621 | emitVariadicOperandMatch(tree, variadicArgTree: argTree, opName: castedName, argIndex: opArgIdx, |
622 | operandIndex&: nextOperand, depth); |
623 | ++nextOperand; |
624 | continue; |
625 | } |
626 | if (operand->isVariableLength()) { |
627 | auto error = formatv(Fmt: "use nested DAG construct to match op {0}'s " |
628 | "variadic operand #{1} unsupported now" , |
629 | Vals: op.getOperationName(), Vals&: opArgIdx); |
630 | PrintFatalError(ErrorLoc: loc, Msg: error); |
631 | } |
632 | } |
633 | |
634 | os << "{\n" ; |
635 | |
636 | // Attributes don't count for getODSOperands. |
637 | // TODO: Operand is a Value, check if we should remove `getDefiningOp()`. |
638 | os.indent() << formatv( |
639 | Fmt: "auto *{0} = " |
640 | "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n" , |
641 | Vals&: argName, Vals&: castedName, Vals&: nextOperand); |
642 | // Null check of operand's definingOp |
643 | emitMatchCheck( |
644 | opName: castedName, /*matchStr=*/argName, |
645 | failureStr: formatv(Fmt: "\"There's no operation that defines operand {0} of {1}\"" , |
646 | Vals: nextOperand++, Vals&: castedName)); |
647 | emitMatch(tree: argTree, name: argName, depth: depth + 1); |
648 | os << formatv(Fmt: "tblgen_ops.push_back({0});\n" , Vals&: argName); |
649 | os.unindent() << "}\n" ; |
650 | continue; |
651 | } |
652 | |
653 | // Next handle DAG leaf: operand or attribute |
654 | if (opArg.is<NamedTypeConstraint *>()) { |
655 | auto operandName = |
656 | formatv(Fmt: "{0}.getODSOperands({1})" , Vals&: castedName, Vals&: nextOperand); |
657 | emitOperandMatch(tree, opName: castedName, operandName: operandName.str(), operandIndex: opArgIdx, |
658 | /*operandMatcher=*/tree.getArgAsLeaf(index: i), |
659 | /*argName=*/tree.getArgName(index: i), argIndex: opArgIdx, |
660 | /*variadicSubIndex=*/std::nullopt); |
661 | ++nextOperand; |
662 | } else if (opArg.is<NamedAttribute *>()) { |
663 | emitAttributeMatch(tree, opName, argIndex: opArgIdx, depth); |
664 | } else { |
665 | PrintFatalError(ErrorLoc: loc, Msg: "unhandled case when matching op" ); |
666 | } |
667 | } |
668 | LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" |
669 | << op.getOperationName() << "' at depth " << depth |
670 | << '\n'); |
671 | } |
672 | |
673 | void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, |
674 | StringRef operandName, int operandIndex, |
675 | DagLeaf operandMatcher, StringRef argName, |
676 | int argIndex, |
677 | std::optional<int> variadicSubIndex) { |
678 | Operator &op = tree.getDialectOp(mapper: opMap); |
679 | auto *operand = op.getArg(index: operandIndex).get<NamedTypeConstraint *>(); |
680 | |
681 | // If a constraint is specified, we need to generate C++ statements to |
682 | // check the constraint. |
683 | if (!operandMatcher.isUnspecified()) { |
684 | if (!operandMatcher.isOperandMatcher()) |
685 | PrintFatalError( |
686 | ErrorLoc: loc, Msg: formatv(Fmt: "the {1}-th argument of op '{0}' should be an operand" , |
687 | Vals: op.getOperationName(), Vals: argIndex + 1)); |
688 | |
689 | // Only need to verify if the matcher's type is different from the one |
690 | // of op definition. |
691 | Constraint constraint = operandMatcher.getAsConstraint(); |
692 | if (operand->constraint != constraint) { |
693 | if (operand->isVariableLength()) { |
694 | auto error = formatv( |
695 | Fmt: "further constrain op {0}'s variadic operand #{1} unsupported now" , |
696 | Vals: op.getOperationName(), Vals&: argIndex); |
697 | PrintFatalError(ErrorLoc: loc, Msg: error); |
698 | } |
699 | auto self = formatv(Fmt: "(*{0}.begin()).getType()" , Vals&: operandName); |
700 | StringRef verifier = staticMatcherHelper.getVerifierName(leaf: operandMatcher); |
701 | emitStaticVerifierCall( |
702 | funcName: verifier, opName, arg: self.str(), |
703 | failureStr: formatv( |
704 | Fmt: "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"" , |
705 | Vals: operand - op.operand_begin(), Vals: op.getOperationName(), |
706 | Vals: escapeString(value: constraint.getSummary())) |
707 | .str()); |
708 | } |
709 | } |
710 | |
711 | // Capture the value |
712 | // `$_` is a special symbol to ignore op argument matching. |
713 | if (!argName.empty() && argName != "_" ) { |
714 | auto res = symbolInfoMap.findBoundSymbol(key: argName, node: tree, op, argIndex: operandIndex, |
715 | variadicSubIndex); |
716 | if (res == symbolInfoMap.end()) |
717 | PrintFatalError(ErrorLoc: loc, Msg: formatv(Fmt: "symbol not found: {0}" , Vals&: argName)); |
718 | |
719 | os << formatv(Fmt: "{0} = {1};\n" , Vals: res->second.getVarName(name: argName), Vals&: operandName); |
720 | } |
721 | } |
722 | |
723 | void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree, |
724 | StringRef opName, int argIndex, |
725 | int &operandIndex, int depth) { |
726 | constexpr int numEitherArgs = 2; |
727 | if (eitherArgTree.getNumArgs() != numEitherArgs) |
728 | PrintFatalError(ErrorLoc: loc, Msg: "`either` only supports grouping two operands" ); |
729 | |
730 | Operator &op = tree.getDialectOp(mapper: opMap); |
731 | |
732 | std::string codeBuffer; |
733 | llvm::raw_string_ostream tblgenOps(codeBuffer); |
734 | |
735 | std::string lambda = formatv(Fmt: "eitherLambda{0}" , Vals&: depth); |
736 | os << formatv( |
737 | Fmt: "auto {0} = [&](::mlir::OperandRange v0, ::mlir::OperandRange v1) {{\n" , |
738 | Vals&: lambda); |
739 | |
740 | os.indent(); |
741 | |
742 | for (int i = 0; i < numEitherArgs; ++i, ++argIndex) { |
743 | if (DagNode argTree = eitherArgTree.getArgAsNestedDag(index: i)) { |
744 | if (argTree.isEither()) |
745 | PrintFatalError(ErrorLoc: loc, Msg: "either cannot be nested" ); |
746 | |
747 | std::string argName = formatv(Fmt: "local_op_{0}" , Vals&: i).str(); |
748 | |
749 | os << formatv(Fmt: "auto {0} = (*v{1}.begin()).getDefiningOp();\n" , Vals&: argName, |
750 | Vals&: i); |
751 | |
752 | // Indent emitMatchCheck and emitMatch because they declare local |
753 | // variables. |
754 | os << "{\n" ; |
755 | os.indent(); |
756 | |
757 | emitMatchCheck( |
758 | opName, /*matchStr=*/argName, |
759 | failureStr: formatv(Fmt: "\"There's no operation that defines operand {0} of {1}\"" , |
760 | Vals: operandIndex++, Vals&: opName)); |
761 | emitMatch(tree: argTree, name: argName, depth: depth + 1); |
762 | |
763 | os.unindent() << "}\n" ; |
764 | |
765 | // `tblgen_ops` is used to collect the matched operations. In either, we |
766 | // need to queue the operation only if the matching success. Thus we emit |
767 | // the code at the end. |
768 | tblgenOps << formatv(Fmt: "tblgen_ops.push_back({0});\n" , Vals&: argName); |
769 | } else if (op.getArg(index: argIndex).is<NamedTypeConstraint *>()) { |
770 | emitOperandMatch(tree, opName, /*operandName=*/formatv(Fmt: "v{0}" , Vals&: i).str(), |
771 | operandIndex, |
772 | /*operandMatcher=*/eitherArgTree.getArgAsLeaf(index: i), |
773 | /*argName=*/eitherArgTree.getArgName(index: i), argIndex, |
774 | /*variadicSubIndex=*/std::nullopt); |
775 | ++operandIndex; |
776 | } else { |
777 | PrintFatalError(ErrorLoc: loc, Msg: "either can only be applied on operand" ); |
778 | } |
779 | } |
780 | |
781 | os << tblgenOps.str(); |
782 | os << "return ::mlir::success();\n" ; |
783 | os.unindent() << "};\n" ; |
784 | |
785 | os << "{\n" ; |
786 | os.indent(); |
787 | |
788 | os << formatv(Fmt: "auto eitherOperand0 = {0}.getODSOperands({1});\n" , Vals&: opName, |
789 | Vals: operandIndex - 2); |
790 | os << formatv(Fmt: "auto eitherOperand1 = {0}.getODSOperands({1});\n" , Vals&: opName, |
791 | Vals: operandIndex - 1); |
792 | |
793 | os << formatv(Fmt: "if(::mlir::failed({0}(eitherOperand0, eitherOperand1)) && " |
794 | "::mlir::failed({0}(eitherOperand1, " |
795 | "eitherOperand0)))\n" , |
796 | Vals&: lambda); |
797 | os.indent() << "return ::mlir::failure();\n" ; |
798 | |
799 | os.unindent().unindent() << "}\n" ; |
800 | } |
801 | |
802 | void PatternEmitter::emitVariadicOperandMatch(DagNode tree, |
803 | DagNode variadicArgTree, |
804 | StringRef opName, int argIndex, |
805 | int &operandIndex, int depth) { |
806 | Operator &op = tree.getDialectOp(mapper: opMap); |
807 | |
808 | os << "{\n" ; |
809 | os.indent(); |
810 | |
811 | os << formatv(Fmt: "auto variadic_operand_range = {0}.getODSOperands({1});\n" , |
812 | Vals&: opName, Vals&: operandIndex); |
813 | os << formatv(Fmt: "if (variadic_operand_range.size() != {0}) " |
814 | "return ::mlir::failure();\n" , |
815 | Vals: variadicArgTree.getNumArgs()); |
816 | |
817 | StringRef variadicTreeName = variadicArgTree.getSymbol(); |
818 | if (!variadicTreeName.empty()) { |
819 | auto res = |
820 | symbolInfoMap.findBoundSymbol(key: variadicTreeName, node: tree, op, argIndex: operandIndex, |
821 | /*variadicSubIndex=*/std::nullopt); |
822 | if (res == symbolInfoMap.end()) |
823 | PrintFatalError(ErrorLoc: loc, Msg: formatv(Fmt: "symbol not found: {0}" , Vals&: variadicTreeName)); |
824 | |
825 | os << formatv(Fmt: "{0} = variadic_operand_range;\n" , |
826 | Vals: res->second.getVarName(name: variadicTreeName)); |
827 | } |
828 | |
829 | for (int i = 0; i < variadicArgTree.getNumArgs(); ++i) { |
830 | if (DagNode argTree = variadicArgTree.getArgAsNestedDag(index: i)) { |
831 | if (!argTree.isOperation()) |
832 | PrintFatalError(ErrorLoc: loc, Msg: "variadic only accepts operation sub-dags" ); |
833 | |
834 | os << "{\n" ; |
835 | os.indent(); |
836 | |
837 | std::string argName = formatv(Fmt: "local_op_{0}" , Vals&: i).str(); |
838 | os << formatv(Fmt: "auto *{0} = " |
839 | "variadic_operand_range[{1}].getDefiningOp();\n" , |
840 | Vals&: argName, Vals&: i); |
841 | emitMatchCheck( |
842 | opName, /*matchStr=*/argName, |
843 | failureStr: formatv(Fmt: "\"There's no operation that defines variadic operand " |
844 | "{0} (variadic sub-opearnd #{1}) of {2}\"" , |
845 | Vals&: operandIndex, Vals&: i, Vals&: opName)); |
846 | emitMatch(tree: argTree, name: argName, depth: depth + 1); |
847 | os << formatv(Fmt: "tblgen_ops.push_back({0});\n" , Vals&: argName); |
848 | |
849 | os.unindent() << "}\n" ; |
850 | } else if (op.getArg(index: argIndex).is<NamedTypeConstraint *>()) { |
851 | auto operandName = formatv(Fmt: "variadic_operand_range.slice({0}, 1)" , Vals&: i); |
852 | emitOperandMatch(tree, opName, operandName: operandName.str(), operandIndex, |
853 | /*operandMatcher=*/variadicArgTree.getArgAsLeaf(index: i), |
854 | /*argName=*/variadicArgTree.getArgName(index: i), argIndex, variadicSubIndex: i); |
855 | } else { |
856 | PrintFatalError(ErrorLoc: loc, Msg: "variadic can only be applied on operand" ); |
857 | } |
858 | } |
859 | |
860 | os.unindent() << "}\n" ; |
861 | } |
862 | |
863 | void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, |
864 | int argIndex, int depth) { |
865 | Operator &op = tree.getDialectOp(mapper: opMap); |
866 | auto *namedAttr = op.getArg(index: argIndex).get<NamedAttribute *>(); |
867 | const auto &attr = namedAttr->attr; |
868 | |
869 | os << "{\n" ; |
870 | os.indent() << formatv(Fmt: "auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");" |
871 | "(void)tblgen_attr;\n" , |
872 | Vals&: opName, Vals: attr.getStorageType(), Vals&: namedAttr->name); |
873 | |
874 | // TODO: This should use getter method to avoid duplication. |
875 | if (attr.hasDefaultValue()) { |
876 | os << "if (!tblgen_attr) tblgen_attr = " |
877 | << std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fmtCtx, |
878 | vals: attr.getDefaultValue())) |
879 | << ";\n" ; |
880 | } else if (attr.isOptional()) { |
881 | // For a missing attribute that is optional according to definition, we |
882 | // should just capture a mlir::Attribute() to signal the missing state. |
883 | // That is precisely what getDiscardableAttr() returns on missing |
884 | // attributes. |
885 | } else { |
886 | emitMatchCheck(opName, matchFmt: tgfmt(fmt: "tblgen_attr" , ctx: &fmtCtx), |
887 | failureFmt: formatv(Fmt: "\"expected op '{0}' to have attribute '{1}' " |
888 | "of type '{2}'\"" , |
889 | Vals: op.getOperationName(), Vals&: namedAttr->name, |
890 | Vals: attr.getStorageType())); |
891 | } |
892 | |
893 | auto matcher = tree.getArgAsLeaf(index: argIndex); |
894 | if (!matcher.isUnspecified()) { |
895 | if (!matcher.isAttrMatcher()) { |
896 | PrintFatalError( |
897 | ErrorLoc: loc, Msg: formatv(Fmt: "the {1}-th argument of op '{0}' should be an attribute" , |
898 | Vals: op.getOperationName(), Vals: argIndex + 1)); |
899 | } |
900 | |
901 | // If a constraint is specified, we need to generate function call to its |
902 | // static verifier. |
903 | StringRef verifier = staticMatcherHelper.getVerifierName(leaf: matcher); |
904 | if (attr.isOptional()) { |
905 | // Avoid dereferencing null attribute. This is using a simple heuristic to |
906 | // avoid common cases of attempting to dereference null attribute. This |
907 | // will return where there is no check if attribute is null unless the |
908 | // attribute's value is not used. |
909 | // FIXME: This could be improved as some null dereferences could slip |
910 | // through. |
911 | if (!StringRef(matcher.getConditionTemplate()).contains(Other: "!$_self" ) && |
912 | StringRef(matcher.getConditionTemplate()).contains(Other: "$_self" )) { |
913 | os << "if (!tblgen_attr) return ::mlir::failure();\n" ; |
914 | } |
915 | } |
916 | emitStaticVerifierCall( |
917 | funcName: verifier, opName, arg: "tblgen_attr" , |
918 | failureStr: formatv(Fmt: "\"op '{0}' attribute '{1}' failed to satisfy constraint: " |
919 | "'{2}'\"" , |
920 | Vals: op.getOperationName(), Vals&: namedAttr->name, |
921 | Vals: escapeString(value: matcher.getAsConstraint().getSummary())) |
922 | .str()); |
923 | } |
924 | |
925 | // Capture the value |
926 | auto name = tree.getArgName(index: argIndex); |
927 | // `$_` is a special symbol to ignore op argument matching. |
928 | if (!name.empty() && name != "_" ) { |
929 | os << formatv(Fmt: "{0} = tblgen_attr;\n" , Vals&: name); |
930 | } |
931 | |
932 | os.unindent() << "}\n" ; |
933 | } |
934 | |
935 | void PatternEmitter::emitMatchCheck( |
936 | StringRef opName, const FmtObjectBase &matchFmt, |
937 | const llvm::formatv_object_base &failureFmt) { |
938 | emitMatchCheck(opName, matchStr: matchFmt.str(), failureStr: failureFmt.str()); |
939 | } |
940 | |
941 | void PatternEmitter::emitMatchCheck(StringRef opName, |
942 | const std::string &matchStr, |
943 | const std::string &failureStr) { |
944 | |
945 | os << "if (!(" << matchStr << "))" ; |
946 | os.scope(open: "{\n" , close: "\n}\n" ).os << "return rewriter.notifyMatchFailure(" << opName |
947 | << ", [&](::mlir::Diagnostic &diag) {\n diag << " |
948 | << failureStr << ";\n});" ; |
949 | } |
950 | |
951 | void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { |
952 | LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n" ); |
953 | int depth = 0; |
954 | emitMatch(tree, name: opName, depth); |
955 | |
956 | for (auto &appliedConstraint : pattern.getConstraints()) { |
957 | auto &constraint = appliedConstraint.constraint; |
958 | auto &entities = appliedConstraint.entities; |
959 | |
960 | auto condition = constraint.getConditionTemplate(); |
961 | if (isa<TypeConstraint>(Val: constraint)) { |
962 | if (entities.size() != 1) |
963 | PrintFatalError(ErrorLoc: loc, Msg: "type constraint requires exactly one argument" ); |
964 | |
965 | auto self = formatv(Fmt: "({0}.getType())" , |
966 | Vals: symbolInfoMap.getValueAndRangeUse(symbol: entities.front())); |
967 | emitMatchCheck( |
968 | opName, matchFmt: tgfmt(fmt: condition, ctx: &fmtCtx.withSelf(subst: self.str())), |
969 | failureFmt: formatv(Fmt: "\"value entity '{0}' failed to satisfy constraint: '{1}'\"" , |
970 | Vals&: entities.front(), Vals: escapeString(value: constraint.getSummary()))); |
971 | |
972 | } else if (isa<AttrConstraint>(Val: constraint)) { |
973 | PrintFatalError( |
974 | ErrorLoc: loc, Msg: "cannot use AttrConstraint in Pattern multi-entity constraints" ); |
975 | } else { |
976 | // TODO: replace formatv arguments with the exact specified |
977 | // args. |
978 | if (entities.size() > 4) { |
979 | PrintFatalError(ErrorLoc: loc, Msg: "only support up to 4-entity constraints now" ); |
980 | } |
981 | SmallVector<std::string, 4> names; |
982 | int i = 0; |
983 | for (int e = entities.size(); i < e; ++i) |
984 | names.push_back(Elt: symbolInfoMap.getValueAndRangeUse(symbol: entities[i])); |
985 | std::string self = appliedConstraint.self; |
986 | if (!self.empty()) |
987 | self = symbolInfoMap.getValueAndRangeUse(symbol: self); |
988 | for (; i < 4; ++i) |
989 | names.push_back(Elt: "<unused>" ); |
990 | emitMatchCheck(opName, |
991 | matchFmt: tgfmt(fmt: condition, ctx: &fmtCtx.withSelf(subst: self), vals&: names[0], |
992 | vals&: names[1], vals&: names[2], vals&: names[3]), |
993 | failureFmt: formatv(Fmt: "\"entities '{0}' failed to satisfy constraint: " |
994 | "'{1}'\"" , |
995 | Vals: llvm::join(R&: entities, Separator: ", " ), |
996 | Vals: escapeString(value: constraint.getSummary()))); |
997 | } |
998 | } |
999 | |
1000 | // Some of the operands could be bound to the same symbol name, we need |
1001 | // to enforce equality constraint on those. |
1002 | // TODO: we should be able to emit equality checks early |
1003 | // and short circuit unnecessary work if vars are not equal. |
1004 | for (auto symbolInfoIt = symbolInfoMap.begin(); |
1005 | symbolInfoIt != symbolInfoMap.end();) { |
1006 | auto range = symbolInfoMap.getRangeOfEqualElements(key: symbolInfoIt->first); |
1007 | auto startRange = range.first; |
1008 | auto endRange = range.second; |
1009 | |
1010 | auto firstOperand = symbolInfoIt->second.getVarName(name: symbolInfoIt->first); |
1011 | for (++startRange; startRange != endRange; ++startRange) { |
1012 | auto secondOperand = startRange->second.getVarName(name: symbolInfoIt->first); |
1013 | emitMatchCheck( |
1014 | opName, |
1015 | matchStr: formatv(Fmt: "*{0}.begin() == *{1}.begin()" , Vals&: firstOperand, Vals&: secondOperand), |
1016 | failureStr: formatv(Fmt: "\"Operands '{0}' and '{1}' must be equal\"" , Vals&: firstOperand, |
1017 | Vals&: secondOperand)); |
1018 | } |
1019 | |
1020 | symbolInfoIt = endRange; |
1021 | } |
1022 | |
1023 | LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n" ); |
1024 | } |
1025 | |
1026 | void PatternEmitter::collectOps(DagNode tree, |
1027 | llvm::SmallPtrSetImpl<const Operator *> &ops) { |
1028 | // Check if this tree is an operation. |
1029 | if (tree.isOperation()) { |
1030 | const Operator &op = tree.getDialectOp(mapper: opMap); |
1031 | LLVM_DEBUG(llvm::dbgs() |
1032 | << "found operation " << op.getOperationName() << '\n'); |
1033 | ops.insert(Ptr: &op); |
1034 | } |
1035 | |
1036 | // Recurse the arguments of the tree. |
1037 | for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) |
1038 | if (auto child = tree.getArgAsNestedDag(index: i)) |
1039 | collectOps(tree: child, ops); |
1040 | } |
1041 | |
1042 | void PatternEmitter::emit(StringRef rewriteName) { |
1043 | // Get the DAG tree for the source pattern. |
1044 | DagNode sourceTree = pattern.getSourcePattern(); |
1045 | |
1046 | const Operator &rootOp = pattern.getSourceRootOp(); |
1047 | auto rootName = rootOp.getOperationName(); |
1048 | |
1049 | // Collect the set of result operations. |
1050 | llvm::SmallPtrSet<const Operator *, 4> resultOps; |
1051 | LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n" ); |
1052 | for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { |
1053 | collectOps(tree: pattern.getResultPattern(index: i), ops&: resultOps); |
1054 | } |
1055 | LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n" ); |
1056 | |
1057 | // Emit RewritePattern for Pattern. |
1058 | auto locs = pattern.getLocation(); |
1059 | os << formatv(Fmt: "/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n" , |
1060 | Vals: make_range(x: locs.rbegin(), y: locs.rend())); |
1061 | os << formatv(Fmt: R"(struct {0} : public ::mlir::RewritePattern { |
1062 | {0}(::mlir::MLIRContext *context) |
1063 | : ::mlir::RewritePattern("{1}", {2}, context, {{)" , |
1064 | Vals&: rewriteName, Vals&: rootName, Vals: pattern.getBenefit()); |
1065 | // Sort result operators by name. |
1066 | llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(), |
1067 | resultOps.end()); |
1068 | llvm::sort(C&: sortedResultOps, Comp: [&](const Operator *lhs, const Operator *rhs) { |
1069 | return lhs->getOperationName() < rhs->getOperationName(); |
1070 | }); |
1071 | llvm::interleaveComma(c: sortedResultOps, os, each_fn: [&](const Operator *op) { |
1072 | os << '"' << op->getOperationName() << '"'; |
1073 | }); |
1074 | os << "}) {}\n" ; |
1075 | |
1076 | // Emit matchAndRewrite() function. |
1077 | { |
1078 | auto classScope = os.scope(); |
1079 | os.printReindented(str: R"( |
1080 | ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, |
1081 | ::mlir::PatternRewriter &rewriter) const override {)" ) |
1082 | << '\n'; |
1083 | { |
1084 | auto functionScope = os.scope(); |
1085 | |
1086 | // Register all symbols bound in the source pattern. |
1087 | pattern.collectSourcePatternBoundSymbols(infoMap&: symbolInfoMap); |
1088 | |
1089 | LLVM_DEBUG(llvm::dbgs() |
1090 | << "start creating local variables for capturing matches\n" ); |
1091 | os << "// Variables for capturing values and attributes used while " |
1092 | "creating ops\n" ; |
1093 | // Create local variables for storing the arguments and results bound |
1094 | // to symbols. |
1095 | for (const auto &symbolInfoPair : symbolInfoMap) { |
1096 | const auto &symbol = symbolInfoPair.first; |
1097 | const auto &info = symbolInfoPair.second; |
1098 | |
1099 | os << info.getVarDecl(name: symbol); |
1100 | } |
1101 | // TODO: capture ops with consistent numbering so that it can be |
1102 | // reused for fused loc. |
1103 | os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n" ; |
1104 | LLVM_DEBUG(llvm::dbgs() |
1105 | << "done creating local variables for capturing matches\n" ); |
1106 | |
1107 | os << "// Match\n" ; |
1108 | os << "tblgen_ops.push_back(op0);\n" ; |
1109 | emitMatchLogic(tree: sourceTree, opName: "op0" ); |
1110 | |
1111 | os << "\n// Rewrite\n" ; |
1112 | emitRewriteLogic(); |
1113 | |
1114 | os << "return ::mlir::success();\n" ; |
1115 | } |
1116 | os << "};\n" ; |
1117 | } |
1118 | os << "};\n\n" ; |
1119 | } |
1120 | |
1121 | void PatternEmitter::emitRewriteLogic() { |
1122 | LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n" ); |
1123 | const Operator &rootOp = pattern.getSourceRootOp(); |
1124 | int numExpectedResults = rootOp.getNumResults(); |
1125 | int numResultPatterns = pattern.getNumResultPatterns(); |
1126 | |
1127 | // First register all symbols bound to ops generated in result patterns. |
1128 | pattern.collectResultPatternBoundSymbols(infoMap&: symbolInfoMap); |
1129 | |
1130 | // Only the last N static values generated are used to replace the matched |
1131 | // root N-result op. We need to calculate the starting index (of the results |
1132 | // of the matched op) each result pattern is to replace. |
1133 | SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); |
1134 | // If we don't need to replace any value at all, set the replacement starting |
1135 | // index as the number of result patterns so we skip all of them when trying |
1136 | // to replace the matched op's results. |
1137 | int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; |
1138 | for (int i = numResultPatterns - 1; i >= 0; --i) { |
1139 | auto numValues = getNodeValueCount(node: pattern.getResultPattern(index: i)); |
1140 | offsets[i] = offsets[i + 1] - numValues; |
1141 | if (offsets[i] == 0) { |
1142 | if (replStartIndex == -1) |
1143 | replStartIndex = i; |
1144 | } else if (offsets[i] < 0 && offsets[i + 1] > 0) { |
1145 | auto error = formatv( |
1146 | Fmt: "cannot use the same multi-result op '{0}' to generate both " |
1147 | "auxiliary values and values to be used for replacing the matched op" , |
1148 | Vals: pattern.getResultPattern(index: i).getSymbol()); |
1149 | PrintFatalError(ErrorLoc: loc, Msg: error); |
1150 | } |
1151 | } |
1152 | |
1153 | if (offsets.front() > 0) { |
1154 | const char error[] = |
1155 | "not enough values generated to replace the matched op" ; |
1156 | PrintFatalError(ErrorLoc: loc, Msg: error); |
1157 | } |
1158 | |
1159 | os << "auto odsLoc = rewriter.getFusedLoc({" ; |
1160 | for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { |
1161 | os << (i ? ", " : "" ) << "tblgen_ops[" << i << "]->getLoc()" ; |
1162 | } |
1163 | os << "}); (void)odsLoc;\n" ; |
1164 | |
1165 | // Process auxiliary result patterns. |
1166 | for (int i = 0; i < replStartIndex; ++i) { |
1167 | DagNode resultTree = pattern.getResultPattern(index: i); |
1168 | auto val = handleResultPattern(resultTree, resultIndex: offsets[i], depth: 0); |
1169 | // Normal op creation will be streamed to `os` by the above call; but |
1170 | // NativeCodeCall will only be materialized to `os` if it is used. Here |
1171 | // we are handling auxiliary patterns so we want the side effect even if |
1172 | // NativeCodeCall is not replacing matched root op's results. |
1173 | if (resultTree.isNativeCodeCall() && |
1174 | resultTree.getNumReturnsOfNativeCode() == 0) |
1175 | os << val << ";\n" ; |
1176 | } |
1177 | |
1178 | auto processSupplementalPatterns = [&]() { |
1179 | int numSupplementalPatterns = pattern.getNumSupplementalPatterns(); |
1180 | for (int i = 0, offset = -numSupplementalPatterns; |
1181 | i < numSupplementalPatterns; ++i) { |
1182 | DagNode resultTree = pattern.getSupplementalPattern(index: i); |
1183 | auto val = handleResultPattern(resultTree, resultIndex: offset++, depth: 0); |
1184 | if (resultTree.isNativeCodeCall() && |
1185 | resultTree.getNumReturnsOfNativeCode() == 0) |
1186 | os << val << ";\n" ; |
1187 | } |
1188 | }; |
1189 | |
1190 | if (numExpectedResults == 0) { |
1191 | assert(replStartIndex >= numResultPatterns && |
1192 | "invalid auxiliary vs. replacement pattern division!" ); |
1193 | processSupplementalPatterns(); |
1194 | // No result to replace. Just erase the op. |
1195 | os << "rewriter.eraseOp(op0);\n" ; |
1196 | } else { |
1197 | // Process replacement result patterns. |
1198 | os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n" ; |
1199 | for (int i = replStartIndex; i < numResultPatterns; ++i) { |
1200 | DagNode resultTree = pattern.getResultPattern(index: i); |
1201 | auto val = handleResultPattern(resultTree, resultIndex: offsets[i], depth: 0); |
1202 | os << "\n" ; |
1203 | // Resolve each symbol for all range use so that we can loop over them. |
1204 | // We need an explicit cast to `SmallVector` to capture the cases where |
1205 | // `{0}` resolves to an `Operation::result_range` as well as cases that |
1206 | // are not iterable (e.g. vector that gets wrapped in additional braces by |
1207 | // RewriterGen). |
1208 | // TODO: Revisit the need for materializing a vector. |
1209 | os << symbolInfoMap.getAllRangeUse( |
1210 | symbol: val, |
1211 | fmt: "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n" |
1212 | " tblgen_repl_values.push_back(v);\n}\n" , |
1213 | separator: "\n" ); |
1214 | } |
1215 | processSupplementalPatterns(); |
1216 | os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n" ; |
1217 | } |
1218 | |
1219 | LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n" ); |
1220 | } |
1221 | |
1222 | std::string PatternEmitter::getUniqueSymbol(const Operator *op) { |
1223 | return std::string( |
1224 | formatv(Fmt: "tblgen_{0}_{1}" , Vals: op->getCppClassName(), Vals: nextValueId++)); |
1225 | } |
1226 | |
1227 | std::string PatternEmitter::handleResultPattern(DagNode resultTree, |
1228 | int resultIndex, int depth) { |
1229 | LLVM_DEBUG(llvm::dbgs() << "handle result pattern: " ); |
1230 | LLVM_DEBUG(resultTree.print(llvm::dbgs())); |
1231 | LLVM_DEBUG(llvm::dbgs() << '\n'); |
1232 | |
1233 | if (resultTree.isLocationDirective()) { |
1234 | PrintFatalError(ErrorLoc: loc, |
1235 | Msg: "location directive can only be used with op creation" ); |
1236 | } |
1237 | |
1238 | if (resultTree.isNativeCodeCall()) |
1239 | return handleReplaceWithNativeCodeCall(resultTree, depth); |
1240 | |
1241 | if (resultTree.isReplaceWithValue()) |
1242 | return handleReplaceWithValue(tree: resultTree).str(); |
1243 | |
1244 | // Normal op creation. |
1245 | auto symbol = handleOpCreation(tree: resultTree, resultIndex, depth); |
1246 | if (resultTree.getSymbol().empty()) { |
1247 | // This is an op not explicitly bound to a symbol in the rewrite rule. |
1248 | // Register the auto-generated symbol for it. |
1249 | symbolInfoMap.bindOpResult(symbol, op: pattern.getDialectOp(node: resultTree)); |
1250 | } |
1251 | return symbol; |
1252 | } |
1253 | |
1254 | StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) { |
1255 | assert(tree.isReplaceWithValue()); |
1256 | |
1257 | if (tree.getNumArgs() != 1) { |
1258 | PrintFatalError( |
1259 | ErrorLoc: loc, Msg: "replaceWithValue directive must take exactly one argument" ); |
1260 | } |
1261 | |
1262 | if (!tree.getSymbol().empty()) { |
1263 | PrintFatalError(ErrorLoc: loc, Msg: "cannot bind symbol to replaceWithValue" ); |
1264 | } |
1265 | |
1266 | return tree.getArgName(index: 0); |
1267 | } |
1268 | |
1269 | std::string PatternEmitter::handleLocationDirective(DagNode tree) { |
1270 | assert(tree.isLocationDirective()); |
1271 | auto lookUpArgLoc = [this, &tree](int idx) { |
1272 | const auto *const lookupFmt = "{0}.getLoc()" ; |
1273 | return symbolInfoMap.getValueAndRangeUse(symbol: tree.getArgName(index: idx), fmt: lookupFmt); |
1274 | }; |
1275 | |
1276 | if (tree.getNumArgs() == 0) |
1277 | llvm::PrintFatalError( |
1278 | Msg: "At least one argument to location directive required" ); |
1279 | |
1280 | if (!tree.getSymbol().empty()) |
1281 | PrintFatalError(ErrorLoc: loc, Msg: "cannot bind symbol to location" ); |
1282 | |
1283 | if (tree.getNumArgs() == 1) { |
1284 | DagLeaf leaf = tree.getArgAsLeaf(index: 0); |
1285 | if (leaf.isStringAttr()) |
1286 | return formatv(Fmt: "::mlir::NameLoc::get(rewriter.getStringAttr(\"{0}\"))" , |
1287 | Vals: leaf.getStringAttr()) |
1288 | .str(); |
1289 | return lookUpArgLoc(0); |
1290 | } |
1291 | |
1292 | std::string ret; |
1293 | llvm::raw_string_ostream os(ret); |
1294 | std::string strAttr; |
1295 | os << "rewriter.getFusedLoc({" ; |
1296 | bool first = true; |
1297 | for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { |
1298 | DagLeaf leaf = tree.getArgAsLeaf(index: i); |
1299 | // Handle the optional string value. |
1300 | if (leaf.isStringAttr()) { |
1301 | if (!strAttr.empty()) |
1302 | llvm::PrintFatalError(Msg: "Only one string attribute may be specified" ); |
1303 | strAttr = leaf.getStringAttr(); |
1304 | continue; |
1305 | } |
1306 | os << (first ? "" : ", " ) << lookUpArgLoc(i); |
1307 | first = false; |
1308 | } |
1309 | os << "}" ; |
1310 | if (!strAttr.empty()) { |
1311 | os << ", rewriter.getStringAttr(\"" << strAttr << "\")" ; |
1312 | } |
1313 | os << ")" ; |
1314 | return os.str(); |
1315 | } |
1316 | |
1317 | std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i, |
1318 | int depth) { |
1319 | // Nested NativeCodeCall. |
1320 | if (auto dagNode = returnType.getArgAsNestedDag(index: i)) { |
1321 | if (!dagNode.isNativeCodeCall()) |
1322 | PrintFatalError(ErrorLoc: loc, Msg: "nested DAG in `returnType` must be a native code " |
1323 | "call" ); |
1324 | return handleReplaceWithNativeCodeCall(resultTree: dagNode, depth); |
1325 | } |
1326 | // String literal. |
1327 | auto dagLeaf = returnType.getArgAsLeaf(index: i); |
1328 | if (dagLeaf.isStringAttr()) |
1329 | return tgfmt(fmt: dagLeaf.getStringAttr(), ctx: &fmtCtx); |
1330 | return tgfmt( |
1331 | fmt: "$0.getType()" , ctx: &fmtCtx, |
1332 | vals: handleOpArgument(leaf: returnType.getArgAsLeaf(index: i), patArgName: returnType.getArgName(index: i))); |
1333 | } |
1334 | |
1335 | std::string PatternEmitter::handleOpArgument(DagLeaf leaf, |
1336 | StringRef patArgName) { |
1337 | if (leaf.isStringAttr()) |
1338 | PrintFatalError(ErrorLoc: loc, Msg: "raw string not supported as argument" ); |
1339 | if (leaf.isConstantAttr()) { |
1340 | auto constAttr = leaf.getAsConstantAttr(); |
1341 | return handleConstantAttr(attr: constAttr.getAttribute(), |
1342 | value: constAttr.getConstantValue()); |
1343 | } |
1344 | if (leaf.isEnumAttrCase()) { |
1345 | auto enumCase = leaf.getAsEnumAttrCase(); |
1346 | // This is an enum case backed by an IntegerAttr. We need to get its value |
1347 | // to build the constant. |
1348 | std::string val = std::to_string(val: enumCase.getValue()); |
1349 | return handleConstantAttr(attr: enumCase, value: val); |
1350 | } |
1351 | |
1352 | LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n" ); |
1353 | auto argName = symbolInfoMap.getValueAndRangeUse(symbol: patArgName); |
1354 | if (leaf.isUnspecified() || leaf.isOperandMatcher()) { |
1355 | LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName |
1356 | << "' (via symbol ref)\n" ); |
1357 | return argName; |
1358 | } |
1359 | if (leaf.isNativeCodeCall()) { |
1360 | auto repl = tgfmt(fmt: leaf.getNativeCodeTemplate(), ctx: &fmtCtx.withSelf(subst: argName)); |
1361 | LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl |
1362 | << "' (via NativeCodeCall)\n" ); |
1363 | return std::string(repl); |
1364 | } |
1365 | PrintFatalError(ErrorLoc: loc, Msg: "unhandled case when rewriting op" ); |
1366 | } |
1367 | |
1368 | std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree, |
1369 | int depth) { |
1370 | LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: " ); |
1371 | LLVM_DEBUG(tree.print(llvm::dbgs())); |
1372 | LLVM_DEBUG(llvm::dbgs() << '\n'); |
1373 | |
1374 | auto fmt = tree.getNativeCodeTemplate(); |
1375 | |
1376 | SmallVector<std::string, 16> attrs; |
1377 | |
1378 | auto tail = getTrailingDirectives(tree); |
1379 | if (tail.returnType) |
1380 | PrintFatalError(ErrorLoc: loc, Msg: "`NativeCodeCall` cannot have return type specifier" ); |
1381 | auto locToUse = getLocation(tail); |
1382 | |
1383 | for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { |
1384 | if (tree.isNestedDagArg(index: i)) { |
1385 | attrs.push_back( |
1386 | Elt: handleResultPattern(resultTree: tree.getArgAsNestedDag(index: i), resultIndex: i, depth: depth + 1)); |
1387 | } else { |
1388 | attrs.push_back( |
1389 | Elt: handleOpArgument(leaf: tree.getArgAsLeaf(index: i), patArgName: tree.getArgName(index: i))); |
1390 | } |
1391 | LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i |
1392 | << " replacement: " << attrs[i] << "\n" ); |
1393 | } |
1394 | |
1395 | std::string symbol = tgfmt(fmt, ctx: &fmtCtx.addSubst(placeholder: "_loc" , subst: locToUse), |
1396 | params: static_cast<ArrayRef<std::string>>(attrs)); |
1397 | |
1398 | // In general, NativeCodeCall without naming binding don't need this. To |
1399 | // ensure void helper function has been correctly labeled, i.e., use |
1400 | // NativeCodeCallVoid, we cache the result to a local variable so that we will |
1401 | // get a compilation error in the auto-generated file. |
1402 | // Example. |
1403 | // // In the td file |
1404 | // Pat<(...), (NativeCodeCall<Foo> ...)> |
1405 | // |
1406 | // --- |
1407 | // |
1408 | // // In the auto-generated .cpp |
1409 | // ... |
1410 | // // Causes compilation error if Foo() returns void. |
1411 | // auto nativeVar = Foo(); |
1412 | // ... |
1413 | if (tree.getNumReturnsOfNativeCode() != 0) { |
1414 | // Determine the local variable name for return value. |
1415 | std::string varName = |
1416 | SymbolInfoMap::getValuePackName(symbol: tree.getSymbol()).str(); |
1417 | if (varName.empty()) { |
1418 | varName = formatv(Fmt: "nativeVar_{0}" , Vals: nextValueId++); |
1419 | // Register the local variable for later uses. |
1420 | symbolInfoMap.bindValues(symbol: varName, numValues: tree.getNumReturnsOfNativeCode()); |
1421 | } |
1422 | |
1423 | // Catch the return value of helper function. |
1424 | os << formatv(Fmt: "auto {0} = {1}; (void){0};\n" , Vals&: varName, Vals&: symbol); |
1425 | |
1426 | if (!tree.getSymbol().empty()) |
1427 | symbol = tree.getSymbol().str(); |
1428 | else |
1429 | symbol = varName; |
1430 | } |
1431 | |
1432 | return symbol; |
1433 | } |
1434 | |
1435 | int PatternEmitter::getNodeValueCount(DagNode node) { |
1436 | if (node.isOperation()) { |
1437 | // If the op is bound to a symbol in the rewrite rule, query its result |
1438 | // count from the symbol info map. |
1439 | auto symbol = node.getSymbol(); |
1440 | if (!symbol.empty()) { |
1441 | return symbolInfoMap.getStaticValueCount(symbol); |
1442 | } |
1443 | // Otherwise this is an unbound op; we will use all its results. |
1444 | return pattern.getDialectOp(node).getNumResults(); |
1445 | } |
1446 | |
1447 | if (node.isNativeCodeCall()) |
1448 | return node.getNumReturnsOfNativeCode(); |
1449 | |
1450 | return 1; |
1451 | } |
1452 | |
1453 | PatternEmitter::TrailingDirectives |
1454 | PatternEmitter::getTrailingDirectives(DagNode tree) { |
1455 | TrailingDirectives tail = {.location: DagNode(nullptr), .returnType: DagNode(nullptr), .numDirectives: 0}; |
1456 | |
1457 | // Look backwards through the arguments. |
1458 | auto numPatArgs = tree.getNumArgs(); |
1459 | for (int i = numPatArgs - 1; i >= 0; --i) { |
1460 | auto dagArg = tree.getArgAsNestedDag(index: i); |
1461 | // A leaf is not a directive. Stop looking. |
1462 | if (!dagArg) |
1463 | break; |
1464 | |
1465 | auto isLocation = dagArg.isLocationDirective(); |
1466 | auto isReturnType = dagArg.isReturnTypeDirective(); |
1467 | // If encountered a DAG node that isn't a trailing directive, stop looking. |
1468 | if (!(isLocation || isReturnType)) |
1469 | break; |
1470 | // Save the directive, but error if one of the same type was already |
1471 | // found. |
1472 | ++tail.numDirectives; |
1473 | if (isLocation) { |
1474 | if (tail.location) |
1475 | PrintFatalError(ErrorLoc: loc, Msg: "`location` directive can only be specified " |
1476 | "once" ); |
1477 | tail.location = dagArg; |
1478 | } else if (isReturnType) { |
1479 | if (tail.returnType) |
1480 | PrintFatalError(ErrorLoc: loc, Msg: "`returnType` directive can only be specified " |
1481 | "once" ); |
1482 | tail.returnType = dagArg; |
1483 | } |
1484 | } |
1485 | |
1486 | return tail; |
1487 | } |
1488 | |
1489 | std::string |
1490 | PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) { |
1491 | if (tail.location) |
1492 | return handleLocationDirective(tree: tail.location); |
1493 | |
1494 | // If no explicit location is given, use the default, all fused, location. |
1495 | return "odsLoc" ; |
1496 | } |
1497 | |
1498 | std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, |
1499 | int depth) { |
1500 | LLVM_DEBUG(llvm::dbgs() << "create op for pattern: " ); |
1501 | LLVM_DEBUG(tree.print(llvm::dbgs())); |
1502 | LLVM_DEBUG(llvm::dbgs() << '\n'); |
1503 | |
1504 | Operator &resultOp = tree.getDialectOp(mapper: opMap); |
1505 | auto numOpArgs = resultOp.getNumArgs(); |
1506 | auto numPatArgs = tree.getNumArgs(); |
1507 | |
1508 | auto tail = getTrailingDirectives(tree); |
1509 | auto locToUse = getLocation(tail); |
1510 | |
1511 | auto inPattern = numPatArgs - tail.numDirectives; |
1512 | if (numOpArgs != inPattern) { |
1513 | PrintFatalError(ErrorLoc: loc, |
1514 | Msg: formatv(Fmt: "resultant op '{0}' argument number mismatch: " |
1515 | "{1} in pattern vs. {2} in definition" , |
1516 | Vals: resultOp.getOperationName(), Vals&: inPattern, Vals&: numOpArgs)); |
1517 | } |
1518 | |
1519 | // A map to collect all nested DAG child nodes' names, with operand index as |
1520 | // the key. This includes both bound and unbound child nodes. |
1521 | ChildNodeIndexNameMap childNodeNames; |
1522 | |
1523 | // If the argument is a type constraint, then its an operand. Check if the |
1524 | // op's argument is variadic that the argument in the pattern is too. |
1525 | auto checkIfMatchedVariadic = [&](int i) { |
1526 | // FIXME: This does not yet check for variable/leaf case. |
1527 | // FIXME: Change so that native code call can be handled. |
1528 | const auto *operand = |
1529 | llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val: resultOp.getArg(index: i)); |
1530 | if (!operand || !operand->isVariadic()) |
1531 | return; |
1532 | |
1533 | auto child = tree.getArgAsNestedDag(index: i); |
1534 | if (!child) |
1535 | return; |
1536 | |
1537 | // Skip over replaceWithValues. |
1538 | while (child.isReplaceWithValue()) { |
1539 | if (!(child = child.getArgAsNestedDag(index: 0))) |
1540 | return; |
1541 | } |
1542 | if (!child.isNativeCodeCall() && !child.isVariadic()) |
1543 | PrintFatalError(ErrorLoc: loc, Msg: formatv(Fmt: "op expects variadic operand `{0}`, while " |
1544 | "provided is non-variadic" , |
1545 | Vals: resultOp.getArgName(index: i))); |
1546 | }; |
1547 | |
1548 | // First go through all the child nodes who are nested DAG constructs to |
1549 | // create ops for them and remember the symbol names for them, so that we can |
1550 | // use the results in the current node. This happens in a recursive manner. |
1551 | for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { |
1552 | checkIfMatchedVariadic(i); |
1553 | if (auto child = tree.getArgAsNestedDag(index: i)) |
1554 | childNodeNames[i] = handleResultPattern(resultTree: child, resultIndex: i, depth: depth + 1); |
1555 | } |
1556 | |
1557 | // The name of the local variable holding this op. |
1558 | std::string valuePackName; |
1559 | // The symbol for holding the result of this pattern. Note that the result of |
1560 | // this pattern is not necessarily the same as the variable created by this |
1561 | // pattern because we can use `__N` suffix to refer only a specific result if |
1562 | // the generated op is a multi-result op. |
1563 | std::string resultValue; |
1564 | if (tree.getSymbol().empty()) { |
1565 | // No symbol is explicitly bound to this op in the pattern. Generate a |
1566 | // unique name. |
1567 | valuePackName = resultValue = getUniqueSymbol(op: &resultOp); |
1568 | } else { |
1569 | resultValue = std::string(tree.getSymbol()); |
1570 | // Strip the index to get the name for the value pack and use it to name the |
1571 | // local variable for the op. |
1572 | valuePackName = std::string(SymbolInfoMap::getValuePackName(symbol: resultValue)); |
1573 | } |
1574 | |
1575 | // Create the local variable for this op. |
1576 | os << formatv(Fmt: "{0} {1};\n{{\n" , Vals: resultOp.getQualCppClassName(), |
1577 | Vals&: valuePackName); |
1578 | |
1579 | // Right now ODS don't have general type inference support. Except a few |
1580 | // special cases listed below, DRR needs to supply types for all results |
1581 | // when building an op. |
1582 | bool isSameOperandsAndResultType = |
1583 | resultOp.getTrait(trait: "::mlir::OpTrait::SameOperandsAndResultType" ); |
1584 | bool useFirstAttr = |
1585 | resultOp.getTrait(trait: "::mlir::OpTrait::FirstAttrDerivedResultType" ); |
1586 | |
1587 | if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) { |
1588 | // We know how to deduce the result type for ops with these traits and we've |
1589 | // generated builders taking aggregate parameters. Use those builders to |
1590 | // create the ops. |
1591 | |
1592 | // First prepare local variables for op arguments used in builder call. |
1593 | createAggregateLocalVarsForOpArgs(node: tree, childNodeNames, depth); |
1594 | |
1595 | // Then create the op. |
1596 | os.scope(open: "" , close: "\n}\n" ).os << formatv( |
1597 | Fmt: "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);" , |
1598 | Vals&: valuePackName, Vals: resultOp.getQualCppClassName(), Vals&: locToUse); |
1599 | return resultValue; |
1600 | } |
1601 | |
1602 | bool usePartialResults = valuePackName != resultValue; |
1603 | |
1604 | if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) { |
1605 | // For these cases (broadcastable ops, op results used both as auxiliary |
1606 | // values and replacement values, ops in nested patterns, auxiliary ops), we |
1607 | // still need to supply the result types when building the op. But because |
1608 | // we don't generate a builder automatically with ODS for them, it's the |
1609 | // developer's responsibility to make sure such a builder (with result type |
1610 | // deduction ability) exists. We go through the separate-parameter builder |
1611 | // here given that it's easier for developers to write compared to |
1612 | // aggregate-parameter builders. |
1613 | createSeparateLocalVarsForOpArgs(node: tree, childNodeNames); |
1614 | |
1615 | os.scope().os << formatv(Fmt: "{0} = rewriter.create<{1}>({2}" , Vals&: valuePackName, |
1616 | Vals: resultOp.getQualCppClassName(), Vals&: locToUse); |
1617 | supplyValuesForOpArgs(node: tree, childNodeNames, depth); |
1618 | os << "\n );\n}\n" ; |
1619 | return resultValue; |
1620 | } |
1621 | |
1622 | // If we are provided explicit return types, use them to build the op. |
1623 | // However, if depth == 0 and resultIndex >= 0, it means we are replacing |
1624 | // the values generated from the source pattern root op. Then we must use the |
1625 | // source pattern's value types to determine the value type of the generated |
1626 | // op here. |
1627 | if (depth == 0 && resultIndex >= 0 && tail.returnType) |
1628 | PrintFatalError(ErrorLoc: loc, Msg: "Cannot specify explicit return types in an op whose " |
1629 | "return values replace the source pattern's root op" ); |
1630 | |
1631 | // First prepare local variables for op arguments used in builder call. |
1632 | createAggregateLocalVarsForOpArgs(node: tree, childNodeNames, depth); |
1633 | |
1634 | // Then prepare the result types. We need to specify the types for all |
1635 | // results. |
1636 | os.indent() << formatv(Fmt: "::llvm::SmallVector<::mlir::Type, 4> tblgen_types; " |
1637 | "(void)tblgen_types;\n" ); |
1638 | int numResults = resultOp.getNumResults(); |
1639 | if (tail.returnType) { |
1640 | auto numRetTys = tail.returnType.getNumArgs(); |
1641 | for (int i = 0; i < numRetTys; ++i) { |
1642 | auto varName = handleReturnTypeArg(returnType: tail.returnType, i, depth: depth + 1); |
1643 | os << "tblgen_types.push_back(" << varName << ");\n" ; |
1644 | } |
1645 | } else { |
1646 | if (numResults != 0) { |
1647 | // Copy the result types from the source pattern. |
1648 | for (int i = 0; i < numResults; ++i) |
1649 | os << formatv(Fmt: "for (auto v: castedOp0.getODSResults({0})) {{\n" |
1650 | " tblgen_types.push_back(v.getType());\n}\n" , |
1651 | Vals: resultIndex + i); |
1652 | } |
1653 | } |
1654 | os << formatv(Fmt: "{0} = rewriter.create<{1}>({2}, tblgen_types, " |
1655 | "tblgen_values, tblgen_attrs);\n" , |
1656 | Vals&: valuePackName, Vals: resultOp.getQualCppClassName(), Vals&: locToUse); |
1657 | os.unindent() << "}\n" ; |
1658 | return resultValue; |
1659 | } |
1660 | |
1661 | void PatternEmitter::createSeparateLocalVarsForOpArgs( |
1662 | DagNode node, ChildNodeIndexNameMap &childNodeNames) { |
1663 | Operator &resultOp = node.getDialectOp(mapper: opMap); |
1664 | |
1665 | // Now prepare operands used for building this op: |
1666 | // * If the operand is non-variadic, we create a `Value` local variable. |
1667 | // * If the operand is variadic, we create a `SmallVector<Value>` local |
1668 | // variable. |
1669 | |
1670 | int valueIndex = 0; // An index for uniquing local variable names. |
1671 | for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { |
1672 | const auto *operand = |
1673 | llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val: resultOp.getArg(index: argIndex)); |
1674 | // We do not need special handling for attributes. |
1675 | if (!operand) |
1676 | continue; |
1677 | |
1678 | raw_indented_ostream::DelimitedScope scope(os); |
1679 | std::string varName; |
1680 | if (operand->isVariadic()) { |
1681 | varName = std::string(formatv(Fmt: "tblgen_values_{0}" , Vals: valueIndex++)); |
1682 | os << formatv(Fmt: "::llvm::SmallVector<::mlir::Value, 4> {0};\n" , Vals&: varName); |
1683 | std::string range; |
1684 | if (node.isNestedDagArg(index: argIndex)) { |
1685 | range = childNodeNames[argIndex]; |
1686 | } else { |
1687 | range = std::string(node.getArgName(index: argIndex)); |
1688 | } |
1689 | // Resolve the symbol for all range use so that we have a uniform way of |
1690 | // capturing the values. |
1691 | range = symbolInfoMap.getValueAndRangeUse(symbol: range); |
1692 | os << formatv(Fmt: "for (auto v: {0}) {{\n {1}.push_back(v);\n}\n" , Vals&: range, |
1693 | Vals&: varName); |
1694 | } else { |
1695 | varName = std::string(formatv(Fmt: "tblgen_value_{0}" , Vals: valueIndex++)); |
1696 | os << formatv(Fmt: "::mlir::Value {0} = " , Vals&: varName); |
1697 | if (node.isNestedDagArg(index: argIndex)) { |
1698 | os << symbolInfoMap.getValueAndRangeUse(symbol: childNodeNames[argIndex]); |
1699 | } else { |
1700 | DagLeaf leaf = node.getArgAsLeaf(index: argIndex); |
1701 | auto symbol = |
1702 | symbolInfoMap.getValueAndRangeUse(symbol: node.getArgName(index: argIndex)); |
1703 | if (leaf.isNativeCodeCall()) { |
1704 | os << std::string( |
1705 | tgfmt(fmt: leaf.getNativeCodeTemplate(), ctx: &fmtCtx.withSelf(subst: symbol))); |
1706 | } else { |
1707 | os << symbol; |
1708 | } |
1709 | } |
1710 | os << ";\n" ; |
1711 | } |
1712 | |
1713 | // Update to use the newly created local variable for building the op later. |
1714 | childNodeNames[argIndex] = varName; |
1715 | } |
1716 | } |
1717 | |
1718 | void PatternEmitter::supplyValuesForOpArgs( |
1719 | DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { |
1720 | Operator &resultOp = node.getDialectOp(mapper: opMap); |
1721 | for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); |
1722 | argIndex != numOpArgs; ++argIndex) { |
1723 | // Start each argument on its own line. |
1724 | os << ",\n " ; |
1725 | |
1726 | Argument opArg = resultOp.getArg(index: argIndex); |
1727 | // Handle the case of operand first. |
1728 | if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: opArg)) { |
1729 | if (!operand->name.empty()) |
1730 | os << "/*" << operand->name << "=*/" ; |
1731 | os << childNodeNames.lookup(Val: argIndex); |
1732 | continue; |
1733 | } |
1734 | |
1735 | // The argument in the op definition. |
1736 | auto opArgName = resultOp.getArgName(index: argIndex); |
1737 | if (auto subTree = node.getArgAsNestedDag(index: argIndex)) { |
1738 | if (!subTree.isNativeCodeCall()) |
1739 | PrintFatalError(ErrorLoc: loc, Msg: "only NativeCodeCall allowed in nested dag node " |
1740 | "for creating attribute" ); |
1741 | os << formatv(Fmt: "/*{0}=*/{1}" , Vals&: opArgName, Vals: childNodeNames.lookup(Val: argIndex)); |
1742 | } else { |
1743 | auto leaf = node.getArgAsLeaf(index: argIndex); |
1744 | // The argument in the result DAG pattern. |
1745 | auto patArgName = node.getArgName(index: argIndex); |
1746 | if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { |
1747 | // TODO: Refactor out into map to avoid recomputing these. |
1748 | if (!opArg.is<NamedAttribute *>()) |
1749 | PrintFatalError(ErrorLoc: loc, Msg: Twine("expected attribute " ) + Twine(argIndex)); |
1750 | if (!patArgName.empty()) |
1751 | os << "/*" << patArgName << "=*/" ; |
1752 | } else { |
1753 | os << "/*" << opArgName << "=*/" ; |
1754 | } |
1755 | os << handleOpArgument(leaf, patArgName); |
1756 | } |
1757 | } |
1758 | } |
1759 | |
1760 | void PatternEmitter::createAggregateLocalVarsForOpArgs( |
1761 | DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { |
1762 | Operator &resultOp = node.getDialectOp(mapper: opMap); |
1763 | |
1764 | auto scope = os.scope(); |
1765 | os << formatv(Fmt: "::llvm::SmallVector<::mlir::Value, 4> " |
1766 | "tblgen_values; (void)tblgen_values;\n" ); |
1767 | os << formatv(Fmt: "::llvm::SmallVector<::mlir::NamedAttribute, 4> " |
1768 | "tblgen_attrs; (void)tblgen_attrs;\n" ); |
1769 | |
1770 | const char *addAttrCmd = |
1771 | "if (auto tmpAttr = {1}) {\n" |
1772 | " tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), " |
1773 | "tmpAttr);\n}\n" ; |
1774 | int numVariadic = 0; |
1775 | bool hasOperandSegmentSizes = false; |
1776 | std::vector<std::string> sizes; |
1777 | for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { |
1778 | if (resultOp.getArg(index: argIndex).is<NamedAttribute *>()) { |
1779 | // The argument in the op definition. |
1780 | auto opArgName = resultOp.getArgName(index: argIndex); |
1781 | hasOperandSegmentSizes = |
1782 | hasOperandSegmentSizes || opArgName == "operandSegmentSizes" ; |
1783 | if (auto subTree = node.getArgAsNestedDag(index: argIndex)) { |
1784 | if (!subTree.isNativeCodeCall()) |
1785 | PrintFatalError(ErrorLoc: loc, Msg: "only NativeCodeCall allowed in nested dag node " |
1786 | "for creating attribute" ); |
1787 | os << formatv(Fmt: addAttrCmd, Vals&: opArgName, Vals: childNodeNames.lookup(Val: argIndex)); |
1788 | } else { |
1789 | auto leaf = node.getArgAsLeaf(index: argIndex); |
1790 | // The argument in the result DAG pattern. |
1791 | auto patArgName = node.getArgName(index: argIndex); |
1792 | os << formatv(Fmt: addAttrCmd, Vals&: opArgName, |
1793 | Vals: handleOpArgument(leaf, patArgName)); |
1794 | } |
1795 | continue; |
1796 | } |
1797 | |
1798 | const auto *operand = |
1799 | resultOp.getArg(index: argIndex).get<NamedTypeConstraint *>(); |
1800 | std::string varName; |
1801 | if (operand->isVariadic()) { |
1802 | ++numVariadic; |
1803 | std::string range; |
1804 | if (node.isNestedDagArg(index: argIndex)) { |
1805 | range = childNodeNames.lookup(Val: argIndex); |
1806 | } else { |
1807 | range = std::string(node.getArgName(index: argIndex)); |
1808 | } |
1809 | // Resolve the symbol for all range use so that we have a uniform way of |
1810 | // capturing the values. |
1811 | range = symbolInfoMap.getValueAndRangeUse(symbol: range); |
1812 | os << formatv(Fmt: "for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n" , |
1813 | Vals&: range); |
1814 | sizes.push_back(x: formatv(Fmt: "static_cast<int32_t>({0}.size())" , Vals&: range)); |
1815 | } else { |
1816 | sizes.emplace_back(args: "1" ); |
1817 | os << formatv(Fmt: "tblgen_values.push_back(" ); |
1818 | if (node.isNestedDagArg(index: argIndex)) { |
1819 | os << symbolInfoMap.getValueAndRangeUse( |
1820 | symbol: childNodeNames.lookup(Val: argIndex)); |
1821 | } else { |
1822 | DagLeaf leaf = node.getArgAsLeaf(index: argIndex); |
1823 | if (leaf.isConstantAttr()) |
1824 | // TODO: Use better location |
1825 | PrintFatalError( |
1826 | ErrorLoc: loc, |
1827 | Msg: "attribute found where value was expected, if attempting to use " |
1828 | "constant value, construct a constant op with given attribute " |
1829 | "instead" ); |
1830 | |
1831 | auto symbol = |
1832 | symbolInfoMap.getValueAndRangeUse(symbol: node.getArgName(index: argIndex)); |
1833 | if (leaf.isNativeCodeCall()) { |
1834 | os << std::string( |
1835 | tgfmt(fmt: leaf.getNativeCodeTemplate(), ctx: &fmtCtx.withSelf(subst: symbol))); |
1836 | } else { |
1837 | os << symbol; |
1838 | } |
1839 | } |
1840 | os << ");\n" ; |
1841 | } |
1842 | } |
1843 | |
1844 | if (numVariadic > 1 && !hasOperandSegmentSizes) { |
1845 | // Only set size if it can't be computed. |
1846 | const auto *sameVariadicSize = |
1847 | resultOp.getTrait(trait: "::mlir::OpTrait::SameVariadicOperandSize" ); |
1848 | if (!sameVariadicSize) { |
1849 | const char *setSizes = R"( |
1850 | tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"), |
1851 | rewriter.getDenseI32ArrayAttr({{ {0} })); |
1852 | )" ; |
1853 | os.printReindented(str: formatv(Fmt: setSizes, Vals: llvm::join(R&: sizes, Separator: ", " )).str()); |
1854 | } |
1855 | } |
1856 | } |
1857 | |
1858 | StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os, |
1859 | const RecordKeeper &recordKeeper, |
1860 | RecordOperatorMap &mapper) |
1861 | : opMap(mapper), staticVerifierEmitter(os, recordKeeper) {} |
1862 | |
1863 | void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) { |
1864 | // PatternEmitter will use the static matcher if there's one generated. To |
1865 | // ensure that all the dependent static matchers are generated before emitting |
1866 | // the matching logic of the DagNode, we use topological order to achieve it. |
1867 | for (auto &dagInfo : topologicalOrder) { |
1868 | DagNode node = dagInfo.first; |
1869 | if (!useStaticMatcher(node)) |
1870 | continue; |
1871 | |
1872 | std::string funcName = |
1873 | formatv(Fmt: "static_dag_matcher_{0}" , Vals: staticMatcherCounter++); |
1874 | assert(!matcherNames.contains(node)); |
1875 | PatternEmitter(dagInfo.second, &opMap, os, *this) |
1876 | .emitStaticMatcher(tree: node, funcName); |
1877 | matcherNames[node] = funcName; |
1878 | } |
1879 | } |
1880 | |
1881 | void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) { |
1882 | staticVerifierEmitter.emitPatternConstraints(constraints: constraints.getArrayRef()); |
1883 | } |
1884 | |
1885 | void StaticMatcherHelper::addPattern(Record *record) { |
1886 | Pattern pat(record, &opMap); |
1887 | |
1888 | // While generating the function body of the DAG matcher, it may depends on |
1889 | // other DAG matchers. To ensure the dependent matchers are ready, we compute |
1890 | // the topological order for all the DAGs and emit the DAG matchers in this |
1891 | // order. |
1892 | llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) { |
1893 | ++refStats[node]; |
1894 | |
1895 | if (refStats[node] != 1) |
1896 | return; |
1897 | |
1898 | for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i) |
1899 | if (DagNode sibling = node.getArgAsNestedDag(index: i)) |
1900 | dfs(sibling); |
1901 | else { |
1902 | DagLeaf leaf = node.getArgAsLeaf(index: i); |
1903 | if (!leaf.isUnspecified()) |
1904 | constraints.insert(X: leaf); |
1905 | } |
1906 | |
1907 | topologicalOrder.push_back(Elt: std::make_pair(x&: node, y&: record)); |
1908 | }; |
1909 | |
1910 | dfs(pat.getSourcePattern()); |
1911 | } |
1912 | |
1913 | StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) { |
1914 | if (leaf.isAttrMatcher()) { |
1915 | std::optional<StringRef> constraint = |
1916 | staticVerifierEmitter.getAttrConstraintFn(constraint: leaf.getAsConstraint()); |
1917 | assert(constraint && "attribute constraint was not uniqued" ); |
1918 | return *constraint; |
1919 | } |
1920 | assert(leaf.isOperandMatcher()); |
1921 | return staticVerifierEmitter.getTypeConstraintFn(constraint: leaf.getAsConstraint()); |
1922 | } |
1923 | |
1924 | static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { |
1925 | emitSourceFileHeader(Desc: "Rewriters" , OS&: os, Record: recordKeeper); |
1926 | |
1927 | const auto &patterns = recordKeeper.getAllDerivedDefinitions(ClassName: "Pattern" ); |
1928 | |
1929 | // We put the map here because it can be shared among multiple patterns. |
1930 | RecordOperatorMap recordOpMap; |
1931 | |
1932 | // Exam all the patterns and generate static matcher for the duplicated |
1933 | // DagNode. |
1934 | StaticMatcherHelper staticMatcher(os, recordKeeper, recordOpMap); |
1935 | for (Record *p : patterns) |
1936 | staticMatcher.addPattern(record: p); |
1937 | staticMatcher.populateStaticConstraintFunctions(os); |
1938 | staticMatcher.populateStaticMatchers(os); |
1939 | |
1940 | std::vector<std::string> rewriterNames; |
1941 | rewriterNames.reserve(n: patterns.size()); |
1942 | |
1943 | std::string baseRewriterName = "GeneratedConvert" ; |
1944 | int rewriterIndex = 0; |
1945 | |
1946 | for (Record *p : patterns) { |
1947 | std::string name; |
1948 | if (p->isAnonymous()) { |
1949 | // If no name is provided, ensure unique rewriter names simply by |
1950 | // appending unique suffix. |
1951 | name = baseRewriterName + llvm::utostr(X: rewriterIndex++); |
1952 | } else { |
1953 | name = std::string(p->getName()); |
1954 | } |
1955 | LLVM_DEBUG(llvm::dbgs() |
1956 | << "=== start generating pattern '" << name << "' ===\n" ); |
1957 | PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(rewriteName: name); |
1958 | LLVM_DEBUG(llvm::dbgs() |
1959 | << "=== done generating pattern '" << name << "' ===\n" ); |
1960 | rewriterNames.push_back(x: std::move(name)); |
1961 | } |
1962 | |
1963 | // Emit function to add the generated matchers to the pattern list. |
1964 | os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(" |
1965 | "::mlir::RewritePatternSet &patterns) {\n" ; |
1966 | for (const auto &name : rewriterNames) { |
1967 | os << " patterns.add<" << name << ">(patterns.getContext());\n" ; |
1968 | } |
1969 | os << "}\n" ; |
1970 | } |
1971 | |
1972 | static mlir::GenRegistration |
1973 | genRewriters("gen-rewriters" , "Generate pattern rewriters" , |
1974 | [](const RecordKeeper &records, raw_ostream &os) { |
1975 | emitRewriters(recordKeeper: records, os); |
1976 | return false; |
1977 | }); |
1978 | |