1 | //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Support/IndentedOstream.h" |
14 | #include "mlir/TableGen/Argument.h" |
15 | #include "mlir/TableGen/Attribute.h" |
16 | #include "mlir/TableGen/CodeGenHelpers.h" |
17 | #include "mlir/TableGen/Format.h" |
18 | #include "mlir/TableGen/GenInfo.h" |
19 | #include "mlir/TableGen/Operator.h" |
20 | #include "mlir/TableGen/Pattern.h" |
21 | #include "mlir/TableGen/Predicate.h" |
22 | #include "mlir/TableGen/Property.h" |
23 | #include "mlir/TableGen/Type.h" |
24 | #include "llvm/ADT/FunctionExtras.h" |
25 | #include "llvm/ADT/SetVector.h" |
26 | #include "llvm/ADT/StringExtras.h" |
27 | #include "llvm/ADT/StringSet.h" |
28 | #include "llvm/Support/CommandLine.h" |
29 | #include "llvm/Support/Debug.h" |
30 | #include "llvm/Support/FormatAdapters.h" |
31 | #include "llvm/Support/PrettyStackTrace.h" |
32 | #include "llvm/Support/Signals.h" |
33 | #include "llvm/TableGen/Error.h" |
34 | #include "llvm/TableGen/Main.h" |
35 | #include "llvm/TableGen/Record.h" |
36 | #include "llvm/TableGen/TableGenBackend.h" |
37 | |
38 | using namespace mlir; |
39 | using namespace mlir::tblgen; |
40 | |
41 | using llvm::formatv; |
42 | using llvm::Record; |
43 | using llvm::RecordKeeper; |
44 | |
45 | #define DEBUG_TYPE "mlir-tblgen-rewritergen" |
46 | |
47 | namespace llvm { |
48 | template <> |
49 | struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { |
50 | static void format(const mlir::tblgen::Pattern::IdentifierLine &v, |
51 | raw_ostream &os, StringRef style) { |
52 | os << v.first << ":" << v.second; |
53 | } |
54 | }; |
55 | } // namespace llvm |
56 | |
57 | //===----------------------------------------------------------------------===// |
58 | // PatternEmitter |
59 | //===----------------------------------------------------------------------===// |
60 | |
61 | namespace { |
62 | |
63 | class StaticMatcherHelper; |
64 | |
65 | class PatternEmitter { |
66 | public: |
67 | PatternEmitter(const Record *pat, RecordOperatorMap *mapper, raw_ostream &os, |
68 | StaticMatcherHelper &helper); |
69 | |
70 | // Emits the mlir::RewritePattern struct named `rewriteName`. |
71 | void emit(StringRef rewriteName); |
72 | |
73 | // Emits the static function of DAG matcher. |
74 | void emitStaticMatcher(DagNode tree, std::string funcName); |
75 | |
76 | private: |
77 | // Emits the code for matching ops. |
78 | void emitMatchLogic(DagNode tree, StringRef opName); |
79 | |
80 | // Emits the code for rewriting ops. |
81 | void emitRewriteLogic(); |
82 | |
83 | //===--------------------------------------------------------------------===// |
84 | // Match utilities |
85 | //===--------------------------------------------------------------------===// |
86 | |
87 | // Emits C++ statements for matching the DAG structure. |
88 | void emitMatch(DagNode tree, StringRef name, int depth); |
89 | |
90 | // Emit C++ function call to static DAG matcher. |
91 | void emitStaticMatchCall(DagNode tree, StringRef name); |
92 | |
93 | // Emit C++ function call to static type/attribute constraint function. |
94 | void emitStaticVerifierCall(StringRef funcName, StringRef opName, |
95 | StringRef arg, StringRef failureStr); |
96 | |
97 | // Emits C++ statements for matching using a native code call. |
98 | void emitNativeCodeMatch(DagNode tree, StringRef name, int depth); |
99 | |
100 | // Emits C++ statements for matching the op constrained by the given DAG |
101 | // `tree` returning the op's variable name. |
102 | void emitOpMatch(DagNode tree, StringRef opName, int depth); |
103 | |
104 | // Emits C++ statements for matching the `argIndex`-th argument of the given |
105 | // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the |
106 | // bound name and the constraint of the operand respectively. |
107 | void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName, |
108 | int operandIndex, DagLeaf operandMatcher, |
109 | StringRef argName, int argIndex, |
110 | std::optional<int> variadicSubIndex); |
111 | |
112 | // Emits C++ statements for matching the operands which can be matched in |
113 | // either order. |
114 | void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree, |
115 | StringRef opName, int argIndex, int &operandIndex, |
116 | int depth); |
117 | |
118 | // Emits C++ statements for matching a variadic operand. |
119 | void emitVariadicOperandMatch(DagNode tree, DagNode variadicArgTree, |
120 | StringRef opName, int argIndex, |
121 | int &operandIndex, int depth); |
122 | |
123 | // Emits C++ statements for matching the `argIndex`-th argument of the given |
124 | // DAG `tree` as an attribute. |
125 | void emitAttributeMatch(DagNode tree, StringRef castedName, int argIndex, |
126 | int depth); |
127 | |
128 | // Emits C++ for checking a match with a corresponding match failure |
129 | // diagnostic. |
130 | void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt, |
131 | const llvm::formatv_object_base &failureFmt); |
132 | |
133 | // Emits C++ for checking a match with a corresponding match failure |
134 | // diagnostics. |
135 | void emitMatchCheck(StringRef opName, const std::string &matchStr, |
136 | const std::string &failureStr); |
137 | |
138 | //===--------------------------------------------------------------------===// |
139 | // Rewrite utilities |
140 | //===--------------------------------------------------------------------===// |
141 | |
142 | // The entry point for handling a result pattern rooted at `resultTree`. This |
143 | // method dispatches to concrete handlers according to `resultTree`'s kind and |
144 | // returns a symbol representing the whole value pack. Callers are expected to |
145 | // further resolve the symbol according to the specific use case. |
146 | // |
147 | // `depth` is the nesting level of `resultTree`; 0 means top-level result |
148 | // pattern. For top-level result pattern, `resultIndex` indicates which result |
149 | // of the matched root op this pattern is intended to replace, which can be |
150 | // used to deduce the result type of the op generated from this result |
151 | // pattern. |
152 | std::string handleResultPattern(DagNode resultTree, int resultIndex, |
153 | int depth); |
154 | |
155 | // Emits the C++ statement to replace the matched DAG with a value built via |
156 | // calling native C++ code. |
157 | std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth); |
158 | |
159 | // Returns the symbol of the old value serving as the replacement. |
160 | StringRef handleReplaceWithValue(DagNode tree); |
161 | |
162 | // Emits the C++ statement to replace the matched DAG with an array of |
163 | // matched values. |
164 | std::string handleVariadic(DagNode tree, int depth); |
165 | |
166 | // Trailing directives are used at the end of DAG node argument lists to |
167 | // specify additional behaviour for op matchers and creators, etc. |
168 | struct TrailingDirectives { |
169 | // DAG node containing the `location` directive. Null if there is none. |
170 | DagNode location; |
171 | |
172 | // DAG node containing the `returnType` directive. Null if there is none. |
173 | DagNode returnType; |
174 | |
175 | // Number of found trailing directives. |
176 | int numDirectives; |
177 | }; |
178 | |
179 | // Collect any trailing directives. |
180 | TrailingDirectives getTrailingDirectives(DagNode tree); |
181 | |
182 | // Returns the location value to use. |
183 | std::string getLocation(TrailingDirectives &tail); |
184 | |
185 | // Returns the location value to use. |
186 | std::string handleLocationDirective(DagNode tree); |
187 | |
188 | // Emit return type argument. |
189 | std::string handleReturnTypeArg(DagNode returnType, int i, int depth); |
190 | |
191 | // Emits the C++ statement to build a new op out of the given DAG `tree` and |
192 | // returns the variable name that this op is assigned to. If the root op in |
193 | // DAG `tree` has a specified name, the created op will be assigned to a |
194 | // variable of the given name. Otherwise, a unique name will be used as the |
195 | // result value name. |
196 | std::string handleOpCreation(DagNode tree, int resultIndex, int depth); |
197 | |
198 | using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>; |
199 | |
200 | // Emits a local variable for each value and attribute to be used for creating |
201 | // an op. |
202 | void createSeparateLocalVarsForOpArgs(DagNode node, |
203 | ChildNodeIndexNameMap &childNodeNames); |
204 | |
205 | // Emits the concrete arguments used to call an op's builder. |
206 | void supplyValuesForOpArgs(DagNode node, |
207 | const ChildNodeIndexNameMap &childNodeNames, |
208 | int depth); |
209 | |
210 | // Emits the local variables for holding all values as a whole and all named |
211 | // attributes as a whole to be used for creating an op. |
212 | void createAggregateLocalVarsForOpArgs( |
213 | DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth); |
214 | |
215 | // Returns the C++ expression to construct a constant attribute of the given |
216 | // `value` for the given attribute kind `attr`. |
217 | std::string handleConstantAttr(Attribute attr, const Twine &value); |
218 | |
219 | // Returns the C++ expression to build an argument from the given DAG `leaf`. |
220 | // `patArgName` is used to bound the argument to the source pattern. |
221 | std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); |
222 | |
223 | //===--------------------------------------------------------------------===// |
224 | // General utilities |
225 | //===--------------------------------------------------------------------===// |
226 | |
227 | // Collects all of the operations within the given dag tree. |
228 | void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); |
229 | |
230 | // Returns a unique symbol for a local variable of the given `op`. |
231 | std::string getUniqueSymbol(const Operator *op); |
232 | |
233 | //===--------------------------------------------------------------------===// |
234 | // Symbol utilities |
235 | //===--------------------------------------------------------------------===// |
236 | |
237 | // Returns how many static values the given DAG `node` correspond to. |
238 | int getNodeValueCount(DagNode node); |
239 | |
240 | private: |
241 | // Pattern instantiation location followed by the location of multiclass |
242 | // prototypes used. This is intended to be used as a whole to |
243 | // PrintFatalError() on errors. |
244 | ArrayRef<SMLoc> loc; |
245 | |
246 | // Op's TableGen Record to wrapper object. |
247 | RecordOperatorMap *opMap; |
248 | |
249 | // Handy wrapper for pattern being emitted. |
250 | Pattern pattern; |
251 | |
252 | // Map for all bound symbols' info. |
253 | SymbolInfoMap symbolInfoMap; |
254 | |
255 | StaticMatcherHelper &staticMatcherHelper; |
256 | |
257 | // The next unused ID for newly created values. |
258 | unsigned nextValueId = 0; |
259 | |
260 | raw_indented_ostream os; |
261 | |
262 | // Format contexts containing placeholder substitutions. |
263 | FmtContext fmtCtx; |
264 | }; |
265 | |
266 | // Tracks DagNode's reference multiple times across patterns. Enables generating |
267 | // static matcher functions for DagNode's referenced multiple times rather than |
268 | // inlining them. |
269 | class StaticMatcherHelper { |
270 | public: |
271 | StaticMatcherHelper(raw_ostream &os, const RecordKeeper &records, |
272 | RecordOperatorMap &mapper); |
273 | |
274 | // Determine if we should inline the match logic or delegate to a static |
275 | // function. |
276 | bool useStaticMatcher(DagNode node) { |
277 | // either/variadic node must be associated to the parentOp, thus we can't |
278 | // emit a static matcher rooted at them. |
279 | if (node.isEither() || node.isVariadic()) |
280 | return false; |
281 | |
282 | return refStats[node] > kStaticMatcherThreshold; |
283 | } |
284 | |
285 | // Get the name of the static DAG matcher function corresponding to the node. |
286 | std::string getMatcherName(DagNode node) { |
287 | assert(useStaticMatcher(node)); |
288 | return matcherNames[node]; |
289 | } |
290 | |
291 | // Get the name of static type/attribute verification function. |
292 | StringRef getVerifierName(DagLeaf leaf); |
293 | |
294 | // Collect the `Record`s, i.e., the DRR, so that we can get the information of |
295 | // the duplicated DAGs. |
296 | void addPattern(const Record *record); |
297 | |
298 | // Emit all static functions of DAG Matcher. |
299 | void populateStaticMatchers(raw_ostream &os); |
300 | |
301 | // Emit all static functions for Constraints. |
302 | void populateStaticConstraintFunctions(raw_ostream &os); |
303 | |
304 | private: |
305 | static constexpr unsigned kStaticMatcherThreshold = 1; |
306 | |
307 | // Consider two patterns as down below, |
308 | // DagNode_Root_A DagNode_Root_B |
309 | // \ \ |
310 | // DagNode_C DagNode_C |
311 | // \ \ |
312 | // DagNode_D DagNode_D |
313 | // |
314 | // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of |
315 | // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced |
316 | // multiple times so we'll have static matchers for both of them. When we're |
317 | // emitting the match logic for DagNode_C, we will check if DagNode_D has the |
318 | // static matcher generated. If so, then we'll generate a call to the |
319 | // function, inline otherwise. In this case, inlining is not what we want. As |
320 | // a result, generate the static matcher in topological order to ensure all |
321 | // the dependent static matchers are generated and we can avoid accidentally |
322 | // inlining. |
323 | // |
324 | // The topological order of all the DagNodes among all patterns. |
325 | SmallVector<std::pair<DagNode, const Record *>> topologicalOrder; |
326 | |
327 | RecordOperatorMap &opMap; |
328 | |
329 | // Records of the static function name of each DagNode |
330 | DenseMap<DagNode, std::string> matcherNames; |
331 | |
332 | // After collecting all the DagNode in each pattern, `refStats` records the |
333 | // number of users for each DagNode. We will generate the static matcher for a |
334 | // DagNode while the number of users exceeds a certain threshold. |
335 | DenseMap<DagNode, unsigned> refStats; |
336 | |
337 | // Number of static matcher generated. This is used to generate a unique name |
338 | // for each DagNode. |
339 | int staticMatcherCounter = 0; |
340 | |
341 | // The DagLeaf which contains type or attr constraint. |
342 | SetVector<DagLeaf> constraints; |
343 | |
344 | // Static type/attribute verification function emitter. |
345 | StaticVerifierFunctionEmitter staticVerifierEmitter; |
346 | }; |
347 | |
348 | } // namespace |
349 | |
350 | PatternEmitter::PatternEmitter(const Record *pat, RecordOperatorMap *mapper, |
351 | raw_ostream &os, StaticMatcherHelper &helper) |
352 | : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), |
353 | symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), os(os) { |
354 | fmtCtx.withBuilder(subst: "rewriter" ); |
355 | } |
356 | |
357 | std::string PatternEmitter::handleConstantAttr(Attribute attr, |
358 | const Twine &value) { |
359 | if (!attr.isConstBuildable()) |
360 | PrintFatalError(ErrorLoc: loc, Msg: "Attribute " + attr.getAttrDefName() + |
361 | " does not have the 'constBuilderCall' field" ); |
362 | |
363 | // TODO: Verify the constants here |
364 | return std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fmtCtx, vals: value)); |
365 | } |
366 | |
367 | void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) { |
368 | os << formatv( |
369 | Fmt: "static ::llvm::LogicalResult {0}(::mlir::PatternRewriter &rewriter, " |
370 | "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation " |
371 | "*, 4> &tblgen_ops" , |
372 | Vals&: funcName); |
373 | |
374 | // We pass the reference of the variables that need to be captured. Hence we |
375 | // need to collect all the symbols in the tree first. |
376 | pattern.collectBoundSymbols(tree, infoMap&: symbolInfoMap, /*isSrcPattern=*/true); |
377 | symbolInfoMap.assignUniqueAlternativeNames(); |
378 | for (const auto &info : symbolInfoMap) |
379 | os << formatv(Fmt: ", {0}" , Vals: info.second.getArgDecl(name: info.first)); |
380 | |
381 | os << ") {\n" ; |
382 | os.indent(); |
383 | os << "(void)tblgen_ops;\n" ; |
384 | |
385 | // Note that a static matcher is considered at least one step from the match |
386 | // entry. |
387 | emitMatch(tree, name: "op0" , /*depth=*/1); |
388 | |
389 | os << "return ::mlir::success();\n" ; |
390 | os.unindent(); |
391 | os << "}\n\n" ; |
392 | } |
393 | |
394 | // Helper function to match patterns. |
395 | void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) { |
396 | if (tree.isNativeCodeCall()) { |
397 | emitNativeCodeMatch(tree, name, depth); |
398 | return; |
399 | } |
400 | |
401 | if (tree.isOperation()) { |
402 | emitOpMatch(tree, opName: name, depth); |
403 | return; |
404 | } |
405 | |
406 | PrintFatalError(ErrorLoc: loc, Msg: "encountered non-op, non-NativeCodeCall match." ); |
407 | } |
408 | |
409 | void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) { |
410 | std::string funcName = staticMatcherHelper.getMatcherName(node: tree); |
411 | os << formatv(Fmt: "if(::mlir::failed({0}(rewriter, {1}, tblgen_ops" , Vals&: funcName, |
412 | Vals&: opName); |
413 | |
414 | // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in |
415 | // one pass. |
416 | |
417 | // In general, bound symbol should have the unique name in the pattern but |
418 | // for the operand, binding same symbol to multiple operands imply a |
419 | // constraint at the same time. In this case, we will rename those operands |
420 | // with different names. As a result, we need to collect all the symbolInfos |
421 | // from the DagNode then get the updated name of the local variables from the |
422 | // global symbolInfoMap. |
423 | |
424 | // Collect all the bound symbols in the Dag |
425 | SymbolInfoMap localSymbolMap(loc); |
426 | pattern.collectBoundSymbols(tree, infoMap&: localSymbolMap, /*isSrcPattern=*/true); |
427 | |
428 | for (const auto &info : localSymbolMap) { |
429 | auto name = info.first; |
430 | auto symboInfo = info.second; |
431 | auto ret = symbolInfoMap.findBoundSymbol(key: name, symbolInfo: symboInfo); |
432 | os << formatv(Fmt: ", {0}" , Vals: ret->second.getVarName(name)); |
433 | } |
434 | |
435 | os << "))) {\n" ; |
436 | os.scope().os << "return ::mlir::failure();\n" ; |
437 | os << "}\n" ; |
438 | } |
439 | |
440 | void PatternEmitter::emitStaticVerifierCall(StringRef funcName, |
441 | StringRef opName, StringRef arg, |
442 | StringRef failureStr) { |
443 | os << formatv(Fmt: "if(::mlir::failed({0}(rewriter, {1}, {2}, {3}))) {{\n" , |
444 | Vals&: funcName, Vals&: opName, Vals&: arg, Vals&: failureStr); |
445 | os.scope().os << "return ::mlir::failure();\n" ; |
446 | os << "}\n" ; |
447 | } |
448 | |
449 | // Helper function to match patterns. |
450 | void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, |
451 | int depth) { |
452 | LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: " ); |
453 | LLVM_DEBUG(tree.print(llvm::dbgs())); |
454 | LLVM_DEBUG(llvm::dbgs() << '\n'); |
455 | |
456 | // The order of generating static matcher follows the topological order so |
457 | // that for every dependent DagNode already have their static matcher |
458 | // generated if needed. The reason we check if `getMatcherName(tree).empty()` |
459 | // is when we are generating the static matcher for a DagNode itself. In this |
460 | // case, we need to emit the function body rather than a function call. |
461 | if (staticMatcherHelper.useStaticMatcher(node: tree) && |
462 | !staticMatcherHelper.getMatcherName(node: tree).empty()) { |
463 | emitStaticMatchCall(tree, opName); |
464 | |
465 | // NativeCodeCall will never be at depth 0 so that we don't need to catch |
466 | // the root operation as emitOpMatch(); |
467 | |
468 | return; |
469 | } |
470 | |
471 | // TODO(suderman): iterate through arguments, determine their types, output |
472 | // names. |
473 | SmallVector<std::string, 8> capture; |
474 | |
475 | raw_indented_ostream::DelimitedScope scope(os); |
476 | |
477 | for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { |
478 | std::string argName = formatv(Fmt: "arg{0}_{1}" , Vals&: depth, Vals&: i); |
479 | if (DagNode argTree = tree.getArgAsNestedDag(index: i)) { |
480 | if (argTree.isEither()) |
481 | PrintFatalError(ErrorLoc: loc, Msg: "NativeCodeCall cannot have `either` operands" ); |
482 | if (argTree.isVariadic()) |
483 | PrintFatalError(ErrorLoc: loc, Msg: "NativeCodeCall cannot have `variadic` operands" ); |
484 | |
485 | os << "::mlir::Value " << argName << ";\n" ; |
486 | } else { |
487 | auto leaf = tree.getArgAsLeaf(index: i); |
488 | if (leaf.isAttrMatcher() || leaf.isConstantAttr()) { |
489 | os << "::mlir::Attribute " << argName << ";\n" ; |
490 | } else { |
491 | os << "::mlir::Value " << argName << ";\n" ; |
492 | } |
493 | } |
494 | |
495 | capture.push_back(Elt: std::move(argName)); |
496 | } |
497 | |
498 | auto tail = getTrailingDirectives(tree); |
499 | if (tail.returnType) |
500 | PrintFatalError(ErrorLoc: loc, Msg: "`NativeCodeCall` cannot have return type specifier" ); |
501 | auto locToUse = getLocation(tail); |
502 | |
503 | auto fmt = tree.getNativeCodeTemplate(); |
504 | if (fmt.count(Str: "$_self" ) != 1) |
505 | PrintFatalError(ErrorLoc: loc, Msg: "NativeCodeCall must have $_self as argument for " |
506 | "passing the defining Operation" ); |
507 | |
508 | auto nativeCodeCall = std::string( |
509 | tgfmt(fmt, ctx: &fmtCtx.addSubst(placeholder: "_loc" , subst: locToUse).withSelf(subst: opName.str()), |
510 | params: static_cast<ArrayRef<std::string>>(capture))); |
511 | |
512 | emitMatchCheck(opName, matchStr: formatv(Fmt: "!::mlir::failed({0})" , Vals&: nativeCodeCall), |
513 | failureStr: formatv(Fmt: "\"{0} return ::mlir::failure\"" , Vals&: nativeCodeCall)); |
514 | |
515 | for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { |
516 | auto name = tree.getArgName(index: i); |
517 | if (!name.empty() && name != "_" ) { |
518 | os << formatv(Fmt: "{0} = {1};\n" , Vals&: name, Vals&: capture[i]); |
519 | } |
520 | } |
521 | |
522 | for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { |
523 | std::string argName = capture[i]; |
524 | |
525 | // Handle nested DAG construct first |
526 | if (tree.getArgAsNestedDag(index: i)) { |
527 | PrintFatalError( |
528 | ErrorLoc: loc, Msg: formatv(Fmt: "Matching nested tree in NativeCodecall not support for " |
529 | "{0} as arg {1}" , |
530 | Vals&: argName, Vals&: i)); |
531 | } |
532 | |
533 | DagLeaf leaf = tree.getArgAsLeaf(index: i); |
534 | |
535 | // The parameter for native function doesn't bind any constraints. |
536 | if (leaf.isUnspecified()) |
537 | continue; |
538 | |
539 | auto constraint = leaf.getAsConstraint(); |
540 | |
541 | std::string self; |
542 | if (leaf.isAttrMatcher() || leaf.isConstantAttr()) |
543 | self = argName; |
544 | else |
545 | self = formatv(Fmt: "{0}.getType()" , Vals&: argName); |
546 | StringRef verifier = staticMatcherHelper.getVerifierName(leaf); |
547 | emitStaticVerifierCall( |
548 | funcName: verifier, opName, arg: self, |
549 | failureStr: formatv(Fmt: "\"operand {0} of native code call '{1}' failed to satisfy " |
550 | "constraint: " |
551 | "'{2}'\"" , |
552 | Vals&: i, Vals: tree.getNativeCodeTemplate(), |
553 | Vals: escapeString(value: constraint.getSummary())) |
554 | .str()); |
555 | } |
556 | |
557 | LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n" ); |
558 | } |
559 | |
560 | // Helper function to match patterns. |
561 | void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { |
562 | Operator &op = tree.getDialectOp(mapper: opMap); |
563 | LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" |
564 | << op.getOperationName() << "' at depth " << depth |
565 | << '\n'); |
566 | |
567 | auto getCastedName = [depth]() -> std::string { |
568 | return formatv(Fmt: "castedOp{0}" , Vals: depth); |
569 | }; |
570 | |
571 | // The order of generating static matcher follows the topological order so |
572 | // that for every dependent DagNode already have their static matcher |
573 | // generated if needed. The reason we check if `getMatcherName(tree).empty()` |
574 | // is when we are generating the static matcher for a DagNode itself. In this |
575 | // case, we need to emit the function body rather than a function call. |
576 | if (staticMatcherHelper.useStaticMatcher(node: tree) && |
577 | !staticMatcherHelper.getMatcherName(node: tree).empty()) { |
578 | emitStaticMatchCall(tree, opName); |
579 | // In the codegen of rewriter, we suppose that castedOp0 will capture the |
580 | // root operation. Manually add it if the root DagNode is a static matcher. |
581 | if (depth == 0) |
582 | os << formatv(Fmt: "auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); " |
583 | "(void){2};\n" , |
584 | Vals&: opName, Vals: op.getQualCppClassName(), Vals: getCastedName()); |
585 | return; |
586 | } |
587 | |
588 | std::string castedName = getCastedName(); |
589 | os << formatv(Fmt: "auto {0} = ::llvm::dyn_cast<{2}>({1}); " |
590 | "(void){0};\n" , |
591 | Vals&: castedName, Vals&: opName, Vals: op.getQualCppClassName()); |
592 | |
593 | // Skip the operand matching at depth 0 as the pattern rewriter already does. |
594 | if (depth != 0) |
595 | emitMatchCheck(opName, /*matchStr=*/castedName, |
596 | failureStr: formatv(Fmt: "\"{0} is not {1} type\"" , Vals&: castedName, |
597 | Vals: op.getQualCppClassName())); |
598 | |
599 | // If the operand's name is set, set to that variable. |
600 | auto name = tree.getSymbol(); |
601 | if (!name.empty()) |
602 | os << formatv(Fmt: "{0} = {1};\n" , Vals&: name, Vals&: castedName); |
603 | |
604 | for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; |
605 | ++i, ++opArgIdx) { |
606 | auto opArg = op.getArg(index: opArgIdx); |
607 | std::string argName = formatv(Fmt: "op{0}" , Vals: depth + 1); |
608 | |
609 | // Handle nested DAG construct first |
610 | if (DagNode argTree = tree.getArgAsNestedDag(index: i)) { |
611 | if (argTree.isEither()) { |
612 | emitEitherOperandMatch(tree, eitherArgTree: argTree, opName: castedName, argIndex: opArgIdx, operandIndex&: nextOperand, |
613 | depth); |
614 | ++opArgIdx; |
615 | continue; |
616 | } |
617 | if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: opArg)) { |
618 | if (argTree.isVariadic()) { |
619 | if (!operand->isVariadic()) { |
620 | auto error = formatv(Fmt: "variadic DAG construct can't match op {0}'s " |
621 | "non-variadic operand #{1}" , |
622 | Vals: op.getOperationName(), Vals&: opArgIdx); |
623 | PrintFatalError(ErrorLoc: loc, Msg: error); |
624 | } |
625 | emitVariadicOperandMatch(tree, variadicArgTree: argTree, opName: castedName, argIndex: opArgIdx, |
626 | operandIndex&: nextOperand, depth); |
627 | ++nextOperand; |
628 | continue; |
629 | } |
630 | if (operand->isVariableLength()) { |
631 | auto error = formatv(Fmt: "use nested DAG construct to match op {0}'s " |
632 | "variadic operand #{1} unsupported now" , |
633 | Vals: op.getOperationName(), Vals&: opArgIdx); |
634 | PrintFatalError(ErrorLoc: loc, Msg: error); |
635 | } |
636 | } |
637 | |
638 | os << "{\n" ; |
639 | |
640 | // Attributes don't count for getODSOperands. |
641 | // TODO: Operand is a Value, check if we should remove `getDefiningOp()`. |
642 | os.indent() << formatv( |
643 | Fmt: "auto *{0} = " |
644 | "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n" , |
645 | Vals&: argName, Vals&: castedName, Vals&: nextOperand); |
646 | // Null check of operand's definingOp |
647 | emitMatchCheck( |
648 | opName: castedName, /*matchStr=*/argName, |
649 | failureStr: formatv(Fmt: "\"There's no operation that defines operand {0} of {1}\"" , |
650 | Vals: nextOperand++, Vals&: castedName)); |
651 | emitMatch(tree: argTree, name: argName, depth: depth + 1); |
652 | os << formatv(Fmt: "tblgen_ops.push_back({0});\n" , Vals&: argName); |
653 | os.unindent() << "}\n" ; |
654 | continue; |
655 | } |
656 | |
657 | // Next handle DAG leaf: operand or attribute |
658 | if (isa<NamedTypeConstraint *>(Val: opArg)) { |
659 | auto operandName = |
660 | formatv(Fmt: "{0}.getODSOperands({1})" , Vals&: castedName, Vals&: nextOperand); |
661 | emitOperandMatch(tree, opName: castedName, operandName: operandName.str(), operandIndex: nextOperand, |
662 | /*operandMatcher=*/tree.getArgAsLeaf(index: i), |
663 | /*argName=*/tree.getArgName(index: i), argIndex: opArgIdx, |
664 | /*variadicSubIndex=*/std::nullopt); |
665 | ++nextOperand; |
666 | } else if (isa<NamedAttribute *>(Val: opArg)) { |
667 | emitAttributeMatch(tree, castedName, argIndex: opArgIdx, depth); |
668 | } else { |
669 | PrintFatalError(ErrorLoc: loc, Msg: "unhandled case when matching op" ); |
670 | } |
671 | } |
672 | LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" |
673 | << op.getOperationName() << "' at depth " << depth |
674 | << '\n'); |
675 | } |
676 | |
677 | void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, |
678 | StringRef operandName, int operandIndex, |
679 | DagLeaf operandMatcher, StringRef argName, |
680 | int argIndex, |
681 | std::optional<int> variadicSubIndex) { |
682 | Operator &op = tree.getDialectOp(mapper: opMap); |
683 | NamedTypeConstraint operand = op.getOperand(index: operandIndex); |
684 | |
685 | // If a constraint is specified, we need to generate C++ statements to |
686 | // check the constraint. |
687 | if (!operandMatcher.isUnspecified()) { |
688 | if (!operandMatcher.isOperandMatcher()) |
689 | PrintFatalError( |
690 | ErrorLoc: loc, Msg: formatv(Fmt: "the {1}-th argument of op '{0}' should be an operand" , |
691 | Vals: op.getOperationName(), Vals: argIndex + 1)); |
692 | |
693 | // Only need to verify if the matcher's type is different from the one |
694 | // of op definition. |
695 | Constraint constraint = operandMatcher.getAsConstraint(); |
696 | if (operand.constraint != constraint) { |
697 | if (operand.isVariableLength()) { |
698 | auto error = formatv( |
699 | Fmt: "further constrain op {0}'s variadic operand #{1} unsupported now" , |
700 | Vals: op.getOperationName(), Vals&: argIndex); |
701 | PrintFatalError(ErrorLoc: loc, Msg: error); |
702 | } |
703 | auto self = formatv(Fmt: "(*{0}.begin()).getType()" , Vals&: operandName); |
704 | StringRef verifier = staticMatcherHelper.getVerifierName(leaf: operandMatcher); |
705 | emitStaticVerifierCall( |
706 | funcName: verifier, opName, arg: self.str(), |
707 | failureStr: formatv( |
708 | Fmt: "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"" , |
709 | Vals&: operandIndex, Vals: op.getOperationName(), |
710 | Vals: escapeString(value: constraint.getSummary())) |
711 | .str()); |
712 | } |
713 | } |
714 | |
715 | // Capture the value |
716 | // `$_` is a special symbol to ignore op argument matching. |
717 | if (!argName.empty() && argName != "_" ) { |
718 | auto res = symbolInfoMap.findBoundSymbol(key: argName, node: tree, op, argIndex, |
719 | variadicSubIndex); |
720 | if (res == symbolInfoMap.end()) |
721 | PrintFatalError(ErrorLoc: loc, Msg: formatv(Fmt: "symbol not found: {0}" , Vals&: argName)); |
722 | |
723 | os << formatv(Fmt: "{0} = {1};\n" , Vals: res->second.getVarName(name: argName), Vals&: operandName); |
724 | } |
725 | } |
726 | |
727 | void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree, |
728 | StringRef opName, int argIndex, |
729 | int &operandIndex, int depth) { |
730 | constexpr int numEitherArgs = 2; |
731 | if (eitherArgTree.getNumArgs() != numEitherArgs) |
732 | PrintFatalError(ErrorLoc: loc, Msg: "`either` only supports grouping two operands" ); |
733 | |
734 | Operator &op = tree.getDialectOp(mapper: opMap); |
735 | |
736 | std::string codeBuffer; |
737 | llvm::raw_string_ostream tblgenOps(codeBuffer); |
738 | |
739 | std::string lambda = formatv(Fmt: "eitherLambda{0}" , Vals&: depth); |
740 | os << formatv( |
741 | Fmt: "auto {0} = [&](::mlir::OperandRange v0, ::mlir::OperandRange v1) {{\n" , |
742 | Vals&: lambda); |
743 | |
744 | os.indent(); |
745 | |
746 | for (int i = 0; i < numEitherArgs; ++i, ++argIndex) { |
747 | if (DagNode argTree = eitherArgTree.getArgAsNestedDag(index: i)) { |
748 | if (argTree.isEither()) |
749 | PrintFatalError(ErrorLoc: loc, Msg: "either cannot be nested" ); |
750 | |
751 | std::string argName = formatv(Fmt: "local_op_{0}" , Vals&: i).str(); |
752 | |
753 | os << formatv(Fmt: "auto {0} = (*v{1}.begin()).getDefiningOp();\n" , Vals&: argName, |
754 | Vals&: i); |
755 | |
756 | // Indent emitMatchCheck and emitMatch because they declare local |
757 | // variables. |
758 | os << "{\n" ; |
759 | os.indent(); |
760 | |
761 | emitMatchCheck( |
762 | opName, /*matchStr=*/argName, |
763 | failureStr: formatv(Fmt: "\"There's no operation that defines operand {0} of {1}\"" , |
764 | Vals: operandIndex++, Vals&: opName)); |
765 | emitMatch(tree: argTree, name: argName, depth: depth + 1); |
766 | |
767 | os.unindent() << "}\n" ; |
768 | |
769 | // `tblgen_ops` is used to collect the matched operations. In either, we |
770 | // need to queue the operation only if the matching success. Thus we emit |
771 | // the code at the end. |
772 | tblgenOps << formatv(Fmt: "tblgen_ops.push_back({0});\n" , Vals&: argName); |
773 | } else if (isa<NamedTypeConstraint *>(Val: op.getArg(index: argIndex))) { |
774 | emitOperandMatch(tree, opName, /*operandName=*/formatv(Fmt: "v{0}" , Vals&: i).str(), |
775 | operandIndex, |
776 | /*operandMatcher=*/eitherArgTree.getArgAsLeaf(index: i), |
777 | /*argName=*/eitherArgTree.getArgName(index: i), argIndex, |
778 | /*variadicSubIndex=*/std::nullopt); |
779 | ++operandIndex; |
780 | } else { |
781 | PrintFatalError(ErrorLoc: loc, Msg: "either can only be applied on operand" ); |
782 | } |
783 | } |
784 | |
785 | os << tblgenOps.str(); |
786 | os << "return ::mlir::success();\n" ; |
787 | os.unindent() << "};\n" ; |
788 | |
789 | os << "{\n" ; |
790 | os.indent(); |
791 | |
792 | os << formatv(Fmt: "auto eitherOperand0 = {0}.getODSOperands({1});\n" , Vals&: opName, |
793 | Vals: operandIndex - 2); |
794 | os << formatv(Fmt: "auto eitherOperand1 = {0}.getODSOperands({1});\n" , Vals&: opName, |
795 | Vals: operandIndex - 1); |
796 | |
797 | os << formatv(Fmt: "if(::mlir::failed({0}(eitherOperand0, eitherOperand1)) && " |
798 | "::mlir::failed({0}(eitherOperand1, " |
799 | "eitherOperand0)))\n" , |
800 | Vals&: lambda); |
801 | os.indent() << "return ::mlir::failure();\n" ; |
802 | |
803 | os.unindent().unindent() << "}\n" ; |
804 | } |
805 | |
806 | void PatternEmitter::emitVariadicOperandMatch(DagNode tree, |
807 | DagNode variadicArgTree, |
808 | StringRef opName, int argIndex, |
809 | int &operandIndex, int depth) { |
810 | Operator &op = tree.getDialectOp(mapper: opMap); |
811 | |
812 | os << "{\n" ; |
813 | os.indent(); |
814 | |
815 | os << formatv(Fmt: "auto variadic_operand_range = {0}.getODSOperands({1});\n" , |
816 | Vals&: opName, Vals&: operandIndex); |
817 | os << formatv(Fmt: "if (variadic_operand_range.size() != {0}) " |
818 | "return ::mlir::failure();\n" , |
819 | Vals: variadicArgTree.getNumArgs()); |
820 | |
821 | StringRef variadicTreeName = variadicArgTree.getSymbol(); |
822 | if (!variadicTreeName.empty()) { |
823 | auto res = |
824 | symbolInfoMap.findBoundSymbol(key: variadicTreeName, node: tree, op, argIndex, |
825 | /*variadicSubIndex=*/std::nullopt); |
826 | if (res == symbolInfoMap.end()) |
827 | PrintFatalError(ErrorLoc: loc, Msg: formatv(Fmt: "symbol not found: {0}" , Vals&: variadicTreeName)); |
828 | |
829 | os << formatv(Fmt: "{0} = variadic_operand_range;\n" , |
830 | Vals: res->second.getVarName(name: variadicTreeName)); |
831 | } |
832 | |
833 | for (int i = 0; i < variadicArgTree.getNumArgs(); ++i) { |
834 | if (DagNode argTree = variadicArgTree.getArgAsNestedDag(index: i)) { |
835 | if (!argTree.isOperation()) |
836 | PrintFatalError(ErrorLoc: loc, Msg: "variadic only accepts operation sub-dags" ); |
837 | |
838 | os << "{\n" ; |
839 | os.indent(); |
840 | |
841 | std::string argName = formatv(Fmt: "local_op_{0}" , Vals&: i).str(); |
842 | os << formatv(Fmt: "auto *{0} = " |
843 | "variadic_operand_range[{1}].getDefiningOp();\n" , |
844 | Vals&: argName, Vals&: i); |
845 | emitMatchCheck( |
846 | opName, /*matchStr=*/argName, |
847 | failureStr: formatv(Fmt: "\"There's no operation that defines variadic operand " |
848 | "{0} (variadic sub-opearnd #{1}) of {2}\"" , |
849 | Vals&: operandIndex, Vals&: i, Vals&: opName)); |
850 | emitMatch(tree: argTree, name: argName, depth: depth + 1); |
851 | os << formatv(Fmt: "tblgen_ops.push_back({0});\n" , Vals&: argName); |
852 | |
853 | os.unindent() << "}\n" ; |
854 | } else if (isa<NamedTypeConstraint *>(Val: op.getArg(index: argIndex))) { |
855 | auto operandName = formatv(Fmt: "variadic_operand_range.slice({0}, 1)" , Vals&: i); |
856 | emitOperandMatch(tree, opName, operandName: operandName.str(), operandIndex, |
857 | /*operandMatcher=*/variadicArgTree.getArgAsLeaf(index: i), |
858 | /*argName=*/variadicArgTree.getArgName(index: i), argIndex, variadicSubIndex: i); |
859 | } else { |
860 | PrintFatalError(ErrorLoc: loc, Msg: "variadic can only be applied on operand" ); |
861 | } |
862 | } |
863 | |
864 | os.unindent() << "}\n" ; |
865 | } |
866 | |
867 | void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef castedName, |
868 | int argIndex, int depth) { |
869 | Operator &op = tree.getDialectOp(mapper: opMap); |
870 | auto *namedAttr = cast<NamedAttribute *>(Val: op.getArg(index: argIndex)); |
871 | const auto &attr = namedAttr->attr; |
872 | |
873 | os << "{\n" ; |
874 | if (op.getDialect().usePropertiesForAttributes()) { |
875 | os.indent() << formatv( |
876 | Fmt: "[[maybe_unused]] auto tblgen_attr = {0}.getProperties().{1}();\n" , |
877 | Vals&: castedName, Vals: op.getGetterName(name: namedAttr->name)); |
878 | } else { |
879 | os.indent() << formatv(Fmt: "[[maybe_unused]] auto tblgen_attr = " |
880 | "{0}->getAttrOfType<{1}>(\"{2}\");\n" , |
881 | Vals&: castedName, Vals: attr.getStorageType(), Vals&: namedAttr->name); |
882 | } |
883 | |
884 | // TODO: This should use getter method to avoid duplication. |
885 | if (attr.hasDefaultValue()) { |
886 | os << "if (!tblgen_attr) tblgen_attr = " |
887 | << std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fmtCtx, |
888 | vals: tgfmt(fmt: attr.getDefaultValue(), ctx: &fmtCtx))) |
889 | << ";\n" ; |
890 | } else if (attr.isOptional()) { |
891 | // For a missing attribute that is optional according to definition, we |
892 | // should just capture a mlir::Attribute() to signal the missing state. |
893 | // That is precisely what getDiscardableAttr() returns on missing |
894 | // attributes. |
895 | } else { |
896 | emitMatchCheck(opName: castedName, matchFmt: tgfmt(fmt: "tblgen_attr" , ctx: &fmtCtx), |
897 | failureFmt: formatv(Fmt: "\"expected op '{0}' to have attribute '{1}' " |
898 | "of type '{2}'\"" , |
899 | Vals: op.getOperationName(), Vals&: namedAttr->name, |
900 | Vals: attr.getStorageType())); |
901 | } |
902 | |
903 | auto matcher = tree.getArgAsLeaf(index: argIndex); |
904 | if (!matcher.isUnspecified()) { |
905 | if (!matcher.isAttrMatcher()) { |
906 | PrintFatalError( |
907 | ErrorLoc: loc, Msg: formatv(Fmt: "the {1}-th argument of op '{0}' should be an attribute" , |
908 | Vals: op.getOperationName(), Vals: argIndex + 1)); |
909 | } |
910 | |
911 | // If a constraint is specified, we need to generate function call to its |
912 | // static verifier. |
913 | StringRef verifier = staticMatcherHelper.getVerifierName(leaf: matcher); |
914 | if (attr.isOptional()) { |
915 | // Avoid dereferencing null attribute. This is using a simple heuristic to |
916 | // avoid common cases of attempting to dereference null attribute. This |
917 | // will return where there is no check if attribute is null unless the |
918 | // attribute's value is not used. |
919 | // FIXME: This could be improved as some null dereferences could slip |
920 | // through. |
921 | if (!StringRef(matcher.getConditionTemplate()).contains(Other: "!$_self" ) && |
922 | StringRef(matcher.getConditionTemplate()).contains(Other: "$_self" )) { |
923 | os << "if (!tblgen_attr) return ::mlir::failure();\n" ; |
924 | } |
925 | } |
926 | emitStaticVerifierCall( |
927 | funcName: verifier, opName: castedName, arg: "tblgen_attr" , |
928 | failureStr: formatv(Fmt: "\"op '{0}' attribute '{1}' failed to satisfy constraint: " |
929 | "'{2}'\"" , |
930 | Vals: op.getOperationName(), Vals&: namedAttr->name, |
931 | Vals: escapeString(value: matcher.getAsConstraint().getSummary())) |
932 | .str()); |
933 | } |
934 | |
935 | // Capture the value |
936 | auto name = tree.getArgName(index: argIndex); |
937 | // `$_` is a special symbol to ignore op argument matching. |
938 | if (!name.empty() && name != "_" ) { |
939 | os << formatv(Fmt: "{0} = tblgen_attr;\n" , Vals&: name); |
940 | } |
941 | |
942 | os.unindent() << "}\n" ; |
943 | } |
944 | |
945 | void PatternEmitter::emitMatchCheck( |
946 | StringRef opName, const FmtObjectBase &matchFmt, |
947 | const llvm::formatv_object_base &failureFmt) { |
948 | emitMatchCheck(opName, matchStr: matchFmt.str(), failureStr: failureFmt.str()); |
949 | } |
950 | |
951 | void PatternEmitter::emitMatchCheck(StringRef opName, |
952 | const std::string &matchStr, |
953 | const std::string &failureStr) { |
954 | |
955 | os << "if (!(" << matchStr << "))" ; |
956 | os.scope(open: "{\n" , close: "\n}\n" ).os << "return rewriter.notifyMatchFailure(" << opName |
957 | << ", [&](::mlir::Diagnostic &diag) {\n diag << " |
958 | << failureStr << ";\n});" ; |
959 | } |
960 | |
961 | void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { |
962 | LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n" ); |
963 | int depth = 0; |
964 | emitMatch(tree, name: opName, depth); |
965 | |
966 | for (auto &appliedConstraint : pattern.getConstraints()) { |
967 | auto &constraint = appliedConstraint.constraint; |
968 | auto &entities = appliedConstraint.entities; |
969 | |
970 | auto condition = constraint.getConditionTemplate(); |
971 | if (isa<TypeConstraint>(Val: constraint)) { |
972 | if (entities.size() != 1) |
973 | PrintFatalError(ErrorLoc: loc, Msg: "type constraint requires exactly one argument" ); |
974 | |
975 | auto self = formatv(Fmt: "({0}.getType())" , |
976 | Vals: symbolInfoMap.getValueAndRangeUse(symbol: entities.front())); |
977 | emitMatchCheck( |
978 | opName, matchFmt: tgfmt(fmt: condition, ctx: &fmtCtx.withSelf(subst: self.str())), |
979 | failureFmt: formatv(Fmt: "\"value entity '{0}' failed to satisfy constraint: '{1}'\"" , |
980 | Vals&: entities.front(), Vals: escapeString(value: constraint.getSummary()))); |
981 | |
982 | } else if (isa<AttrConstraint>(Val: constraint)) { |
983 | PrintFatalError( |
984 | ErrorLoc: loc, Msg: "cannot use AttrConstraint in Pattern multi-entity constraints" ); |
985 | } else { |
986 | // TODO: replace formatv arguments with the exact specified |
987 | // args. |
988 | if (entities.size() > 4) { |
989 | PrintFatalError(ErrorLoc: loc, Msg: "only support up to 4-entity constraints now" ); |
990 | } |
991 | SmallVector<std::string, 4> names; |
992 | int i = 0; |
993 | for (int e = entities.size(); i < e; ++i) |
994 | names.push_back(Elt: symbolInfoMap.getValueAndRangeUse(symbol: entities[i])); |
995 | std::string self = appliedConstraint.self; |
996 | if (!self.empty()) |
997 | self = symbolInfoMap.getValueAndRangeUse(symbol: self); |
998 | for (; i < 4; ++i) |
999 | names.push_back(Elt: "<unused>" ); |
1000 | emitMatchCheck(opName, |
1001 | matchFmt: tgfmt(fmt: condition, ctx: &fmtCtx.withSelf(subst: self), vals&: names[0], |
1002 | vals&: names[1], vals&: names[2], vals&: names[3]), |
1003 | failureFmt: formatv(Fmt: "\"entities '{0}' failed to satisfy constraint: " |
1004 | "'{1}'\"" , |
1005 | Vals: llvm::join(R&: entities, Separator: ", " ), |
1006 | Vals: escapeString(value: constraint.getSummary()))); |
1007 | } |
1008 | } |
1009 | |
1010 | // Some of the operands could be bound to the same symbol name, we need |
1011 | // to enforce equality constraint on those. |
1012 | // TODO: we should be able to emit equality checks early |
1013 | // and short circuit unnecessary work if vars are not equal. |
1014 | for (auto symbolInfoIt = symbolInfoMap.begin(); |
1015 | symbolInfoIt != symbolInfoMap.end();) { |
1016 | auto range = symbolInfoMap.getRangeOfEqualElements(key: symbolInfoIt->first); |
1017 | auto startRange = range.first; |
1018 | auto endRange = range.second; |
1019 | |
1020 | auto firstOperand = symbolInfoIt->second.getVarName(name: symbolInfoIt->first); |
1021 | for (++startRange; startRange != endRange; ++startRange) { |
1022 | auto secondOperand = startRange->second.getVarName(name: symbolInfoIt->first); |
1023 | emitMatchCheck( |
1024 | opName, |
1025 | matchStr: formatv(Fmt: "*{0}.begin() == *{1}.begin()" , Vals&: firstOperand, Vals&: secondOperand), |
1026 | failureStr: formatv(Fmt: "\"Operands '{0}' and '{1}' must be equal\"" , Vals&: firstOperand, |
1027 | Vals&: secondOperand)); |
1028 | } |
1029 | |
1030 | symbolInfoIt = endRange; |
1031 | } |
1032 | |
1033 | LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n" ); |
1034 | } |
1035 | |
1036 | void PatternEmitter::collectOps(DagNode tree, |
1037 | llvm::SmallPtrSetImpl<const Operator *> &ops) { |
1038 | // Check if this tree is an operation. |
1039 | if (tree.isOperation()) { |
1040 | const Operator &op = tree.getDialectOp(mapper: opMap); |
1041 | LLVM_DEBUG(llvm::dbgs() |
1042 | << "found operation " << op.getOperationName() << '\n'); |
1043 | ops.insert(Ptr: &op); |
1044 | } |
1045 | |
1046 | // Recurse the arguments of the tree. |
1047 | for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) |
1048 | if (auto child = tree.getArgAsNestedDag(index: i)) |
1049 | collectOps(tree: child, ops); |
1050 | } |
1051 | |
1052 | void PatternEmitter::emit(StringRef rewriteName) { |
1053 | // Get the DAG tree for the source pattern. |
1054 | DagNode sourceTree = pattern.getSourcePattern(); |
1055 | |
1056 | const Operator &rootOp = pattern.getSourceRootOp(); |
1057 | auto rootName = rootOp.getOperationName(); |
1058 | |
1059 | // Collect the set of result operations. |
1060 | llvm::SmallPtrSet<const Operator *, 4> resultOps; |
1061 | LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n" ); |
1062 | for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { |
1063 | collectOps(tree: pattern.getResultPattern(index: i), ops&: resultOps); |
1064 | } |
1065 | LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n" ); |
1066 | |
1067 | // Emit RewritePattern for Pattern. |
1068 | auto locs = pattern.getLocation(); |
1069 | os << formatv(Fmt: "/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n" , |
1070 | Vals: llvm::reverse(C&: locs)); |
1071 | os << formatv(Fmt: R"(struct {0} : public ::mlir::RewritePattern { |
1072 | {0}(::mlir::MLIRContext *context) |
1073 | : ::mlir::RewritePattern("{1}", {2}, context, {{)" , |
1074 | Vals&: rewriteName, Vals&: rootName, Vals: pattern.getBenefit()); |
1075 | // Sort result operators by name. |
1076 | llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(), |
1077 | resultOps.end()); |
1078 | llvm::sort(C&: sortedResultOps, Comp: [&](const Operator *lhs, const Operator *rhs) { |
1079 | return lhs->getOperationName() < rhs->getOperationName(); |
1080 | }); |
1081 | llvm::interleaveComma(c: sortedResultOps, os, each_fn: [&](const Operator *op) { |
1082 | os << '"' << op->getOperationName() << '"'; |
1083 | }); |
1084 | os << "}) {}\n" ; |
1085 | |
1086 | // Emit matchAndRewrite() function. |
1087 | { |
1088 | auto classScope = os.scope(); |
1089 | os.printReindented(str: R"( |
1090 | ::llvm::LogicalResult matchAndRewrite(::mlir::Operation *op0, |
1091 | ::mlir::PatternRewriter &rewriter) const override {)" ) |
1092 | << '\n'; |
1093 | { |
1094 | auto functionScope = os.scope(); |
1095 | |
1096 | // Register all symbols bound in the source pattern. |
1097 | pattern.collectSourcePatternBoundSymbols(infoMap&: symbolInfoMap); |
1098 | |
1099 | LLVM_DEBUG(llvm::dbgs() |
1100 | << "start creating local variables for capturing matches\n" ); |
1101 | os << "// Variables for capturing values and attributes used while " |
1102 | "creating ops\n" ; |
1103 | // Create local variables for storing the arguments and results bound |
1104 | // to symbols. |
1105 | for (const auto &symbolInfoPair : symbolInfoMap) { |
1106 | const auto &symbol = symbolInfoPair.first; |
1107 | const auto &info = symbolInfoPair.second; |
1108 | |
1109 | os << info.getVarDecl(name: symbol); |
1110 | } |
1111 | // TODO: capture ops with consistent numbering so that it can be |
1112 | // reused for fused loc. |
1113 | os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n" ; |
1114 | LLVM_DEBUG(llvm::dbgs() |
1115 | << "done creating local variables for capturing matches\n" ); |
1116 | |
1117 | os << "// Match\n" ; |
1118 | os << "tblgen_ops.push_back(op0);\n" ; |
1119 | emitMatchLogic(tree: sourceTree, opName: "op0" ); |
1120 | |
1121 | os << "\n// Rewrite\n" ; |
1122 | emitRewriteLogic(); |
1123 | |
1124 | os << "return ::mlir::success();\n" ; |
1125 | } |
1126 | os << "}\n" ; |
1127 | } |
1128 | os << "};\n\n" ; |
1129 | } |
1130 | |
1131 | void PatternEmitter::emitRewriteLogic() { |
1132 | LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n" ); |
1133 | const Operator &rootOp = pattern.getSourceRootOp(); |
1134 | int numExpectedResults = rootOp.getNumResults(); |
1135 | int numResultPatterns = pattern.getNumResultPatterns(); |
1136 | |
1137 | // First register all symbols bound to ops generated in result patterns. |
1138 | pattern.collectResultPatternBoundSymbols(infoMap&: symbolInfoMap); |
1139 | |
1140 | // Only the last N static values generated are used to replace the matched |
1141 | // root N-result op. We need to calculate the starting index (of the results |
1142 | // of the matched op) each result pattern is to replace. |
1143 | SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); |
1144 | // If we don't need to replace any value at all, set the replacement starting |
1145 | // index as the number of result patterns so we skip all of them when trying |
1146 | // to replace the matched op's results. |
1147 | int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; |
1148 | for (int i = numResultPatterns - 1; i >= 0; --i) { |
1149 | auto numValues = getNodeValueCount(node: pattern.getResultPattern(index: i)); |
1150 | offsets[i] = offsets[i + 1] - numValues; |
1151 | if (offsets[i] == 0) { |
1152 | if (replStartIndex == -1) |
1153 | replStartIndex = i; |
1154 | } else if (offsets[i] < 0 && offsets[i + 1] > 0) { |
1155 | auto error = formatv( |
1156 | Fmt: "cannot use the same multi-result op '{0}' to generate both " |
1157 | "auxiliary values and values to be used for replacing the matched op" , |
1158 | Vals: pattern.getResultPattern(index: i).getSymbol()); |
1159 | PrintFatalError(ErrorLoc: loc, Msg: error); |
1160 | } |
1161 | } |
1162 | |
1163 | if (offsets.front() > 0) { |
1164 | const char error[] = |
1165 | "not enough values generated to replace the matched op" ; |
1166 | PrintFatalError(ErrorLoc: loc, Msg: error); |
1167 | } |
1168 | |
1169 | os << "auto odsLoc = rewriter.getFusedLoc({" ; |
1170 | for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { |
1171 | os << (i ? ", " : "" ) << "tblgen_ops[" << i << "]->getLoc()" ; |
1172 | } |
1173 | os << "}); (void)odsLoc;\n" ; |
1174 | |
1175 | // Process auxiliary result patterns. |
1176 | for (int i = 0; i < replStartIndex; ++i) { |
1177 | DagNode resultTree = pattern.getResultPattern(index: i); |
1178 | auto val = handleResultPattern(resultTree, resultIndex: offsets[i], depth: 0); |
1179 | // Normal op creation will be streamed to `os` by the above call; but |
1180 | // NativeCodeCall will only be materialized to `os` if it is used. Here |
1181 | // we are handling auxiliary patterns so we want the side effect even if |
1182 | // NativeCodeCall is not replacing matched root op's results. |
1183 | if (resultTree.isNativeCodeCall() && |
1184 | resultTree.getNumReturnsOfNativeCode() == 0) |
1185 | os << val << ";\n" ; |
1186 | } |
1187 | |
1188 | auto processSupplementalPatterns = [&]() { |
1189 | int numSupplementalPatterns = pattern.getNumSupplementalPatterns(); |
1190 | for (int i = 0, offset = -numSupplementalPatterns; |
1191 | i < numSupplementalPatterns; ++i) { |
1192 | DagNode resultTree = pattern.getSupplementalPattern(index: i); |
1193 | auto val = handleResultPattern(resultTree, resultIndex: offset++, depth: 0); |
1194 | if (resultTree.isNativeCodeCall() && |
1195 | resultTree.getNumReturnsOfNativeCode() == 0) |
1196 | os << val << ";\n" ; |
1197 | } |
1198 | }; |
1199 | |
1200 | if (numExpectedResults == 0) { |
1201 | assert(replStartIndex >= numResultPatterns && |
1202 | "invalid auxiliary vs. replacement pattern division!" ); |
1203 | processSupplementalPatterns(); |
1204 | // No result to replace. Just erase the op. |
1205 | os << "rewriter.eraseOp(op0);\n" ; |
1206 | } else { |
1207 | // Process replacement result patterns. |
1208 | os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n" ; |
1209 | for (int i = replStartIndex; i < numResultPatterns; ++i) { |
1210 | DagNode resultTree = pattern.getResultPattern(index: i); |
1211 | auto val = handleResultPattern(resultTree, resultIndex: offsets[i], depth: 0); |
1212 | os << "\n" ; |
1213 | // Resolve each symbol for all range use so that we can loop over them. |
1214 | // We need an explicit cast to `SmallVector` to capture the cases where |
1215 | // `{0}` resolves to an `Operation::result_range` as well as cases that |
1216 | // are not iterable (e.g. vector that gets wrapped in additional braces by |
1217 | // RewriterGen). |
1218 | // TODO: Revisit the need for materializing a vector. |
1219 | os << symbolInfoMap.getAllRangeUse( |
1220 | symbol: val, |
1221 | fmt: "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n" |
1222 | " tblgen_repl_values.push_back(v);\n}\n" , |
1223 | separator: "\n" ); |
1224 | } |
1225 | processSupplementalPatterns(); |
1226 | os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n" ; |
1227 | } |
1228 | |
1229 | LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n" ); |
1230 | } |
1231 | |
1232 | std::string PatternEmitter::getUniqueSymbol(const Operator *op) { |
1233 | return std::string( |
1234 | formatv(Fmt: "tblgen_{0}_{1}" , Vals: op->getCppClassName(), Vals: nextValueId++)); |
1235 | } |
1236 | |
1237 | std::string PatternEmitter::handleResultPattern(DagNode resultTree, |
1238 | int resultIndex, int depth) { |
1239 | LLVM_DEBUG(llvm::dbgs() << "handle result pattern: " ); |
1240 | LLVM_DEBUG(resultTree.print(llvm::dbgs())); |
1241 | LLVM_DEBUG(llvm::dbgs() << '\n'); |
1242 | |
1243 | if (resultTree.isLocationDirective()) { |
1244 | PrintFatalError(ErrorLoc: loc, |
1245 | Msg: "location directive can only be used with op creation" ); |
1246 | } |
1247 | |
1248 | if (resultTree.isNativeCodeCall()) |
1249 | return handleReplaceWithNativeCodeCall(resultTree, depth); |
1250 | |
1251 | if (resultTree.isReplaceWithValue()) |
1252 | return handleReplaceWithValue(tree: resultTree).str(); |
1253 | |
1254 | if (resultTree.isVariadic()) |
1255 | return handleVariadic(tree: resultTree, depth); |
1256 | |
1257 | // Normal op creation. |
1258 | auto symbol = handleOpCreation(tree: resultTree, resultIndex, depth); |
1259 | if (resultTree.getSymbol().empty()) { |
1260 | // This is an op not explicitly bound to a symbol in the rewrite rule. |
1261 | // Register the auto-generated symbol for it. |
1262 | symbolInfoMap.bindOpResult(symbol, op: pattern.getDialectOp(node: resultTree)); |
1263 | } |
1264 | return symbol; |
1265 | } |
1266 | |
1267 | std::string PatternEmitter::handleVariadic(DagNode tree, int depth) { |
1268 | assert(tree.isVariadic()); |
1269 | |
1270 | std::string output; |
1271 | llvm::raw_string_ostream oss(output); |
1272 | auto name = std::string(formatv(Fmt: "tblgen_variadic_values_{0}" , Vals: nextValueId++)); |
1273 | symbolInfoMap.bindValue(symbol: name); |
1274 | oss << "::llvm::SmallVector<::mlir::Value, 4> " << name << ";\n" ; |
1275 | for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { |
1276 | if (auto child = tree.getArgAsNestedDag(index: i)) { |
1277 | oss << name << ".push_back(" << handleResultPattern(resultTree: child, resultIndex: i, depth: depth + 1) |
1278 | << ");\n" ; |
1279 | } else { |
1280 | oss << name << ".push_back(" |
1281 | << handleOpArgument(leaf: tree.getArgAsLeaf(index: i), patArgName: tree.getArgName(index: i)) |
1282 | << ");\n" ; |
1283 | } |
1284 | } |
1285 | |
1286 | os << oss.str(); |
1287 | return name; |
1288 | } |
1289 | |
1290 | StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) { |
1291 | assert(tree.isReplaceWithValue()); |
1292 | |
1293 | if (tree.getNumArgs() != 1) { |
1294 | PrintFatalError( |
1295 | ErrorLoc: loc, Msg: "replaceWithValue directive must take exactly one argument" ); |
1296 | } |
1297 | |
1298 | if (!tree.getSymbol().empty()) { |
1299 | PrintFatalError(ErrorLoc: loc, Msg: "cannot bind symbol to replaceWithValue" ); |
1300 | } |
1301 | |
1302 | return tree.getArgName(index: 0); |
1303 | } |
1304 | |
1305 | std::string PatternEmitter::handleLocationDirective(DagNode tree) { |
1306 | assert(tree.isLocationDirective()); |
1307 | auto lookUpArgLoc = [this, &tree](int idx) { |
1308 | const auto *const lookupFmt = "{0}.getLoc()" ; |
1309 | return symbolInfoMap.getValueAndRangeUse(symbol: tree.getArgName(index: idx), fmt: lookupFmt); |
1310 | }; |
1311 | |
1312 | if (tree.getNumArgs() == 0) |
1313 | llvm::PrintFatalError( |
1314 | Msg: "At least one argument to location directive required" ); |
1315 | |
1316 | if (!tree.getSymbol().empty()) |
1317 | PrintFatalError(ErrorLoc: loc, Msg: "cannot bind symbol to location" ); |
1318 | |
1319 | if (tree.getNumArgs() == 1) { |
1320 | DagLeaf leaf = tree.getArgAsLeaf(index: 0); |
1321 | if (leaf.isStringAttr()) |
1322 | return formatv(Fmt: "::mlir::NameLoc::get(rewriter.getStringAttr(\"{0}\"))" , |
1323 | Vals: leaf.getStringAttr()) |
1324 | .str(); |
1325 | return lookUpArgLoc(0); |
1326 | } |
1327 | |
1328 | std::string ret; |
1329 | llvm::raw_string_ostream os(ret); |
1330 | std::string strAttr; |
1331 | os << "rewriter.getFusedLoc({" ; |
1332 | bool first = true; |
1333 | for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { |
1334 | DagLeaf leaf = tree.getArgAsLeaf(index: i); |
1335 | // Handle the optional string value. |
1336 | if (leaf.isStringAttr()) { |
1337 | if (!strAttr.empty()) |
1338 | llvm::PrintFatalError(Msg: "Only one string attribute may be specified" ); |
1339 | strAttr = leaf.getStringAttr(); |
1340 | continue; |
1341 | } |
1342 | os << (first ? "" : ", " ) << lookUpArgLoc(i); |
1343 | first = false; |
1344 | } |
1345 | os << "}" ; |
1346 | if (!strAttr.empty()) { |
1347 | os << ", rewriter.getStringAttr(\"" << strAttr << "\")" ; |
1348 | } |
1349 | os << ")" ; |
1350 | return os.str(); |
1351 | } |
1352 | |
1353 | std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i, |
1354 | int depth) { |
1355 | // Nested NativeCodeCall. |
1356 | if (auto dagNode = returnType.getArgAsNestedDag(index: i)) { |
1357 | if (!dagNode.isNativeCodeCall()) |
1358 | PrintFatalError(ErrorLoc: loc, Msg: "nested DAG in `returnType` must be a native code " |
1359 | "call" ); |
1360 | return handleReplaceWithNativeCodeCall(resultTree: dagNode, depth); |
1361 | } |
1362 | // String literal. |
1363 | auto dagLeaf = returnType.getArgAsLeaf(index: i); |
1364 | if (dagLeaf.isStringAttr()) |
1365 | return tgfmt(fmt: dagLeaf.getStringAttr(), ctx: &fmtCtx); |
1366 | return tgfmt( |
1367 | fmt: "$0.getType()" , ctx: &fmtCtx, |
1368 | vals: handleOpArgument(leaf: returnType.getArgAsLeaf(index: i), patArgName: returnType.getArgName(index: i))); |
1369 | } |
1370 | |
1371 | std::string PatternEmitter::handleOpArgument(DagLeaf leaf, |
1372 | StringRef patArgName) { |
1373 | if (leaf.isStringAttr()) |
1374 | PrintFatalError(ErrorLoc: loc, Msg: "raw string not supported as argument" ); |
1375 | if (leaf.isConstantAttr()) { |
1376 | auto constAttr = leaf.getAsConstantAttr(); |
1377 | return handleConstantAttr(attr: constAttr.getAttribute(), |
1378 | value: constAttr.getConstantValue()); |
1379 | } |
1380 | if (leaf.isEnumCase()) { |
1381 | auto enumCase = leaf.getAsEnumCase(); |
1382 | // This is an enum case backed by an IntegerAttr. We need to get its value |
1383 | // to build the constant. |
1384 | std::string val = std::to_string(val: enumCase.getValue()); |
1385 | return handleConstantAttr(attr: Attribute(&enumCase.getDef()), value: val); |
1386 | } |
1387 | |
1388 | LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n" ); |
1389 | auto argName = symbolInfoMap.getValueAndRangeUse(symbol: patArgName); |
1390 | if (leaf.isUnspecified() || leaf.isOperandMatcher()) { |
1391 | LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName |
1392 | << "' (via symbol ref)\n" ); |
1393 | return argName; |
1394 | } |
1395 | if (leaf.isNativeCodeCall()) { |
1396 | auto repl = tgfmt(fmt: leaf.getNativeCodeTemplate(), ctx: &fmtCtx.withSelf(subst: argName)); |
1397 | LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl |
1398 | << "' (via NativeCodeCall)\n" ); |
1399 | return std::string(repl); |
1400 | } |
1401 | PrintFatalError(ErrorLoc: loc, Msg: "unhandled case when rewriting op" ); |
1402 | } |
1403 | |
1404 | std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree, |
1405 | int depth) { |
1406 | LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: " ); |
1407 | LLVM_DEBUG(tree.print(llvm::dbgs())); |
1408 | LLVM_DEBUG(llvm::dbgs() << '\n'); |
1409 | |
1410 | auto fmt = tree.getNativeCodeTemplate(); |
1411 | |
1412 | SmallVector<std::string, 16> attrs; |
1413 | |
1414 | auto tail = getTrailingDirectives(tree); |
1415 | if (tail.returnType) |
1416 | PrintFatalError(ErrorLoc: loc, Msg: "`NativeCodeCall` cannot have return type specifier" ); |
1417 | auto locToUse = getLocation(tail); |
1418 | |
1419 | for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { |
1420 | if (tree.isNestedDagArg(index: i)) { |
1421 | attrs.push_back( |
1422 | Elt: handleResultPattern(resultTree: tree.getArgAsNestedDag(index: i), resultIndex: i, depth: depth + 1)); |
1423 | } else { |
1424 | attrs.push_back( |
1425 | Elt: handleOpArgument(leaf: tree.getArgAsLeaf(index: i), patArgName: tree.getArgName(index: i))); |
1426 | } |
1427 | LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i |
1428 | << " replacement: " << attrs[i] << "\n" ); |
1429 | } |
1430 | |
1431 | std::string symbol = tgfmt(fmt, ctx: &fmtCtx.addSubst(placeholder: "_loc" , subst: locToUse), |
1432 | params: static_cast<ArrayRef<std::string>>(attrs)); |
1433 | |
1434 | // In general, NativeCodeCall without naming binding don't need this. To |
1435 | // ensure void helper function has been correctly labeled, i.e., use |
1436 | // NativeCodeCallVoid, we cache the result to a local variable so that we will |
1437 | // get a compilation error in the auto-generated file. |
1438 | // Example. |
1439 | // // In the td file |
1440 | // Pat<(...), (NativeCodeCall<Foo> ...)> |
1441 | // |
1442 | // --- |
1443 | // |
1444 | // // In the auto-generated .cpp |
1445 | // ... |
1446 | // // Causes compilation error if Foo() returns void. |
1447 | // auto nativeVar = Foo(); |
1448 | // ... |
1449 | if (tree.getNumReturnsOfNativeCode() != 0) { |
1450 | // Determine the local variable name for return value. |
1451 | std::string varName = |
1452 | SymbolInfoMap::getValuePackName(symbol: tree.getSymbol()).str(); |
1453 | if (varName.empty()) { |
1454 | varName = formatv(Fmt: "nativeVar_{0}" , Vals: nextValueId++); |
1455 | // Register the local variable for later uses. |
1456 | symbolInfoMap.bindValues(symbol: varName, numValues: tree.getNumReturnsOfNativeCode()); |
1457 | } |
1458 | |
1459 | // Catch the return value of helper function. |
1460 | os << formatv(Fmt: "auto {0} = {1}; (void){0};\n" , Vals&: varName, Vals&: symbol); |
1461 | |
1462 | if (!tree.getSymbol().empty()) |
1463 | symbol = tree.getSymbol().str(); |
1464 | else |
1465 | symbol = varName; |
1466 | } |
1467 | |
1468 | return symbol; |
1469 | } |
1470 | |
1471 | int PatternEmitter::getNodeValueCount(DagNode node) { |
1472 | if (node.isOperation()) { |
1473 | // If the op is bound to a symbol in the rewrite rule, query its result |
1474 | // count from the symbol info map. |
1475 | auto symbol = node.getSymbol(); |
1476 | if (!symbol.empty()) { |
1477 | return symbolInfoMap.getStaticValueCount(symbol); |
1478 | } |
1479 | // Otherwise this is an unbound op; we will use all its results. |
1480 | return pattern.getDialectOp(node).getNumResults(); |
1481 | } |
1482 | |
1483 | if (node.isNativeCodeCall()) |
1484 | return node.getNumReturnsOfNativeCode(); |
1485 | |
1486 | return 1; |
1487 | } |
1488 | |
1489 | PatternEmitter::TrailingDirectives |
1490 | PatternEmitter::getTrailingDirectives(DagNode tree) { |
1491 | TrailingDirectives tail = {.location: DagNode(nullptr), .returnType: DagNode(nullptr), .numDirectives: 0}; |
1492 | |
1493 | // Look backwards through the arguments. |
1494 | auto numPatArgs = tree.getNumArgs(); |
1495 | for (int i = numPatArgs - 1; i >= 0; --i) { |
1496 | auto dagArg = tree.getArgAsNestedDag(index: i); |
1497 | // A leaf is not a directive. Stop looking. |
1498 | if (!dagArg) |
1499 | break; |
1500 | |
1501 | auto isLocation = dagArg.isLocationDirective(); |
1502 | auto isReturnType = dagArg.isReturnTypeDirective(); |
1503 | // If encountered a DAG node that isn't a trailing directive, stop looking. |
1504 | if (!(isLocation || isReturnType)) |
1505 | break; |
1506 | // Save the directive, but error if one of the same type was already |
1507 | // found. |
1508 | ++tail.numDirectives; |
1509 | if (isLocation) { |
1510 | if (tail.location) |
1511 | PrintFatalError(ErrorLoc: loc, Msg: "`location` directive can only be specified " |
1512 | "once" ); |
1513 | tail.location = dagArg; |
1514 | } else if (isReturnType) { |
1515 | if (tail.returnType) |
1516 | PrintFatalError(ErrorLoc: loc, Msg: "`returnType` directive can only be specified " |
1517 | "once" ); |
1518 | tail.returnType = dagArg; |
1519 | } |
1520 | } |
1521 | |
1522 | return tail; |
1523 | } |
1524 | |
1525 | std::string |
1526 | PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) { |
1527 | if (tail.location) |
1528 | return handleLocationDirective(tree: tail.location); |
1529 | |
1530 | // If no explicit location is given, use the default, all fused, location. |
1531 | return "odsLoc" ; |
1532 | } |
1533 | |
1534 | std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, |
1535 | int depth) { |
1536 | LLVM_DEBUG(llvm::dbgs() << "create op for pattern: " ); |
1537 | LLVM_DEBUG(tree.print(llvm::dbgs())); |
1538 | LLVM_DEBUG(llvm::dbgs() << '\n'); |
1539 | |
1540 | Operator &resultOp = tree.getDialectOp(mapper: opMap); |
1541 | bool useProperties = resultOp.getDialect().usePropertiesForAttributes(); |
1542 | auto numOpArgs = resultOp.getNumArgs(); |
1543 | auto numPatArgs = tree.getNumArgs(); |
1544 | |
1545 | auto tail = getTrailingDirectives(tree); |
1546 | auto locToUse = getLocation(tail); |
1547 | |
1548 | auto inPattern = numPatArgs - tail.numDirectives; |
1549 | if (numOpArgs != inPattern) { |
1550 | PrintFatalError(ErrorLoc: loc, |
1551 | Msg: formatv(Fmt: "resultant op '{0}' argument number mismatch: " |
1552 | "{1} in pattern vs. {2} in definition" , |
1553 | Vals: resultOp.getOperationName(), Vals&: inPattern, Vals&: numOpArgs)); |
1554 | } |
1555 | |
1556 | // A map to collect all nested DAG child nodes' names, with operand index as |
1557 | // the key. This includes both bound and unbound child nodes. |
1558 | ChildNodeIndexNameMap childNodeNames; |
1559 | |
1560 | // If the argument is a type constraint, then its an operand. Check if the |
1561 | // op's argument is variadic that the argument in the pattern is too. |
1562 | auto checkIfMatchedVariadic = [&](int i) { |
1563 | // FIXME: This does not yet check for variable/leaf case. |
1564 | // FIXME: Change so that native code call can be handled. |
1565 | const auto *operand = |
1566 | llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val: resultOp.getArg(index: i)); |
1567 | if (!operand || !operand->isVariadic()) |
1568 | return; |
1569 | |
1570 | auto child = tree.getArgAsNestedDag(index: i); |
1571 | if (!child) |
1572 | return; |
1573 | |
1574 | // Skip over replaceWithValues. |
1575 | while (child.isReplaceWithValue()) { |
1576 | if (!(child = child.getArgAsNestedDag(index: 0))) |
1577 | return; |
1578 | } |
1579 | if (!child.isNativeCodeCall() && !child.isVariadic()) |
1580 | PrintFatalError(ErrorLoc: loc, Msg: formatv(Fmt: "op expects variadic operand `{0}`, while " |
1581 | "provided is non-variadic" , |
1582 | Vals: resultOp.getArgName(index: i))); |
1583 | }; |
1584 | |
1585 | // First go through all the child nodes who are nested DAG constructs to |
1586 | // create ops for them and remember the symbol names for them, so that we can |
1587 | // use the results in the current node. This happens in a recursive manner. |
1588 | for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { |
1589 | checkIfMatchedVariadic(i); |
1590 | if (auto child = tree.getArgAsNestedDag(index: i)) |
1591 | childNodeNames[i] = handleResultPattern(resultTree: child, resultIndex: i, depth: depth + 1); |
1592 | } |
1593 | |
1594 | // The name of the local variable holding this op. |
1595 | std::string valuePackName; |
1596 | // The symbol for holding the result of this pattern. Note that the result of |
1597 | // this pattern is not necessarily the same as the variable created by this |
1598 | // pattern because we can use `__N` suffix to refer only a specific result if |
1599 | // the generated op is a multi-result op. |
1600 | std::string resultValue; |
1601 | if (tree.getSymbol().empty()) { |
1602 | // No symbol is explicitly bound to this op in the pattern. Generate a |
1603 | // unique name. |
1604 | valuePackName = resultValue = getUniqueSymbol(op: &resultOp); |
1605 | } else { |
1606 | resultValue = std::string(tree.getSymbol()); |
1607 | // Strip the index to get the name for the value pack and use it to name the |
1608 | // local variable for the op. |
1609 | valuePackName = std::string(SymbolInfoMap::getValuePackName(symbol: resultValue)); |
1610 | } |
1611 | |
1612 | // Create the local variable for this op. |
1613 | os << formatv(Fmt: "{0} {1};\n{{\n" , Vals: resultOp.getQualCppClassName(), |
1614 | Vals&: valuePackName); |
1615 | |
1616 | // Right now ODS don't have general type inference support. Except a few |
1617 | // special cases listed below, DRR needs to supply types for all results |
1618 | // when building an op. |
1619 | bool isSameOperandsAndResultType = |
1620 | resultOp.getTrait(trait: "::mlir::OpTrait::SameOperandsAndResultType" ); |
1621 | bool useFirstAttr = |
1622 | resultOp.getTrait(trait: "::mlir::OpTrait::FirstAttrDerivedResultType" ); |
1623 | |
1624 | if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) { |
1625 | // We know how to deduce the result type for ops with these traits and we've |
1626 | // generated builders taking aggregate parameters. Use those builders to |
1627 | // create the ops. |
1628 | |
1629 | // First prepare local variables for op arguments used in builder call. |
1630 | createAggregateLocalVarsForOpArgs(node: tree, childNodeNames, depth); |
1631 | |
1632 | // Then create the op. |
1633 | os.scope(open: "" , close: "\n}\n" ).os |
1634 | << formatv(Fmt: "{0} = rewriter.create<{1}>({2}, tblgen_values, {3});" , |
1635 | Vals&: valuePackName, Vals: resultOp.getQualCppClassName(), Vals&: locToUse, |
1636 | Vals: useProperties ? "tblgen_props" : "tblgen_attrs" ); |
1637 | return resultValue; |
1638 | } |
1639 | |
1640 | bool usePartialResults = valuePackName != resultValue; |
1641 | |
1642 | if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) { |
1643 | // For these cases (broadcastable ops, op results used both as auxiliary |
1644 | // values and replacement values, ops in nested patterns, auxiliary ops), we |
1645 | // still need to supply the result types when building the op. But because |
1646 | // we don't generate a builder automatically with ODS for them, it's the |
1647 | // developer's responsibility to make sure such a builder (with result type |
1648 | // deduction ability) exists. We go through the separate-parameter builder |
1649 | // here given that it's easier for developers to write compared to |
1650 | // aggregate-parameter builders. |
1651 | createSeparateLocalVarsForOpArgs(node: tree, childNodeNames); |
1652 | |
1653 | os.scope().os << formatv(Fmt: "{0} = rewriter.create<{1}>({2}" , Vals&: valuePackName, |
1654 | Vals: resultOp.getQualCppClassName(), Vals&: locToUse); |
1655 | supplyValuesForOpArgs(node: tree, childNodeNames, depth); |
1656 | os << "\n );\n}\n" ; |
1657 | return resultValue; |
1658 | } |
1659 | |
1660 | // If we are provided explicit return types, use them to build the op. |
1661 | // However, if depth == 0 and resultIndex >= 0, it means we are replacing |
1662 | // the values generated from the source pattern root op. Then we must use the |
1663 | // source pattern's value types to determine the value type of the generated |
1664 | // op here. |
1665 | if (depth == 0 && resultIndex >= 0 && tail.returnType) |
1666 | PrintFatalError(ErrorLoc: loc, Msg: "Cannot specify explicit return types in an op whose " |
1667 | "return values replace the source pattern's root op" ); |
1668 | |
1669 | // First prepare local variables for op arguments used in builder call. |
1670 | createAggregateLocalVarsForOpArgs(node: tree, childNodeNames, depth); |
1671 | |
1672 | // Then prepare the result types. We need to specify the types for all |
1673 | // results. |
1674 | os.indent() << formatv(Fmt: "::llvm::SmallVector<::mlir::Type, 4> tblgen_types; " |
1675 | "(void)tblgen_types;\n" ); |
1676 | int numResults = resultOp.getNumResults(); |
1677 | if (tail.returnType) { |
1678 | auto numRetTys = tail.returnType.getNumArgs(); |
1679 | for (int i = 0; i < numRetTys; ++i) { |
1680 | auto varName = handleReturnTypeArg(returnType: tail.returnType, i, depth: depth + 1); |
1681 | os << "tblgen_types.push_back(" << varName << ");\n" ; |
1682 | } |
1683 | } else { |
1684 | if (numResults != 0) { |
1685 | // Copy the result types from the source pattern. |
1686 | for (int i = 0; i < numResults; ++i) |
1687 | os << formatv(Fmt: "for (auto v: castedOp0.getODSResults({0})) {{\n" |
1688 | " tblgen_types.push_back(v.getType());\n}\n" , |
1689 | Vals: resultIndex + i); |
1690 | } |
1691 | } |
1692 | os << formatv(Fmt: "{0} = rewriter.create<{1}>({2}, tblgen_types, " |
1693 | "tblgen_values, {3});\n" , |
1694 | Vals&: valuePackName, Vals: resultOp.getQualCppClassName(), Vals&: locToUse, |
1695 | Vals: useProperties ? "tblgen_props" : "tblgen_attrs" ); |
1696 | os.unindent() << "}\n" ; |
1697 | return resultValue; |
1698 | } |
1699 | |
1700 | void PatternEmitter::createSeparateLocalVarsForOpArgs( |
1701 | DagNode node, ChildNodeIndexNameMap &childNodeNames) { |
1702 | Operator &resultOp = node.getDialectOp(mapper: opMap); |
1703 | |
1704 | // Now prepare operands used for building this op: |
1705 | // * If the operand is non-variadic, we create a `Value` local variable. |
1706 | // * If the operand is variadic, we create a `SmallVector<Value>` local |
1707 | // variable. |
1708 | |
1709 | int valueIndex = 0; // An index for uniquing local variable names. |
1710 | for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { |
1711 | const auto *operand = |
1712 | llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val: resultOp.getArg(index: argIndex)); |
1713 | // We do not need special handling for attributes. |
1714 | if (!operand) |
1715 | continue; |
1716 | |
1717 | raw_indented_ostream::DelimitedScope scope(os); |
1718 | std::string varName; |
1719 | if (operand->isVariadic()) { |
1720 | varName = std::string(formatv(Fmt: "tblgen_values_{0}" , Vals: valueIndex++)); |
1721 | os << formatv(Fmt: "::llvm::SmallVector<::mlir::Value, 4> {0};\n" , Vals&: varName); |
1722 | std::string range; |
1723 | if (node.isNestedDagArg(index: argIndex)) { |
1724 | range = childNodeNames[argIndex]; |
1725 | } else { |
1726 | range = std::string(node.getArgName(index: argIndex)); |
1727 | } |
1728 | // Resolve the symbol for all range use so that we have a uniform way of |
1729 | // capturing the values. |
1730 | range = symbolInfoMap.getValueAndRangeUse(symbol: range); |
1731 | os << formatv(Fmt: "for (auto v: {0}) {{\n {1}.push_back(v);\n}\n" , Vals&: range, |
1732 | Vals&: varName); |
1733 | } else { |
1734 | varName = std::string(formatv(Fmt: "tblgen_value_{0}" , Vals: valueIndex++)); |
1735 | os << formatv(Fmt: "::mlir::Value {0} = " , Vals&: varName); |
1736 | if (node.isNestedDagArg(index: argIndex)) { |
1737 | os << symbolInfoMap.getValueAndRangeUse(symbol: childNodeNames[argIndex]); |
1738 | } else { |
1739 | DagLeaf leaf = node.getArgAsLeaf(index: argIndex); |
1740 | auto symbol = |
1741 | symbolInfoMap.getValueAndRangeUse(symbol: node.getArgName(index: argIndex)); |
1742 | if (leaf.isNativeCodeCall()) { |
1743 | os << std::string( |
1744 | tgfmt(fmt: leaf.getNativeCodeTemplate(), ctx: &fmtCtx.withSelf(subst: symbol))); |
1745 | } else { |
1746 | os << symbol; |
1747 | } |
1748 | } |
1749 | os << ";\n" ; |
1750 | } |
1751 | |
1752 | // Update to use the newly created local variable for building the op later. |
1753 | childNodeNames[argIndex] = varName; |
1754 | } |
1755 | } |
1756 | |
1757 | void PatternEmitter::supplyValuesForOpArgs( |
1758 | DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { |
1759 | Operator &resultOp = node.getDialectOp(mapper: opMap); |
1760 | for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); |
1761 | argIndex != numOpArgs; ++argIndex) { |
1762 | // Start each argument on its own line. |
1763 | os << ",\n " ; |
1764 | |
1765 | Argument opArg = resultOp.getArg(index: argIndex); |
1766 | // Handle the case of operand first. |
1767 | if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: opArg)) { |
1768 | if (!operand->name.empty()) |
1769 | os << "/*" << operand->name << "=*/" ; |
1770 | os << childNodeNames.lookup(Val: argIndex); |
1771 | continue; |
1772 | } |
1773 | |
1774 | // The argument in the op definition. |
1775 | auto opArgName = resultOp.getArgName(index: argIndex); |
1776 | if (auto subTree = node.getArgAsNestedDag(index: argIndex)) { |
1777 | if (!subTree.isNativeCodeCall()) |
1778 | PrintFatalError(ErrorLoc: loc, Msg: "only NativeCodeCall allowed in nested dag node " |
1779 | "for creating attribute" ); |
1780 | os << formatv(Fmt: "/*{0}=*/{1}" , Vals&: opArgName, Vals: childNodeNames.lookup(Val: argIndex)); |
1781 | } else { |
1782 | auto leaf = node.getArgAsLeaf(index: argIndex); |
1783 | // The argument in the result DAG pattern. |
1784 | auto patArgName = node.getArgName(index: argIndex); |
1785 | if (leaf.isConstantAttr() || leaf.isEnumCase()) { |
1786 | // TODO: Refactor out into map to avoid recomputing these. |
1787 | if (!isa<NamedAttribute *>(Val: opArg)) |
1788 | PrintFatalError(ErrorLoc: loc, Msg: Twine("expected attribute " ) + Twine(argIndex)); |
1789 | if (!patArgName.empty()) |
1790 | os << "/*" << patArgName << "=*/" ; |
1791 | } else { |
1792 | os << "/*" << opArgName << "=*/" ; |
1793 | } |
1794 | os << handleOpArgument(leaf, patArgName); |
1795 | } |
1796 | } |
1797 | } |
1798 | |
1799 | void PatternEmitter::createAggregateLocalVarsForOpArgs( |
1800 | DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { |
1801 | Operator &resultOp = node.getDialectOp(mapper: opMap); |
1802 | |
1803 | bool useProperties = resultOp.getDialect().usePropertiesForAttributes(); |
1804 | auto scope = os.scope(); |
1805 | os << formatv(Fmt: "::llvm::SmallVector<::mlir::Value, 4> " |
1806 | "tblgen_values; (void)tblgen_values;\n" ); |
1807 | if (useProperties) { |
1808 | os << formatv(Fmt: "{0}::Properties tblgen_props; (void)tblgen_props;\n" , |
1809 | Vals: resultOp.getQualCppClassName()); |
1810 | } else { |
1811 | os << formatv(Fmt: "::llvm::SmallVector<::mlir::NamedAttribute, 4> " |
1812 | "tblgen_attrs; (void)tblgen_attrs;\n" ); |
1813 | } |
1814 | |
1815 | const char *setPropCmd = |
1816 | "tblgen_props.{0} = " |
1817 | "::llvm::dyn_cast_if_present<decltype(tblgen_props.{0})>({1});\n" ; |
1818 | const char *addAttrCmd = |
1819 | "if (auto tmpAttr = {1}) {\n" |
1820 | " tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), " |
1821 | "tmpAttr);\n}\n" ; |
1822 | const char *setterCmd = (useProperties) ? setPropCmd : addAttrCmd; |
1823 | |
1824 | int numVariadic = 0; |
1825 | bool hasOperandSegmentSizes = false; |
1826 | std::vector<std::string> sizes; |
1827 | for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { |
1828 | if (isa<NamedAttribute *>(Val: resultOp.getArg(index: argIndex))) { |
1829 | // The argument in the op definition. |
1830 | auto opArgName = resultOp.getArgName(index: argIndex); |
1831 | hasOperandSegmentSizes = |
1832 | hasOperandSegmentSizes || opArgName == "operandSegmentSizes" ; |
1833 | if (auto subTree = node.getArgAsNestedDag(index: argIndex)) { |
1834 | if (!subTree.isNativeCodeCall()) |
1835 | PrintFatalError(ErrorLoc: loc, Msg: "only NativeCodeCall allowed in nested dag node " |
1836 | "for creating attribute" ); |
1837 | |
1838 | os << formatv(Fmt: setterCmd, Vals&: opArgName, Vals: childNodeNames.lookup(Val: argIndex)); |
1839 | } else { |
1840 | auto leaf = node.getArgAsLeaf(index: argIndex); |
1841 | // The argument in the result DAG pattern. |
1842 | auto patArgName = node.getArgName(index: argIndex); |
1843 | os << formatv(Fmt: setterCmd, Vals&: opArgName, Vals: handleOpArgument(leaf, patArgName)); |
1844 | } |
1845 | continue; |
1846 | } |
1847 | |
1848 | const auto *operand = |
1849 | cast<NamedTypeConstraint *>(Val: resultOp.getArg(index: argIndex)); |
1850 | if (operand->isVariadic()) { |
1851 | ++numVariadic; |
1852 | std::string range; |
1853 | if (node.isNestedDagArg(index: argIndex)) { |
1854 | range = childNodeNames.lookup(Val: argIndex); |
1855 | } else { |
1856 | range = std::string(node.getArgName(index: argIndex)); |
1857 | } |
1858 | // Resolve the symbol for all range use so that we have a uniform way of |
1859 | // capturing the values. |
1860 | range = symbolInfoMap.getValueAndRangeUse(symbol: range); |
1861 | os << formatv(Fmt: "for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n" , |
1862 | Vals&: range); |
1863 | sizes.push_back(x: formatv(Fmt: "static_cast<int32_t>({0}.size())" , Vals&: range)); |
1864 | } else { |
1865 | sizes.emplace_back(args: "1" ); |
1866 | os << formatv(Fmt: "tblgen_values.push_back(" ); |
1867 | if (node.isNestedDagArg(index: argIndex)) { |
1868 | os << symbolInfoMap.getValueAndRangeUse( |
1869 | symbol: childNodeNames.lookup(Val: argIndex)); |
1870 | } else { |
1871 | DagLeaf leaf = node.getArgAsLeaf(index: argIndex); |
1872 | if (leaf.isConstantAttr()) |
1873 | // TODO: Use better location |
1874 | PrintFatalError( |
1875 | ErrorLoc: loc, |
1876 | Msg: "attribute found where value was expected, if attempting to use " |
1877 | "constant value, construct a constant op with given attribute " |
1878 | "instead" ); |
1879 | |
1880 | auto symbol = |
1881 | symbolInfoMap.getValueAndRangeUse(symbol: node.getArgName(index: argIndex)); |
1882 | if (leaf.isNativeCodeCall()) { |
1883 | os << std::string( |
1884 | tgfmt(fmt: leaf.getNativeCodeTemplate(), ctx: &fmtCtx.withSelf(subst: symbol))); |
1885 | } else { |
1886 | os << symbol; |
1887 | } |
1888 | } |
1889 | os << ");\n" ; |
1890 | } |
1891 | } |
1892 | |
1893 | if (numVariadic > 1 && !hasOperandSegmentSizes) { |
1894 | // Only set size if it can't be computed. |
1895 | const auto *sameVariadicSize = |
1896 | resultOp.getTrait(trait: "::mlir::OpTrait::SameVariadicOperandSize" ); |
1897 | if (!sameVariadicSize) { |
1898 | if (useProperties) { |
1899 | const char *setSizes = R"( |
1900 | tblgen_props.operandSegmentSizes = {{ {0} }; |
1901 | )" ; |
1902 | os.printReindented(str: formatv(Fmt: setSizes, Vals: llvm::join(R&: sizes, Separator: ", " )).str()); |
1903 | } else { |
1904 | const char *setSizes = R"( |
1905 | tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"), |
1906 | rewriter.getDenseI32ArrayAttr({{ {0} })); |
1907 | )" ; |
1908 | os.printReindented(str: formatv(Fmt: setSizes, Vals: llvm::join(R&: sizes, Separator: ", " )).str()); |
1909 | } |
1910 | } |
1911 | } |
1912 | } |
1913 | |
1914 | StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os, |
1915 | const RecordKeeper &records, |
1916 | RecordOperatorMap &mapper) |
1917 | : opMap(mapper), staticVerifierEmitter(os, records) {} |
1918 | |
1919 | void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) { |
1920 | // PatternEmitter will use the static matcher if there's one generated. To |
1921 | // ensure that all the dependent static matchers are generated before emitting |
1922 | // the matching logic of the DagNode, we use topological order to achieve it. |
1923 | for (auto &dagInfo : topologicalOrder) { |
1924 | DagNode node = dagInfo.first; |
1925 | if (!useStaticMatcher(node)) |
1926 | continue; |
1927 | |
1928 | std::string funcName = |
1929 | formatv(Fmt: "static_dag_matcher_{0}" , Vals: staticMatcherCounter++); |
1930 | assert(!matcherNames.contains(node)); |
1931 | PatternEmitter(dagInfo.second, &opMap, os, *this) |
1932 | .emitStaticMatcher(tree: node, funcName); |
1933 | matcherNames[node] = funcName; |
1934 | } |
1935 | } |
1936 | |
1937 | void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) { |
1938 | staticVerifierEmitter.emitPatternConstraints(constraints: constraints.getArrayRef()); |
1939 | } |
1940 | |
1941 | void StaticMatcherHelper::addPattern(const Record *record) { |
1942 | Pattern pat(record, &opMap); |
1943 | |
1944 | // While generating the function body of the DAG matcher, it may depends on |
1945 | // other DAG matchers. To ensure the dependent matchers are ready, we compute |
1946 | // the topological order for all the DAGs and emit the DAG matchers in this |
1947 | // order. |
1948 | llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) { |
1949 | ++refStats[node]; |
1950 | |
1951 | if (refStats[node] != 1) |
1952 | return; |
1953 | |
1954 | for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i) |
1955 | if (DagNode sibling = node.getArgAsNestedDag(index: i)) |
1956 | dfs(sibling); |
1957 | else { |
1958 | DagLeaf leaf = node.getArgAsLeaf(index: i); |
1959 | if (!leaf.isUnspecified()) |
1960 | constraints.insert(X: leaf); |
1961 | } |
1962 | |
1963 | topologicalOrder.push_back(Elt: std::make_pair(x&: node, y&: record)); |
1964 | }; |
1965 | |
1966 | dfs(pat.getSourcePattern()); |
1967 | } |
1968 | |
1969 | StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) { |
1970 | if (leaf.isAttrMatcher()) { |
1971 | std::optional<StringRef> constraint = |
1972 | staticVerifierEmitter.getAttrConstraintFn(constraint: leaf.getAsConstraint()); |
1973 | assert(constraint && "attribute constraint was not uniqued" ); |
1974 | return *constraint; |
1975 | } |
1976 | assert(leaf.isOperandMatcher()); |
1977 | return staticVerifierEmitter.getTypeConstraintFn(constraint: leaf.getAsConstraint()); |
1978 | } |
1979 | |
1980 | static void emitRewriters(const RecordKeeper &records, raw_ostream &os) { |
1981 | emitSourceFileHeader(Desc: "Rewriters" , OS&: os, Record: records); |
1982 | |
1983 | auto patterns = records.getAllDerivedDefinitions(ClassName: "Pattern" ); |
1984 | |
1985 | // We put the map here because it can be shared among multiple patterns. |
1986 | RecordOperatorMap recordOpMap; |
1987 | |
1988 | // Exam all the patterns and generate static matcher for the duplicated |
1989 | // DagNode. |
1990 | StaticMatcherHelper staticMatcher(os, records, recordOpMap); |
1991 | for (const Record *p : patterns) |
1992 | staticMatcher.addPattern(record: p); |
1993 | staticMatcher.populateStaticConstraintFunctions(os); |
1994 | staticMatcher.populateStaticMatchers(os); |
1995 | |
1996 | std::vector<std::string> rewriterNames; |
1997 | rewriterNames.reserve(n: patterns.size()); |
1998 | |
1999 | std::string baseRewriterName = "GeneratedConvert" ; |
2000 | int rewriterIndex = 0; |
2001 | |
2002 | for (const Record *p : patterns) { |
2003 | std::string name; |
2004 | if (p->isAnonymous()) { |
2005 | // If no name is provided, ensure unique rewriter names simply by |
2006 | // appending unique suffix. |
2007 | name = baseRewriterName + llvm::utostr(X: rewriterIndex++); |
2008 | } else { |
2009 | name = std::string(p->getName()); |
2010 | } |
2011 | LLVM_DEBUG(llvm::dbgs() |
2012 | << "=== start generating pattern '" << name << "' ===\n" ); |
2013 | PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(rewriteName: name); |
2014 | LLVM_DEBUG(llvm::dbgs() |
2015 | << "=== done generating pattern '" << name << "' ===\n" ); |
2016 | rewriterNames.push_back(x: std::move(name)); |
2017 | } |
2018 | |
2019 | // Emit function to add the generated matchers to the pattern list. |
2020 | os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(" |
2021 | "::mlir::RewritePatternSet &patterns) {\n" ; |
2022 | for (const auto &name : rewriterNames) { |
2023 | os << " patterns.add<" << name << ">(patterns.getContext());\n" ; |
2024 | } |
2025 | os << "}\n" ; |
2026 | } |
2027 | |
2028 | static mlir::GenRegistration |
2029 | genRewriters("gen-rewriters" , "Generate pattern rewriters" , |
2030 | [](const RecordKeeper &records, raw_ostream &os) { |
2031 | emitRewriters(records, os); |
2032 | return false; |
2033 | }); |
2034 | |