1//===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Support/IndentedOstream.h"
14#include "mlir/TableGen/Argument.h"
15#include "mlir/TableGen/Attribute.h"
16#include "mlir/TableGen/CodeGenHelpers.h"
17#include "mlir/TableGen/Format.h"
18#include "mlir/TableGen/GenInfo.h"
19#include "mlir/TableGen/Operator.h"
20#include "mlir/TableGen/Pattern.h"
21#include "mlir/TableGen/Predicate.h"
22#include "mlir/TableGen/Property.h"
23#include "mlir/TableGen/Type.h"
24#include "llvm/ADT/FunctionExtras.h"
25#include "llvm/ADT/SetVector.h"
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/ADT/StringSet.h"
28#include "llvm/Support/CommandLine.h"
29#include "llvm/Support/Debug.h"
30#include "llvm/Support/FormatAdapters.h"
31#include "llvm/Support/PrettyStackTrace.h"
32#include "llvm/Support/Signals.h"
33#include "llvm/TableGen/Error.h"
34#include "llvm/TableGen/Main.h"
35#include "llvm/TableGen/Record.h"
36#include "llvm/TableGen/TableGenBackend.h"
37
38using namespace mlir;
39using namespace mlir::tblgen;
40
41using llvm::formatv;
42using llvm::Record;
43using llvm::RecordKeeper;
44
45#define DEBUG_TYPE "mlir-tblgen-rewritergen"
46
47namespace llvm {
48template <>
49struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
50 static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
51 raw_ostream &os, StringRef style) {
52 os << v.first << ":" << v.second;
53 }
54};
55} // namespace llvm
56
57//===----------------------------------------------------------------------===//
58// PatternEmitter
59//===----------------------------------------------------------------------===//
60
61namespace {
62
63class StaticMatcherHelper;
64
65class PatternEmitter {
66public:
67 PatternEmitter(const Record *pat, RecordOperatorMap *mapper, raw_ostream &os,
68 StaticMatcherHelper &helper);
69
70 // Emits the mlir::RewritePattern struct named `rewriteName`.
71 void emit(StringRef rewriteName);
72
73 // Emits the static function of DAG matcher.
74 void emitStaticMatcher(DagNode tree, std::string funcName);
75
76private:
77 // Emits the code for matching ops.
78 void emitMatchLogic(DagNode tree, StringRef opName);
79
80 // Emits the code for rewriting ops.
81 void emitRewriteLogic();
82
83 //===--------------------------------------------------------------------===//
84 // Match utilities
85 //===--------------------------------------------------------------------===//
86
87 // Emits C++ statements for matching the DAG structure.
88 void emitMatch(DagNode tree, StringRef name, int depth);
89
90 // Emit C++ function call to static DAG matcher.
91 void emitStaticMatchCall(DagNode tree, StringRef name);
92
93 // Emit C++ function call to static type/attribute constraint function.
94 void emitStaticVerifierCall(StringRef funcName, StringRef opName,
95 StringRef arg, StringRef failureStr);
96
97 // Emits C++ statements for matching using a native code call.
98 void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
99
100 // Emits C++ statements for matching the op constrained by the given DAG
101 // `tree` returning the op's variable name.
102 void emitOpMatch(DagNode tree, StringRef opName, int depth);
103
104 // Emits C++ statements for matching the `argIndex`-th argument of the given
105 // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the
106 // bound name and the constraint of the operand respectively.
107 void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName,
108 int operandIndex, DagLeaf operandMatcher,
109 StringRef argName, int argIndex,
110 std::optional<int> variadicSubIndex);
111
112 // Emits C++ statements for matching the operands which can be matched in
113 // either order.
114 void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
115 StringRef opName, int argIndex, int &operandIndex,
116 int depth);
117
118 // Emits C++ statements for matching a variadic operand.
119 void emitVariadicOperandMatch(DagNode tree, DagNode variadicArgTree,
120 StringRef opName, int argIndex,
121 int &operandIndex, int depth);
122
123 // Emits C++ statements for matching the `argIndex`-th argument of the given
124 // DAG `tree` as an attribute.
125 void emitAttributeMatch(DagNode tree, StringRef castedName, int argIndex,
126 int depth);
127
128 // Emits C++ statements for matching the `argIndex`-th argument of the given
129 // DAG `tree` as a property.
130 void emitPropertyMatch(DagNode tree, StringRef castedName, int argIndex,
131 int depth);
132
133 // Emits C++ for checking a match with a corresponding match failure
134 // diagnostic.
135 void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
136 const llvm::formatv_object_base &failureFmt);
137
138 // Emits C++ for checking a match with a corresponding match failure
139 // diagnostics.
140 void emitMatchCheck(StringRef opName, const std::string &matchStr,
141 const std::string &failureStr);
142
143 //===--------------------------------------------------------------------===//
144 // Rewrite utilities
145 //===--------------------------------------------------------------------===//
146
147 // The entry point for handling a result pattern rooted at `resultTree`. This
148 // method dispatches to concrete handlers according to `resultTree`'s kind and
149 // returns a symbol representing the whole value pack. Callers are expected to
150 // further resolve the symbol according to the specific use case.
151 //
152 // `depth` is the nesting level of `resultTree`; 0 means top-level result
153 // pattern. For top-level result pattern, `resultIndex` indicates which result
154 // of the matched root op this pattern is intended to replace, which can be
155 // used to deduce the result type of the op generated from this result
156 // pattern.
157 std::string handleResultPattern(DagNode resultTree, int resultIndex,
158 int depth);
159
160 // Emits the C++ statement to replace the matched DAG with a value built via
161 // calling native C++ code.
162 std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
163
164 // Returns the symbol of the old value serving as the replacement.
165 StringRef handleReplaceWithValue(DagNode tree);
166
167 // Emits the C++ statement to replace the matched DAG with an array of
168 // matched values.
169 std::string handleVariadic(DagNode tree, int depth);
170
171 // Trailing directives are used at the end of DAG node argument lists to
172 // specify additional behaviour for op matchers and creators, etc.
173 struct TrailingDirectives {
174 // DAG node containing the `location` directive. Null if there is none.
175 DagNode location;
176
177 // DAG node containing the `returnType` directive. Null if there is none.
178 DagNode returnType;
179
180 // Number of found trailing directives.
181 int numDirectives;
182 };
183
184 // Collect any trailing directives.
185 TrailingDirectives getTrailingDirectives(DagNode tree);
186
187 // Returns the location value to use.
188 std::string getLocation(TrailingDirectives &tail);
189
190 // Returns the location value to use.
191 std::string handleLocationDirective(DagNode tree);
192
193 // Emit return type argument.
194 std::string handleReturnTypeArg(DagNode returnType, int i, int depth);
195
196 // Emits the C++ statement to build a new op out of the given DAG `tree` and
197 // returns the variable name that this op is assigned to. If the root op in
198 // DAG `tree` has a specified name, the created op will be assigned to a
199 // variable of the given name. Otherwise, a unique name will be used as the
200 // result value name.
201 std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
202
203 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
204
205 // Emits a local variable for each value and attribute to be used for creating
206 // an op.
207 void createSeparateLocalVarsForOpArgs(DagNode node,
208 ChildNodeIndexNameMap &childNodeNames);
209
210 // Emits the concrete arguments used to call an op's builder.
211 void supplyValuesForOpArgs(DagNode node,
212 const ChildNodeIndexNameMap &childNodeNames,
213 int depth);
214
215 // Emits the local variables for holding all values as a whole and all named
216 // attributes as a whole to be used for creating an op.
217 void createAggregateLocalVarsForOpArgs(
218 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
219
220 // Returns the C++ expression to construct a constant attribute of the given
221 // `value` for the given attribute kind `attr`.
222 std::string handleConstantAttr(Attribute attr, const Twine &value);
223
224 // Returns the C++ expression to build an argument from the given DAG `leaf`.
225 // `patArgName` is used to bound the argument to the source pattern.
226 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
227
228 //===--------------------------------------------------------------------===//
229 // General utilities
230 //===--------------------------------------------------------------------===//
231
232 // Collects all of the operations within the given dag tree.
233 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
234
235 // Returns a unique symbol for a local variable of the given `op`.
236 std::string getUniqueSymbol(const Operator *op);
237
238 //===--------------------------------------------------------------------===//
239 // Symbol utilities
240 //===--------------------------------------------------------------------===//
241
242 // Returns how many static values the given DAG `node` correspond to.
243 int getNodeValueCount(DagNode node);
244
245private:
246 // Pattern instantiation location followed by the location of multiclass
247 // prototypes used. This is intended to be used as a whole to
248 // PrintFatalError() on errors.
249 ArrayRef<SMLoc> loc;
250
251 // Op's TableGen Record to wrapper object.
252 RecordOperatorMap *opMap;
253
254 // Handy wrapper for pattern being emitted.
255 Pattern pattern;
256
257 // Map for all bound symbols' info.
258 SymbolInfoMap symbolInfoMap;
259
260 StaticMatcherHelper &staticMatcherHelper;
261
262 // The next unused ID for newly created values.
263 unsigned nextValueId = 0;
264
265 raw_indented_ostream os;
266
267 // Format contexts containing placeholder substitutions.
268 FmtContext fmtCtx;
269};
270
271// Tracks DagNode's reference multiple times across patterns. Enables generating
272// static matcher functions for DagNode's referenced multiple times rather than
273// inlining them.
274class StaticMatcherHelper {
275public:
276 StaticMatcherHelper(raw_ostream &os, const RecordKeeper &records,
277 RecordOperatorMap &mapper);
278
279 // Determine if we should inline the match logic or delegate to a static
280 // function.
281 bool useStaticMatcher(DagNode node) {
282 // either/variadic node must be associated to the parentOp, thus we can't
283 // emit a static matcher rooted at them.
284 if (node.isEither() || node.isVariadic())
285 return false;
286
287 return refStats[node] > kStaticMatcherThreshold;
288 }
289
290 // Get the name of the static DAG matcher function corresponding to the node.
291 std::string getMatcherName(DagNode node) {
292 assert(useStaticMatcher(node));
293 return matcherNames[node];
294 }
295
296 // Get the name of static type/attribute verification function.
297 StringRef getVerifierName(DagLeaf leaf);
298
299 // Collect the `Record`s, i.e., the DRR, so that we can get the information of
300 // the duplicated DAGs.
301 void addPattern(const Record *record);
302
303 // Emit all static functions of DAG Matcher.
304 void populateStaticMatchers(raw_ostream &os);
305
306 // Emit all static functions for Constraints.
307 void populateStaticConstraintFunctions(raw_ostream &os);
308
309private:
310 static constexpr unsigned kStaticMatcherThreshold = 1;
311
312 // Consider two patterns as down below,
313 // DagNode_Root_A DagNode_Root_B
314 // \ \
315 // DagNode_C DagNode_C
316 // \ \
317 // DagNode_D DagNode_D
318 //
319 // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of
320 // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced
321 // multiple times so we'll have static matchers for both of them. When we're
322 // emitting the match logic for DagNode_C, we will check if DagNode_D has the
323 // static matcher generated. If so, then we'll generate a call to the
324 // function, inline otherwise. In this case, inlining is not what we want. As
325 // a result, generate the static matcher in topological order to ensure all
326 // the dependent static matchers are generated and we can avoid accidentally
327 // inlining.
328 //
329 // The topological order of all the DagNodes among all patterns.
330 SmallVector<std::pair<DagNode, const Record *>> topologicalOrder;
331
332 RecordOperatorMap &opMap;
333
334 // Records of the static function name of each DagNode
335 DenseMap<DagNode, std::string> matcherNames;
336
337 // After collecting all the DagNode in each pattern, `refStats` records the
338 // number of users for each DagNode. We will generate the static matcher for a
339 // DagNode while the number of users exceeds a certain threshold.
340 DenseMap<DagNode, unsigned> refStats;
341
342 // Number of static matcher generated. This is used to generate a unique name
343 // for each DagNode.
344 int staticMatcherCounter = 0;
345
346 // The DagLeaf which contains type, attr, or prop constraint.
347 SetVector<DagLeaf> constraints;
348
349 // Static type/attribute verification function emitter.
350 StaticVerifierFunctionEmitter staticVerifierEmitter;
351};
352
353} // namespace
354
355PatternEmitter::PatternEmitter(const Record *pat, RecordOperatorMap *mapper,
356 raw_ostream &os, StaticMatcherHelper &helper)
357 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
358 symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), os(os) {
359 fmtCtx.withBuilder(subst: "rewriter");
360}
361
362std::string PatternEmitter::handleConstantAttr(Attribute attr,
363 const Twine &value) {
364 if (!attr.isConstBuildable())
365 PrintFatalError(ErrorLoc: loc, Msg: "Attribute " + attr.getAttrDefName() +
366 " does not have the 'constBuilderCall' field");
367
368 // TODO: Verify the constants here
369 return std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fmtCtx, vals: value));
370}
371
372void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) {
373 os << formatv(
374 Fmt: "static ::llvm::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
375 "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation "
376 "*, 4> &tblgen_ops",
377 Vals&: funcName);
378
379 // We pass the reference of the variables that need to be captured. Hence we
380 // need to collect all the symbols in the tree first.
381 pattern.collectBoundSymbols(tree, infoMap&: symbolInfoMap, /*isSrcPattern=*/true);
382 symbolInfoMap.assignUniqueAlternativeNames();
383 for (const auto &info : symbolInfoMap)
384 os << formatv(Fmt: ", {0}", Vals: info.second.getArgDecl(name: info.first));
385
386 os << ") {\n";
387 os.indent();
388 os << "(void)tblgen_ops;\n";
389
390 // Note that a static matcher is considered at least one step from the match
391 // entry.
392 emitMatch(tree, name: "op0", /*depth=*/1);
393
394 os << "return ::mlir::success();\n";
395 os.unindent();
396 os << "}\n\n";
397}
398
399// Helper function to match patterns.
400void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
401 if (tree.isNativeCodeCall()) {
402 emitNativeCodeMatch(tree, name, depth);
403 return;
404 }
405
406 if (tree.isOperation()) {
407 emitOpMatch(tree, opName: name, depth);
408 return;
409 }
410
411 PrintFatalError(ErrorLoc: loc, Msg: "encountered non-op, non-NativeCodeCall match.");
412}
413
414void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) {
415 std::string funcName = staticMatcherHelper.getMatcherName(node: tree);
416 os << formatv(Fmt: "if(::mlir::failed({0}(rewriter, {1}, tblgen_ops", Vals&: funcName,
417 Vals&: opName);
418
419 // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in
420 // one pass.
421
422 // In general, bound symbol should have the unique name in the pattern but
423 // for the operand, binding same symbol to multiple operands imply a
424 // constraint at the same time. In this case, we will rename those operands
425 // with different names. As a result, we need to collect all the symbolInfos
426 // from the DagNode then get the updated name of the local variables from the
427 // global symbolInfoMap.
428
429 // Collect all the bound symbols in the Dag
430 SymbolInfoMap localSymbolMap(loc);
431 pattern.collectBoundSymbols(tree, infoMap&: localSymbolMap, /*isSrcPattern=*/true);
432
433 for (const auto &info : localSymbolMap) {
434 auto name = info.first;
435 auto symboInfo = info.second;
436 auto ret = symbolInfoMap.findBoundSymbol(key: name, symbolInfo: symboInfo);
437 os << formatv(Fmt: ", {0}", Vals: ret->second.getVarName(name));
438 }
439
440 os << "))) {\n";
441 os.scope().os << "return ::mlir::failure();\n";
442 os << "}\n";
443}
444
445void PatternEmitter::emitStaticVerifierCall(StringRef funcName,
446 StringRef opName, StringRef arg,
447 StringRef failureStr) {
448 os << formatv(Fmt: "if(::mlir::failed({0}(rewriter, {1}, {2}, {3}))) {{\n",
449 Vals&: funcName, Vals&: opName, Vals&: arg, Vals&: failureStr);
450 os.scope().os << "return ::mlir::failure();\n";
451 os << "}\n";
452}
453
454// Helper function to match patterns.
455void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
456 int depth) {
457 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
458 LLVM_DEBUG(tree.print(llvm::dbgs()));
459 LLVM_DEBUG(llvm::dbgs() << '\n');
460
461 // The order of generating static matcher follows the topological order so
462 // that for every dependent DagNode already have their static matcher
463 // generated if needed. The reason we check if `getMatcherName(tree).empty()`
464 // is when we are generating the static matcher for a DagNode itself. In this
465 // case, we need to emit the function body rather than a function call.
466 if (staticMatcherHelper.useStaticMatcher(node: tree) &&
467 !staticMatcherHelper.getMatcherName(node: tree).empty()) {
468 emitStaticMatchCall(tree, opName);
469
470 // NativeCodeCall will never be at depth 0 so that we don't need to catch
471 // the root operation as emitOpMatch();
472
473 return;
474 }
475
476 // TODO(suderman): iterate through arguments, determine their types, output
477 // names.
478 SmallVector<std::string, 8> capture;
479
480 raw_indented_ostream::DelimitedScope scope(os);
481
482 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
483 std::string argName = formatv(Fmt: "arg{0}_{1}", Vals&: depth, Vals&: i);
484 if (DagNode argTree = tree.getArgAsNestedDag(index: i)) {
485 if (argTree.isEither())
486 PrintFatalError(ErrorLoc: loc, Msg: "NativeCodeCall cannot have `either` operands");
487 if (argTree.isVariadic())
488 PrintFatalError(ErrorLoc: loc, Msg: "NativeCodeCall cannot have `variadic` operands");
489
490 os << "::mlir::Value " << argName << ";\n";
491 } else {
492 auto leaf = tree.getArgAsLeaf(index: i);
493 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
494 os << "::mlir::Attribute " << argName << ";\n";
495 } else if (leaf.isPropMatcher()) {
496 StringRef interfaceType = leaf.getAsPropConstraint().getInterfaceType();
497 if (interfaceType.empty())
498 PrintFatalError(ErrorLoc: loc, Msg: "NativeCodeCall cannot have a property operand "
499 "with unspecified interface type");
500 os << interfaceType << " " << argName;
501 if (leaf.isPropDefinition()) {
502 Property propDef = leaf.getAsProperty();
503 // Ensure properties that aren't zero-arg-constructable still work.
504 if (propDef.hasDefaultValue())
505 os << " = " << propDef.getDefaultValue();
506 }
507 os << ";\n";
508 } else {
509 os << "::mlir::Value " << argName << ";\n";
510 }
511 }
512
513 capture.push_back(Elt: std::move(argName));
514 }
515
516 auto tail = getTrailingDirectives(tree);
517 if (tail.returnType)
518 PrintFatalError(ErrorLoc: loc, Msg: "`NativeCodeCall` cannot have return type specifier");
519 auto locToUse = getLocation(tail);
520
521 auto fmt = tree.getNativeCodeTemplate();
522 if (fmt.count(Str: "$_self") != 1)
523 PrintFatalError(ErrorLoc: loc, Msg: "NativeCodeCall must have $_self as argument for "
524 "passing the defining Operation");
525
526 auto nativeCodeCall = std::string(
527 tgfmt(fmt, ctx: &fmtCtx.addSubst(placeholder: "_loc", subst: locToUse).withSelf(subst: opName.str()),
528 params: static_cast<ArrayRef<std::string>>(capture)));
529
530 emitMatchCheck(opName, matchStr: formatv(Fmt: "!::mlir::failed({0})", Vals&: nativeCodeCall),
531 failureStr: formatv(Fmt: "\"{0} return ::mlir::failure\"", Vals&: nativeCodeCall));
532
533 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
534 auto name = tree.getArgName(index: i);
535 if (!name.empty() && name != "_") {
536 os << formatv(Fmt: "{0} = {1};\n", Vals&: name, Vals&: capture[i]);
537 }
538 }
539
540 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
541 std::string argName = capture[i];
542
543 // Handle nested DAG construct first
544 if (tree.getArgAsNestedDag(index: i)) {
545 PrintFatalError(
546 ErrorLoc: loc, Msg: formatv(Fmt: "Matching nested tree in NativeCodecall not support for "
547 "{0} as arg {1}",
548 Vals&: argName, Vals&: i));
549 }
550
551 DagLeaf leaf = tree.getArgAsLeaf(index: i);
552
553 // The parameter for native function doesn't bind any constraints.
554 if (leaf.isUnspecified())
555 continue;
556
557 auto constraint = leaf.getAsConstraint();
558
559 std::string self;
560 if (leaf.isAttrMatcher() || leaf.isConstantAttr() || leaf.isPropMatcher())
561 self = argName;
562 else
563 self = formatv(Fmt: "{0}.getType()", Vals&: argName);
564 StringRef verifier = staticMatcherHelper.getVerifierName(leaf);
565 emitStaticVerifierCall(
566 funcName: verifier, opName, arg: self,
567 failureStr: formatv(Fmt: "\"operand {0} of native code call '{1}' failed to satisfy "
568 "constraint: "
569 "'{2}'\"",
570 Vals&: i, Vals: tree.getNativeCodeTemplate(),
571 Vals: escapeString(value: constraint.getSummary()))
572 .str());
573 }
574
575 LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
576}
577
578// Helper function to match patterns.
579void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
580 Operator &op = tree.getDialectOp(mapper: opMap);
581 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
582 << op.getOperationName() << "' at depth " << depth
583 << '\n');
584
585 auto getCastedName = [depth]() -> std::string {
586 return formatv(Fmt: "castedOp{0}", Vals: depth);
587 };
588
589 // The order of generating static matcher follows the topological order so
590 // that for every dependent DagNode already have their static matcher
591 // generated if needed. The reason we check if `getMatcherName(tree).empty()`
592 // is when we are generating the static matcher for a DagNode itself. In this
593 // case, we need to emit the function body rather than a function call.
594 if (staticMatcherHelper.useStaticMatcher(node: tree) &&
595 !staticMatcherHelper.getMatcherName(node: tree).empty()) {
596 emitStaticMatchCall(tree, opName);
597 // In the codegen of rewriter, we suppose that castedOp0 will capture the
598 // root operation. Manually add it if the root DagNode is a static matcher.
599 if (depth == 0)
600 os << formatv(Fmt: "auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); "
601 "(void){2};\n",
602 Vals&: opName, Vals: op.getQualCppClassName(), Vals: getCastedName());
603 return;
604 }
605
606 std::string castedName = getCastedName();
607 os << formatv(Fmt: "auto {0} = ::llvm::dyn_cast<{2}>({1}); "
608 "(void){0};\n",
609 Vals&: castedName, Vals&: opName, Vals: op.getQualCppClassName());
610
611 // Skip the operand matching at depth 0 as the pattern rewriter already does.
612 if (depth != 0)
613 emitMatchCheck(opName, /*matchStr=*/castedName,
614 failureStr: formatv(Fmt: "\"{0} is not {1} type\"", Vals&: castedName,
615 Vals: op.getQualCppClassName()));
616
617 // If the operand's name is set, set to that variable.
618 auto name = tree.getSymbol();
619 if (!name.empty())
620 os << formatv(Fmt: "{0} = {1};\n", Vals&: name, Vals&: castedName);
621
622 for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e;
623 ++i, ++opArgIdx) {
624 auto opArg = op.getArg(index: opArgIdx);
625 std::string argName = formatv(Fmt: "op{0}", Vals: depth + 1);
626
627 // Handle nested DAG construct first
628 if (DagNode argTree = tree.getArgAsNestedDag(index: i)) {
629 if (argTree.isEither()) {
630 emitEitherOperandMatch(tree, eitherArgTree: argTree, opName: castedName, argIndex: opArgIdx, operandIndex&: nextOperand,
631 depth);
632 ++opArgIdx;
633 continue;
634 }
635 if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: opArg)) {
636 if (argTree.isVariadic()) {
637 if (!operand->isVariadic()) {
638 auto error = formatv(Fmt: "variadic DAG construct can't match op {0}'s "
639 "non-variadic operand #{1}",
640 Vals: op.getOperationName(), Vals&: opArgIdx);
641 PrintFatalError(ErrorLoc: loc, Msg: error);
642 }
643 emitVariadicOperandMatch(tree, variadicArgTree: argTree, opName: castedName, argIndex: opArgIdx,
644 operandIndex&: nextOperand, depth);
645 ++nextOperand;
646 continue;
647 }
648 if (operand->isVariableLength()) {
649 auto error = formatv(Fmt: "use nested DAG construct to match op {0}'s "
650 "variadic operand #{1} unsupported now",
651 Vals: op.getOperationName(), Vals&: opArgIdx);
652 PrintFatalError(ErrorLoc: loc, Msg: error);
653 }
654 }
655
656 os << "{\n";
657
658 // Attributes don't count for getODSOperands.
659 // TODO: Operand is a Value, check if we should remove `getDefiningOp()`.
660 os.indent() << formatv(
661 Fmt: "auto *{0} = "
662 "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
663 Vals&: argName, Vals&: castedName, Vals&: nextOperand);
664 // Null check of operand's definingOp
665 emitMatchCheck(
666 opName: castedName, /*matchStr=*/argName,
667 failureStr: formatv(Fmt: "\"There's no operation that defines operand {0} of {1}\"",
668 Vals: nextOperand++, Vals&: castedName));
669 emitMatch(tree: argTree, name: argName, depth: depth + 1);
670 os << formatv(Fmt: "tblgen_ops.push_back({0});\n", Vals&: argName);
671 os.unindent() << "}\n";
672 continue;
673 }
674
675 // Next handle DAG leaf: operand or attribute
676 if (isa<NamedTypeConstraint *>(Val: opArg)) {
677 auto operandName =
678 formatv(Fmt: "{0}.getODSOperands({1})", Vals&: castedName, Vals&: nextOperand);
679 emitOperandMatch(tree, opName: castedName, operandName: operandName.str(), operandIndex: nextOperand,
680 /*operandMatcher=*/tree.getArgAsLeaf(index: i),
681 /*argName=*/tree.getArgName(index: i), argIndex: opArgIdx,
682 /*variadicSubIndex=*/std::nullopt);
683 ++nextOperand;
684 } else if (isa<NamedAttribute *>(Val: opArg)) {
685 emitAttributeMatch(tree, castedName, argIndex: opArgIdx, depth);
686 } else if (isa<NamedProperty *>(Val: opArg)) {
687 emitPropertyMatch(tree, castedName, argIndex: opArgIdx, depth);
688 } else {
689 PrintFatalError(ErrorLoc: loc, Msg: "unhandled case when matching op");
690 }
691 }
692 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
693 << op.getOperationName() << "' at depth " << depth
694 << '\n');
695}
696
697void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
698 StringRef operandName, int operandIndex,
699 DagLeaf operandMatcher, StringRef argName,
700 int argIndex,
701 std::optional<int> variadicSubIndex) {
702 Operator &op = tree.getDialectOp(mapper: opMap);
703 NamedTypeConstraint operand = op.getOperand(index: operandIndex);
704
705 // If a constraint is specified, we need to generate C++ statements to
706 // check the constraint.
707 if (!operandMatcher.isUnspecified()) {
708 if (!operandMatcher.isOperandMatcher())
709 PrintFatalError(
710 ErrorLoc: loc, Msg: formatv(Fmt: "the {1}-th argument of op '{0}' should be an operand",
711 Vals: op.getOperationName(), Vals: argIndex + 1));
712
713 // Only need to verify if the matcher's type is different from the one
714 // of op definition.
715 Constraint constraint = operandMatcher.getAsConstraint();
716 if (operand.constraint != constraint) {
717 if (operand.isVariableLength()) {
718 auto error = formatv(
719 Fmt: "further constrain op {0}'s variadic operand #{1} unsupported now",
720 Vals: op.getOperationName(), Vals&: argIndex);
721 PrintFatalError(ErrorLoc: loc, Msg: error);
722 }
723 auto self = formatv(Fmt: "(*{0}.begin()).getType()", Vals&: operandName);
724 StringRef verifier = staticMatcherHelper.getVerifierName(leaf: operandMatcher);
725 emitStaticVerifierCall(
726 funcName: verifier, opName, arg: self.str(),
727 failureStr: formatv(
728 Fmt: "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
729 Vals&: operandIndex, Vals: op.getOperationName(),
730 Vals: escapeString(value: constraint.getSummary()))
731 .str());
732 }
733 }
734
735 // Capture the value
736 // `$_` is a special symbol to ignore op argument matching.
737 if (!argName.empty() && argName != "_") {
738 auto res = symbolInfoMap.findBoundSymbol(key: argName, node: tree, op, argIndex,
739 variadicSubIndex);
740 if (res == symbolInfoMap.end())
741 PrintFatalError(ErrorLoc: loc, Msg: formatv(Fmt: "symbol not found: {0}", Vals&: argName));
742
743 os << formatv(Fmt: "{0} = {1};\n", Vals: res->second.getVarName(name: argName), Vals&: operandName);
744 }
745}
746
747void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
748 StringRef opName, int argIndex,
749 int &operandIndex, int depth) {
750 constexpr int numEitherArgs = 2;
751 if (eitherArgTree.getNumArgs() != numEitherArgs)
752 PrintFatalError(ErrorLoc: loc, Msg: "`either` only supports grouping two operands");
753
754 Operator &op = tree.getDialectOp(mapper: opMap);
755
756 std::string codeBuffer;
757 llvm::raw_string_ostream tblgenOps(codeBuffer);
758
759 std::string lambda = formatv(Fmt: "eitherLambda{0}", Vals&: depth);
760 os << formatv(
761 Fmt: "auto {0} = [&](::mlir::OperandRange v0, ::mlir::OperandRange v1) {{\n",
762 Vals&: lambda);
763
764 os.indent();
765
766 for (int i = 0; i < numEitherArgs; ++i, ++argIndex) {
767 if (DagNode argTree = eitherArgTree.getArgAsNestedDag(index: i)) {
768 if (argTree.isEither())
769 PrintFatalError(ErrorLoc: loc, Msg: "either cannot be nested");
770
771 std::string argName = formatv(Fmt: "local_op_{0}", Vals&: i).str();
772
773 os << formatv(Fmt: "auto {0} = (*v{1}.begin()).getDefiningOp();\n", Vals&: argName,
774 Vals&: i);
775
776 // Indent emitMatchCheck and emitMatch because they declare local
777 // variables.
778 os << "{\n";
779 os.indent();
780
781 emitMatchCheck(
782 opName, /*matchStr=*/argName,
783 failureStr: formatv(Fmt: "\"There's no operation that defines operand {0} of {1}\"",
784 Vals: operandIndex++, Vals&: opName));
785 emitMatch(tree: argTree, name: argName, depth: depth + 1);
786
787 os.unindent() << "}\n";
788
789 // `tblgen_ops` is used to collect the matched operations. In either, we
790 // need to queue the operation only if the matching success. Thus we emit
791 // the code at the end.
792 tblgenOps << formatv(Fmt: "tblgen_ops.push_back({0});\n", Vals&: argName);
793 } else if (isa<NamedTypeConstraint *>(Val: op.getArg(index: argIndex))) {
794 emitOperandMatch(tree, opName, /*operandName=*/formatv(Fmt: "v{0}", Vals&: i).str(),
795 operandIndex,
796 /*operandMatcher=*/eitherArgTree.getArgAsLeaf(index: i),
797 /*argName=*/eitherArgTree.getArgName(index: i), argIndex,
798 /*variadicSubIndex=*/std::nullopt);
799 ++operandIndex;
800 } else {
801 PrintFatalError(ErrorLoc: loc, Msg: "either can only be applied on operand");
802 }
803 }
804
805 os << tblgenOps.str();
806 os << "return ::mlir::success();\n";
807 os.unindent() << "};\n";
808
809 os << "{\n";
810 os.indent();
811
812 os << formatv(Fmt: "auto eitherOperand0 = {0}.getODSOperands({1});\n", Vals&: opName,
813 Vals: operandIndex - 2);
814 os << formatv(Fmt: "auto eitherOperand1 = {0}.getODSOperands({1});\n", Vals&: opName,
815 Vals: operandIndex - 1);
816
817 os << formatv(Fmt: "if(::mlir::failed({0}(eitherOperand0, eitherOperand1)) && "
818 "::mlir::failed({0}(eitherOperand1, "
819 "eitherOperand0)))\n",
820 Vals&: lambda);
821 os.indent() << "return ::mlir::failure();\n";
822
823 os.unindent().unindent() << "}\n";
824}
825
826void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
827 DagNode variadicArgTree,
828 StringRef opName, int argIndex,
829 int &operandIndex, int depth) {
830 Operator &op = tree.getDialectOp(mapper: opMap);
831
832 os << "{\n";
833 os.indent();
834
835 os << formatv(Fmt: "auto variadic_operand_range = {0}.getODSOperands({1});\n",
836 Vals&: opName, Vals&: operandIndex);
837 os << formatv(Fmt: "if (variadic_operand_range.size() != {0}) "
838 "return ::mlir::failure();\n",
839 Vals: variadicArgTree.getNumArgs());
840
841 StringRef variadicTreeName = variadicArgTree.getSymbol();
842 if (!variadicTreeName.empty()) {
843 auto res =
844 symbolInfoMap.findBoundSymbol(key: variadicTreeName, node: tree, op, argIndex,
845 /*variadicSubIndex=*/std::nullopt);
846 if (res == symbolInfoMap.end())
847 PrintFatalError(ErrorLoc: loc, Msg: formatv(Fmt: "symbol not found: {0}", Vals&: variadicTreeName));
848
849 os << formatv(Fmt: "{0} = variadic_operand_range;\n",
850 Vals: res->second.getVarName(name: variadicTreeName));
851 }
852
853 for (int i = 0; i < variadicArgTree.getNumArgs(); ++i) {
854 if (DagNode argTree = variadicArgTree.getArgAsNestedDag(index: i)) {
855 if (!argTree.isOperation())
856 PrintFatalError(ErrorLoc: loc, Msg: "variadic only accepts operation sub-dags");
857
858 os << "{\n";
859 os.indent();
860
861 std::string argName = formatv(Fmt: "local_op_{0}", Vals&: i).str();
862 os << formatv(Fmt: "auto *{0} = "
863 "variadic_operand_range[{1}].getDefiningOp();\n",
864 Vals&: argName, Vals&: i);
865 emitMatchCheck(
866 opName, /*matchStr=*/argName,
867 failureStr: formatv(Fmt: "\"There's no operation that defines variadic operand "
868 "{0} (variadic sub-opearnd #{1}) of {2}\"",
869 Vals&: operandIndex, Vals&: i, Vals&: opName));
870 emitMatch(tree: argTree, name: argName, depth: depth + 1);
871 os << formatv(Fmt: "tblgen_ops.push_back({0});\n", Vals&: argName);
872
873 os.unindent() << "}\n";
874 } else if (isa<NamedTypeConstraint *>(Val: op.getArg(index: argIndex))) {
875 auto operandName = formatv(Fmt: "variadic_operand_range.slice({0}, 1)", Vals&: i);
876 emitOperandMatch(tree, opName, operandName: operandName.str(), operandIndex,
877 /*operandMatcher=*/variadicArgTree.getArgAsLeaf(index: i),
878 /*argName=*/variadicArgTree.getArgName(index: i), argIndex, variadicSubIndex: i);
879 } else {
880 PrintFatalError(ErrorLoc: loc, Msg: "variadic can only be applied on operand");
881 }
882 }
883
884 os.unindent() << "}\n";
885}
886
887void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef castedName,
888 int argIndex, int depth) {
889 Operator &op = tree.getDialectOp(mapper: opMap);
890 auto *namedAttr = cast<NamedAttribute *>(Val: op.getArg(index: argIndex));
891 const auto &attr = namedAttr->attr;
892
893 os << "{\n";
894 if (op.getDialect().usePropertiesForAttributes()) {
895 os.indent() << formatv(
896 Fmt: "[[maybe_unused]] auto tblgen_attr = {0}.getProperties().{1}();\n",
897 Vals&: castedName, Vals: op.getGetterName(name: namedAttr->name));
898 } else {
899 os.indent() << formatv(Fmt: "[[maybe_unused]] auto tblgen_attr = "
900 "{0}->getAttrOfType<{1}>(\"{2}\");\n",
901 Vals&: castedName, Vals: attr.getStorageType(), Vals&: namedAttr->name);
902 }
903
904 // TODO: This should use getter method to avoid duplication.
905 if (attr.hasDefaultValue()) {
906 os << "if (!tblgen_attr) tblgen_attr = "
907 << std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fmtCtx,
908 vals: tgfmt(fmt: attr.getDefaultValue(), ctx: &fmtCtx)))
909 << ";\n";
910 } else if (attr.isOptional()) {
911 // For a missing attribute that is optional according to definition, we
912 // should just capture a mlir::Attribute() to signal the missing state.
913 // That is precisely what getDiscardableAttr() returns on missing
914 // attributes.
915 } else {
916 emitMatchCheck(opName: castedName, matchFmt: tgfmt(fmt: "tblgen_attr", ctx: &fmtCtx),
917 failureFmt: formatv(Fmt: "\"expected op '{0}' to have attribute '{1}' "
918 "of type '{2}'\"",
919 Vals: op.getOperationName(), Vals&: namedAttr->name,
920 Vals: attr.getStorageType()));
921 }
922
923 auto matcher = tree.getArgAsLeaf(index: argIndex);
924 if (!matcher.isUnspecified()) {
925 if (!matcher.isAttrMatcher()) {
926 PrintFatalError(
927 ErrorLoc: loc, Msg: formatv(Fmt: "the {1}-th argument of op '{0}' should be an attribute",
928 Vals: op.getOperationName(), Vals: argIndex + 1));
929 }
930
931 // If a constraint is specified, we need to generate function call to its
932 // static verifier.
933 StringRef verifier = staticMatcherHelper.getVerifierName(leaf: matcher);
934 if (attr.isOptional()) {
935 // Avoid dereferencing null attribute. This is using a simple heuristic to
936 // avoid common cases of attempting to dereference null attribute. This
937 // will return where there is no check if attribute is null unless the
938 // attribute's value is not used.
939 // FIXME: This could be improved as some null dereferences could slip
940 // through.
941 if (!StringRef(matcher.getConditionTemplate()).contains(Other: "!$_self") &&
942 StringRef(matcher.getConditionTemplate()).contains(Other: "$_self")) {
943 os << "if (!tblgen_attr) return ::mlir::failure();\n";
944 }
945 }
946 emitStaticVerifierCall(
947 funcName: verifier, opName: castedName, arg: "tblgen_attr",
948 failureStr: formatv(Fmt: "\"op '{0}' attribute '{1}' failed to satisfy constraint: "
949 "'{2}'\"",
950 Vals: op.getOperationName(), Vals&: namedAttr->name,
951 Vals: escapeString(value: matcher.getAsConstraint().getSummary()))
952 .str());
953 }
954
955 // Capture the value
956 auto name = tree.getArgName(index: argIndex);
957 // `$_` is a special symbol to ignore op argument matching.
958 if (!name.empty() && name != "_") {
959 os << formatv(Fmt: "{0} = tblgen_attr;\n", Vals&: name);
960 }
961
962 os.unindent() << "}\n";
963}
964
965void PatternEmitter::emitPropertyMatch(DagNode tree, StringRef castedName,
966 int argIndex, int depth) {
967 Operator &op = tree.getDialectOp(mapper: opMap);
968 auto *namedProp = cast<NamedProperty *>(Val: op.getArg(index: argIndex));
969
970 os << "{\n";
971 os.indent() << formatv(
972 Fmt: "[[maybe_unused]] auto tblgen_prop = {0}.getProperties().{1}();\n",
973 Vals&: castedName, Vals: op.getGetterName(name: namedProp->name));
974
975 auto matcher = tree.getArgAsLeaf(index: argIndex);
976 if (!matcher.isUnspecified()) {
977 if (!matcher.isPropMatcher()) {
978 PrintFatalError(
979 ErrorLoc: loc, Msg: formatv(Fmt: "the {1}-th argument of op '{0}' should be a property",
980 Vals: op.getOperationName(), Vals: argIndex + 1));
981 }
982
983 // If a constraint is specified, we need to generate function call to its
984 // static verifier.
985 StringRef verifier = staticMatcherHelper.getVerifierName(leaf: matcher);
986 emitStaticVerifierCall(
987 funcName: verifier, opName: castedName, arg: "tblgen_prop",
988 failureStr: formatv(Fmt: "\"op '{0}' property '{1}' failed to satisfy constraint: "
989 "'{2}'\"",
990 Vals: op.getOperationName(), Vals&: namedProp->name,
991 Vals: escapeString(value: matcher.getAsConstraint().getSummary()))
992 .str());
993 }
994
995 // Capture the value
996 auto name = tree.getArgName(index: argIndex);
997 // `$_` is a special symbol to ignore op argument matching.
998 if (!name.empty() && name != "_") {
999 os << formatv(Fmt: "{0} = tblgen_prop;\n", Vals&: name);
1000 }
1001
1002 os.unindent() << "}\n";
1003}
1004
1005void PatternEmitter::emitMatchCheck(
1006 StringRef opName, const FmtObjectBase &matchFmt,
1007 const llvm::formatv_object_base &failureFmt) {
1008 emitMatchCheck(opName, matchStr: matchFmt.str(), failureStr: failureFmt.str());
1009}
1010
1011void PatternEmitter::emitMatchCheck(StringRef opName,
1012 const std::string &matchStr,
1013 const std::string &failureStr) {
1014
1015 os << "if (!(" << matchStr << "))";
1016 os.scope(open: "{\n", close: "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
1017 << ", [&](::mlir::Diagnostic &diag) {\n diag << "
1018 << failureStr << ";\n});";
1019}
1020
1021void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
1022 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
1023 int depth = 0;
1024 emitMatch(tree, name: opName, depth);
1025
1026 for (auto &appliedConstraint : pattern.getConstraints()) {
1027 auto &constraint = appliedConstraint.constraint;
1028 auto &entities = appliedConstraint.entities;
1029
1030 auto condition = constraint.getConditionTemplate();
1031 if (isa<TypeConstraint>(Val: constraint)) {
1032 if (entities.size() != 1)
1033 PrintFatalError(ErrorLoc: loc, Msg: "type constraint requires exactly one argument");
1034
1035 auto self = formatv(Fmt: "({0}.getType())",
1036 Vals: symbolInfoMap.getValueAndRangeUse(symbol: entities.front()));
1037 emitMatchCheck(
1038 opName, matchFmt: tgfmt(fmt: condition, ctx: &fmtCtx.withSelf(subst: self.str())),
1039 failureFmt: formatv(Fmt: "\"value entity '{0}' failed to satisfy constraint: '{1}'\"",
1040 Vals&: entities.front(), Vals: escapeString(value: constraint.getSummary())));
1041
1042 } else if (isa<AttrConstraint>(Val: constraint)) {
1043 PrintFatalError(
1044 ErrorLoc: loc, Msg: "cannot use AttrConstraint in Pattern multi-entity constraints");
1045 } else {
1046 // TODO: replace formatv arguments with the exact specified
1047 // args.
1048 if (entities.size() > 4) {
1049 PrintFatalError(ErrorLoc: loc, Msg: "only support up to 4-entity constraints now");
1050 }
1051 SmallVector<std::string, 4> names;
1052 int i = 0;
1053 for (int e = entities.size(); i < e; ++i)
1054 names.push_back(Elt: symbolInfoMap.getValueAndRangeUse(symbol: entities[i]));
1055 std::string self = appliedConstraint.self;
1056 if (!self.empty())
1057 self = symbolInfoMap.getValueAndRangeUse(symbol: self);
1058 for (; i < 4; ++i)
1059 names.push_back(Elt: "<unused>");
1060 emitMatchCheck(opName,
1061 matchFmt: tgfmt(fmt: condition, ctx: &fmtCtx.withSelf(subst: self), vals&: names[0],
1062 vals&: names[1], vals&: names[2], vals&: names[3]),
1063 failureFmt: formatv(Fmt: "\"entities '{0}' failed to satisfy constraint: "
1064 "'{1}'\"",
1065 Vals: llvm::join(R&: entities, Separator: ", "),
1066 Vals: escapeString(value: constraint.getSummary())));
1067 }
1068 }
1069
1070 // Some of the operands could be bound to the same symbol name, we need
1071 // to enforce equality constraint on those.
1072 // TODO: we should be able to emit equality checks early
1073 // and short circuit unnecessary work if vars are not equal.
1074 for (auto symbolInfoIt = symbolInfoMap.begin();
1075 symbolInfoIt != symbolInfoMap.end();) {
1076 auto range = symbolInfoMap.getRangeOfEqualElements(key: symbolInfoIt->first);
1077 auto startRange = range.first;
1078 auto endRange = range.second;
1079
1080 auto firstOperand = symbolInfoIt->second.getVarName(name: symbolInfoIt->first);
1081 for (++startRange; startRange != endRange; ++startRange) {
1082 auto secondOperand = startRange->second.getVarName(name: symbolInfoIt->first);
1083 emitMatchCheck(
1084 opName,
1085 matchStr: formatv(Fmt: "*{0}.begin() == *{1}.begin()", Vals&: firstOperand, Vals&: secondOperand),
1086 failureStr: formatv(Fmt: "\"Operands '{0}' and '{1}' must be equal\"", Vals&: firstOperand,
1087 Vals&: secondOperand));
1088 }
1089
1090 symbolInfoIt = endRange;
1091 }
1092
1093 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
1094}
1095
1096void PatternEmitter::collectOps(DagNode tree,
1097 llvm::SmallPtrSetImpl<const Operator *> &ops) {
1098 // Check if this tree is an operation.
1099 if (tree.isOperation()) {
1100 const Operator &op = tree.getDialectOp(mapper: opMap);
1101 LLVM_DEBUG(llvm::dbgs()
1102 << "found operation " << op.getOperationName() << '\n');
1103 ops.insert(Ptr: &op);
1104 }
1105
1106 // Recurse the arguments of the tree.
1107 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
1108 if (auto child = tree.getArgAsNestedDag(index: i))
1109 collectOps(tree: child, ops);
1110}
1111
1112void PatternEmitter::emit(StringRef rewriteName) {
1113 // Get the DAG tree for the source pattern.
1114 DagNode sourceTree = pattern.getSourcePattern();
1115
1116 const Operator &rootOp = pattern.getSourceRootOp();
1117 auto rootName = rootOp.getOperationName();
1118
1119 // Collect the set of result operations.
1120 llvm::SmallPtrSet<const Operator *, 4> resultOps;
1121 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
1122 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
1123 collectOps(tree: pattern.getResultPattern(index: i), ops&: resultOps);
1124 }
1125 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
1126
1127 // Emit RewritePattern for Pattern.
1128 auto locs = pattern.getLocation();
1129 os << formatv(Fmt: "/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n",
1130 Vals: llvm::reverse(C&: locs));
1131 os << formatv(Fmt: R"(struct {0} : public ::mlir::RewritePattern {
1132 {0}(::mlir::MLIRContext *context)
1133 : ::mlir::RewritePattern("{1}", {2}, context, {{)",
1134 Vals&: rewriteName, Vals&: rootName, Vals: pattern.getBenefit());
1135 // Sort result operators by name.
1136 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
1137 resultOps.end());
1138 llvm::sort(C&: sortedResultOps, Comp: [&](const Operator *lhs, const Operator *rhs) {
1139 return lhs->getOperationName() < rhs->getOperationName();
1140 });
1141 llvm::interleaveComma(c: sortedResultOps, os, each_fn: [&](const Operator *op) {
1142 os << '"' << op->getOperationName() << '"';
1143 });
1144 os << "}) {}\n";
1145
1146 // Emit matchAndRewrite() function.
1147 {
1148 auto classScope = os.scope();
1149 os.printReindented(str: R"(
1150 ::llvm::LogicalResult matchAndRewrite(::mlir::Operation *op0,
1151 ::mlir::PatternRewriter &rewriter) const override {)")
1152 << '\n';
1153 {
1154 auto functionScope = os.scope();
1155
1156 // Register all symbols bound in the source pattern.
1157 pattern.collectSourcePatternBoundSymbols(infoMap&: symbolInfoMap);
1158
1159 LLVM_DEBUG(llvm::dbgs()
1160 << "start creating local variables for capturing matches\n");
1161 os << "// Variables for capturing values and attributes used while "
1162 "creating ops\n";
1163 // Create local variables for storing the arguments and results bound
1164 // to symbols.
1165 for (const auto &symbolInfoPair : symbolInfoMap) {
1166 const auto &symbol = symbolInfoPair.first;
1167 const auto &info = symbolInfoPair.second;
1168
1169 os << info.getVarDecl(name: symbol);
1170 }
1171 // TODO: capture ops with consistent numbering so that it can be
1172 // reused for fused loc.
1173 os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n";
1174 LLVM_DEBUG(llvm::dbgs()
1175 << "done creating local variables for capturing matches\n");
1176
1177 os << "// Match\n";
1178 os << "tblgen_ops.push_back(op0);\n";
1179 emitMatchLogic(tree: sourceTree, opName: "op0");
1180
1181 os << "\n// Rewrite\n";
1182 emitRewriteLogic();
1183
1184 os << "return ::mlir::success();\n";
1185 }
1186 os << "}\n";
1187 }
1188 os << "};\n\n";
1189}
1190
1191void PatternEmitter::emitRewriteLogic() {
1192 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
1193 const Operator &rootOp = pattern.getSourceRootOp();
1194 int numExpectedResults = rootOp.getNumResults();
1195 int numResultPatterns = pattern.getNumResultPatterns();
1196
1197 // First register all symbols bound to ops generated in result patterns.
1198 pattern.collectResultPatternBoundSymbols(infoMap&: symbolInfoMap);
1199
1200 // Only the last N static values generated are used to replace the matched
1201 // root N-result op. We need to calculate the starting index (of the results
1202 // of the matched op) each result pattern is to replace.
1203 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
1204 // If we don't need to replace any value at all, set the replacement starting
1205 // index as the number of result patterns so we skip all of them when trying
1206 // to replace the matched op's results.
1207 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
1208 for (int i = numResultPatterns - 1; i >= 0; --i) {
1209 auto numValues = getNodeValueCount(node: pattern.getResultPattern(index: i));
1210 offsets[i] = offsets[i + 1] - numValues;
1211 if (offsets[i] == 0) {
1212 if (replStartIndex == -1)
1213 replStartIndex = i;
1214 } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
1215 auto error = formatv(
1216 Fmt: "cannot use the same multi-result op '{0}' to generate both "
1217 "auxiliary values and values to be used for replacing the matched op",
1218 Vals: pattern.getResultPattern(index: i).getSymbol());
1219 PrintFatalError(ErrorLoc: loc, Msg: error);
1220 }
1221 }
1222
1223 if (offsets.front() > 0) {
1224 const char error[] =
1225 "not enough values generated to replace the matched op";
1226 PrintFatalError(ErrorLoc: loc, Msg: error);
1227 }
1228
1229 os << "auto odsLoc = rewriter.getFusedLoc({";
1230 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
1231 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
1232 }
1233 os << "}); (void)odsLoc;\n";
1234
1235 // Process auxiliary result patterns.
1236 for (int i = 0; i < replStartIndex; ++i) {
1237 DagNode resultTree = pattern.getResultPattern(index: i);
1238 auto val = handleResultPattern(resultTree, resultIndex: offsets[i], depth: 0);
1239 // Normal op creation will be streamed to `os` by the above call; but
1240 // NativeCodeCall will only be materialized to `os` if it is used. Here
1241 // we are handling auxiliary patterns so we want the side effect even if
1242 // NativeCodeCall is not replacing matched root op's results.
1243 if (resultTree.isNativeCodeCall() &&
1244 resultTree.getNumReturnsOfNativeCode() == 0)
1245 os << val << ";\n";
1246 }
1247
1248 auto processSupplementalPatterns = [&]() {
1249 int numSupplementalPatterns = pattern.getNumSupplementalPatterns();
1250 for (int i = 0, offset = -numSupplementalPatterns;
1251 i < numSupplementalPatterns; ++i) {
1252 DagNode resultTree = pattern.getSupplementalPattern(index: i);
1253 auto val = handleResultPattern(resultTree, resultIndex: offset++, depth: 0);
1254 if (resultTree.isNativeCodeCall() &&
1255 resultTree.getNumReturnsOfNativeCode() == 0)
1256 os << val << ";\n";
1257 }
1258 };
1259
1260 if (numExpectedResults == 0) {
1261 assert(replStartIndex >= numResultPatterns &&
1262 "invalid auxiliary vs. replacement pattern division!");
1263 processSupplementalPatterns();
1264 // No result to replace. Just erase the op.
1265 os << "rewriter.eraseOp(op0);\n";
1266 } else {
1267 // Process replacement result patterns.
1268 os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
1269 for (int i = replStartIndex; i < numResultPatterns; ++i) {
1270 DagNode resultTree = pattern.getResultPattern(index: i);
1271 auto val = handleResultPattern(resultTree, resultIndex: offsets[i], depth: 0);
1272 os << "\n";
1273 // Resolve each symbol for all range use so that we can loop over them.
1274 // We need an explicit cast to `SmallVector` to capture the cases where
1275 // `{0}` resolves to an `Operation::result_range` as well as cases that
1276 // are not iterable (e.g. vector that gets wrapped in additional braces by
1277 // RewriterGen).
1278 // TODO: Revisit the need for materializing a vector.
1279 os << symbolInfoMap.getAllRangeUse(
1280 symbol: val,
1281 fmt: "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
1282 " tblgen_repl_values.push_back(v);\n}\n",
1283 separator: "\n");
1284 }
1285 processSupplementalPatterns();
1286 os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
1287 }
1288
1289 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
1290}
1291
1292std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
1293 return std::string(
1294 formatv(Fmt: "tblgen_{0}_{1}", Vals: op->getCppClassName(), Vals: nextValueId++));
1295}
1296
1297std::string PatternEmitter::handleResultPattern(DagNode resultTree,
1298 int resultIndex, int depth) {
1299 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
1300 LLVM_DEBUG(resultTree.print(llvm::dbgs()));
1301 LLVM_DEBUG(llvm::dbgs() << '\n');
1302
1303 if (resultTree.isLocationDirective()) {
1304 PrintFatalError(ErrorLoc: loc,
1305 Msg: "location directive can only be used with op creation");
1306 }
1307
1308 if (resultTree.isNativeCodeCall())
1309 return handleReplaceWithNativeCodeCall(resultTree, depth);
1310
1311 if (resultTree.isReplaceWithValue())
1312 return handleReplaceWithValue(tree: resultTree).str();
1313
1314 if (resultTree.isVariadic())
1315 return handleVariadic(tree: resultTree, depth);
1316
1317 // Normal op creation.
1318 auto symbol = handleOpCreation(tree: resultTree, resultIndex, depth);
1319 if (resultTree.getSymbol().empty()) {
1320 // This is an op not explicitly bound to a symbol in the rewrite rule.
1321 // Register the auto-generated symbol for it.
1322 symbolInfoMap.bindOpResult(symbol, op: pattern.getDialectOp(node: resultTree));
1323 }
1324 return symbol;
1325}
1326
1327std::string PatternEmitter::handleVariadic(DagNode tree, int depth) {
1328 assert(tree.isVariadic());
1329
1330 std::string output;
1331 llvm::raw_string_ostream oss(output);
1332 auto name = std::string(formatv(Fmt: "tblgen_variadic_values_{0}", Vals: nextValueId++));
1333 symbolInfoMap.bindValue(symbol: name);
1334 oss << "::llvm::SmallVector<::mlir::Value, 4> " << name << ";\n";
1335 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
1336 if (auto child = tree.getArgAsNestedDag(index: i)) {
1337 oss << name << ".push_back(" << handleResultPattern(resultTree: child, resultIndex: i, depth: depth + 1)
1338 << ");\n";
1339 } else {
1340 oss << name << ".push_back("
1341 << handleOpArgument(leaf: tree.getArgAsLeaf(index: i), patArgName: tree.getArgName(index: i))
1342 << ");\n";
1343 }
1344 }
1345
1346 os << oss.str();
1347 return name;
1348}
1349
1350StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
1351 assert(tree.isReplaceWithValue());
1352
1353 if (tree.getNumArgs() != 1) {
1354 PrintFatalError(
1355 ErrorLoc: loc, Msg: "replaceWithValue directive must take exactly one argument");
1356 }
1357
1358 if (!tree.getSymbol().empty()) {
1359 PrintFatalError(ErrorLoc: loc, Msg: "cannot bind symbol to replaceWithValue");
1360 }
1361
1362 return tree.getArgName(index: 0);
1363}
1364
1365std::string PatternEmitter::handleLocationDirective(DagNode tree) {
1366 assert(tree.isLocationDirective());
1367 auto lookUpArgLoc = [this, &tree](int idx) {
1368 const auto *const lookupFmt = "{0}.getLoc()";
1369 return symbolInfoMap.getValueAndRangeUse(symbol: tree.getArgName(index: idx), fmt: lookupFmt);
1370 };
1371
1372 if (tree.getNumArgs() == 0)
1373 llvm::PrintFatalError(
1374 Msg: "At least one argument to location directive required");
1375
1376 if (!tree.getSymbol().empty())
1377 PrintFatalError(ErrorLoc: loc, Msg: "cannot bind symbol to location");
1378
1379 if (tree.getNumArgs() == 1) {
1380 DagLeaf leaf = tree.getArgAsLeaf(index: 0);
1381 if (leaf.isStringAttr())
1382 return formatv(Fmt: "::mlir::NameLoc::get(rewriter.getStringAttr(\"{0}\"))",
1383 Vals: leaf.getStringAttr())
1384 .str();
1385 return lookUpArgLoc(0);
1386 }
1387
1388 std::string ret;
1389 llvm::raw_string_ostream os(ret);
1390 std::string strAttr;
1391 os << "rewriter.getFusedLoc({";
1392 bool first = true;
1393 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
1394 DagLeaf leaf = tree.getArgAsLeaf(index: i);
1395 // Handle the optional string value.
1396 if (leaf.isStringAttr()) {
1397 if (!strAttr.empty())
1398 llvm::PrintFatalError(Msg: "Only one string attribute may be specified");
1399 strAttr = leaf.getStringAttr();
1400 continue;
1401 }
1402 os << (first ? "" : ", ") << lookUpArgLoc(i);
1403 first = false;
1404 }
1405 os << "}";
1406 if (!strAttr.empty()) {
1407 os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
1408 }
1409 os << ")";
1410 return os.str();
1411}
1412
1413std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i,
1414 int depth) {
1415 // Nested NativeCodeCall.
1416 if (auto dagNode = returnType.getArgAsNestedDag(index: i)) {
1417 if (!dagNode.isNativeCodeCall())
1418 PrintFatalError(ErrorLoc: loc, Msg: "nested DAG in `returnType` must be a native code "
1419 "call");
1420 return handleReplaceWithNativeCodeCall(resultTree: dagNode, depth);
1421 }
1422 // String literal.
1423 auto dagLeaf = returnType.getArgAsLeaf(index: i);
1424 if (dagLeaf.isStringAttr())
1425 return tgfmt(fmt: dagLeaf.getStringAttr(), ctx: &fmtCtx);
1426 return tgfmt(
1427 fmt: "$0.getType()", ctx: &fmtCtx,
1428 vals: handleOpArgument(leaf: returnType.getArgAsLeaf(index: i), patArgName: returnType.getArgName(index: i)));
1429}
1430
1431std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
1432 StringRef patArgName) {
1433 if (leaf.isStringAttr())
1434 PrintFatalError(ErrorLoc: loc, Msg: "raw string not supported as argument");
1435 if (leaf.isConstantAttr()) {
1436 auto constAttr = leaf.getAsConstantAttr();
1437 return handleConstantAttr(attr: constAttr.getAttribute(),
1438 value: constAttr.getConstantValue());
1439 }
1440 if (leaf.isEnumCase()) {
1441 auto enumCase = leaf.getAsEnumCase();
1442 // This is an enum case backed by an IntegerAttr. We need to get its value
1443 // to build the constant.
1444 std::string val = std::to_string(val: enumCase.getValue());
1445 return handleConstantAttr(attr: Attribute(&enumCase.getDef()), value: val);
1446 }
1447 if (leaf.isConstantProp()) {
1448 auto constantProp = leaf.getAsConstantProp();
1449 return constantProp.getValue().str();
1450 }
1451
1452 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
1453 auto argName = symbolInfoMap.getValueAndRangeUse(symbol: patArgName);
1454 if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
1455 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
1456 << "' (via symbol ref)\n");
1457 return argName;
1458 }
1459 if (leaf.isNativeCodeCall()) {
1460 auto repl = tgfmt(fmt: leaf.getNativeCodeTemplate(), ctx: &fmtCtx.withSelf(subst: argName));
1461 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
1462 << "' (via NativeCodeCall)\n");
1463 return std::string(repl);
1464 }
1465 PrintFatalError(ErrorLoc: loc, Msg: "unhandled case when rewriting op");
1466}
1467
1468std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
1469 int depth) {
1470 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
1471 LLVM_DEBUG(tree.print(llvm::dbgs()));
1472 LLVM_DEBUG(llvm::dbgs() << '\n');
1473
1474 auto fmt = tree.getNativeCodeTemplate();
1475
1476 SmallVector<std::string, 16> attrs;
1477
1478 auto tail = getTrailingDirectives(tree);
1479 if (tail.returnType)
1480 PrintFatalError(ErrorLoc: loc, Msg: "`NativeCodeCall` cannot have return type specifier");
1481 auto locToUse = getLocation(tail);
1482
1483 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1484 if (tree.isNestedDagArg(index: i)) {
1485 attrs.push_back(
1486 Elt: handleResultPattern(resultTree: tree.getArgAsNestedDag(index: i), resultIndex: i, depth: depth + 1));
1487 } else {
1488 attrs.push_back(
1489 Elt: handleOpArgument(leaf: tree.getArgAsLeaf(index: i), patArgName: tree.getArgName(index: i)));
1490 }
1491 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
1492 << " replacement: " << attrs[i] << "\n");
1493 }
1494
1495 std::string symbol = tgfmt(fmt, ctx: &fmtCtx.addSubst(placeholder: "_loc", subst: locToUse),
1496 params: static_cast<ArrayRef<std::string>>(attrs));
1497
1498 // In general, NativeCodeCall without naming binding don't need this. To
1499 // ensure void helper function has been correctly labeled, i.e., use
1500 // NativeCodeCallVoid, we cache the result to a local variable so that we will
1501 // get a compilation error in the auto-generated file.
1502 // Example.
1503 // // In the td file
1504 // Pat<(...), (NativeCodeCall<Foo> ...)>
1505 //
1506 // ---
1507 //
1508 // // In the auto-generated .cpp
1509 // ...
1510 // // Causes compilation error if Foo() returns void.
1511 // auto nativeVar = Foo();
1512 // ...
1513 if (tree.getNumReturnsOfNativeCode() != 0) {
1514 // Determine the local variable name for return value.
1515 std::string varName =
1516 SymbolInfoMap::getValuePackName(symbol: tree.getSymbol()).str();
1517 if (varName.empty()) {
1518 varName = formatv(Fmt: "nativeVar_{0}", Vals: nextValueId++);
1519 // Register the local variable for later uses.
1520 symbolInfoMap.bindValues(symbol: varName, numValues: tree.getNumReturnsOfNativeCode());
1521 }
1522
1523 // Catch the return value of helper function.
1524 os << formatv(Fmt: "auto {0} = {1}; (void){0};\n", Vals&: varName, Vals&: symbol);
1525
1526 if (!tree.getSymbol().empty())
1527 symbol = tree.getSymbol().str();
1528 else
1529 symbol = varName;
1530 }
1531
1532 return symbol;
1533}
1534
1535int PatternEmitter::getNodeValueCount(DagNode node) {
1536 if (node.isOperation()) {
1537 // If the op is bound to a symbol in the rewrite rule, query its result
1538 // count from the symbol info map.
1539 auto symbol = node.getSymbol();
1540 if (!symbol.empty()) {
1541 return symbolInfoMap.getStaticValueCount(symbol);
1542 }
1543 // Otherwise this is an unbound op; we will use all its results.
1544 return pattern.getDialectOp(node).getNumResults();
1545 }
1546
1547 if (node.isNativeCodeCall())
1548 return node.getNumReturnsOfNativeCode();
1549
1550 return 1;
1551}
1552
1553PatternEmitter::TrailingDirectives
1554PatternEmitter::getTrailingDirectives(DagNode tree) {
1555 TrailingDirectives tail = {.location: DagNode(nullptr), .returnType: DagNode(nullptr), .numDirectives: 0};
1556
1557 // Look backwards through the arguments.
1558 auto numPatArgs = tree.getNumArgs();
1559 for (int i = numPatArgs - 1; i >= 0; --i) {
1560 auto dagArg = tree.getArgAsNestedDag(index: i);
1561 // A leaf is not a directive. Stop looking.
1562 if (!dagArg)
1563 break;
1564
1565 auto isLocation = dagArg.isLocationDirective();
1566 auto isReturnType = dagArg.isReturnTypeDirective();
1567 // If encountered a DAG node that isn't a trailing directive, stop looking.
1568 if (!(isLocation || isReturnType))
1569 break;
1570 // Save the directive, but error if one of the same type was already
1571 // found.
1572 ++tail.numDirectives;
1573 if (isLocation) {
1574 if (tail.location)
1575 PrintFatalError(ErrorLoc: loc, Msg: "`location` directive can only be specified "
1576 "once");
1577 tail.location = dagArg;
1578 } else if (isReturnType) {
1579 if (tail.returnType)
1580 PrintFatalError(ErrorLoc: loc, Msg: "`returnType` directive can only be specified "
1581 "once");
1582 tail.returnType = dagArg;
1583 }
1584 }
1585
1586 return tail;
1587}
1588
1589std::string
1590PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) {
1591 if (tail.location)
1592 return handleLocationDirective(tree: tail.location);
1593
1594 // If no explicit location is given, use the default, all fused, location.
1595 return "odsLoc";
1596}
1597
1598std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
1599 int depth) {
1600 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
1601 LLVM_DEBUG(tree.print(llvm::dbgs()));
1602 LLVM_DEBUG(llvm::dbgs() << '\n');
1603
1604 Operator &resultOp = tree.getDialectOp(mapper: opMap);
1605 bool useProperties = resultOp.getDialect().usePropertiesForAttributes();
1606 auto numOpArgs = resultOp.getNumArgs();
1607 auto numPatArgs = tree.getNumArgs();
1608
1609 auto tail = getTrailingDirectives(tree);
1610 auto locToUse = getLocation(tail);
1611
1612 auto inPattern = numPatArgs - tail.numDirectives;
1613 if (numOpArgs != inPattern) {
1614 PrintFatalError(ErrorLoc: loc,
1615 Msg: formatv(Fmt: "resultant op '{0}' argument number mismatch: "
1616 "{1} in pattern vs. {2} in definition",
1617 Vals: resultOp.getOperationName(), Vals&: inPattern, Vals&: numOpArgs));
1618 }
1619
1620 // A map to collect all nested DAG child nodes' names, with operand index as
1621 // the key. This includes both bound and unbound child nodes.
1622 ChildNodeIndexNameMap childNodeNames;
1623
1624 // If the argument is a type constraint, then its an operand. Check if the
1625 // op's argument is variadic that the argument in the pattern is too.
1626 auto checkIfMatchedVariadic = [&](int i) {
1627 // FIXME: This does not yet check for variable/leaf case.
1628 // FIXME: Change so that native code call can be handled.
1629 const auto *operand =
1630 llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val: resultOp.getArg(index: i));
1631 if (!operand || !operand->isVariadic())
1632 return;
1633
1634 auto child = tree.getArgAsNestedDag(index: i);
1635 if (!child)
1636 return;
1637
1638 // Skip over replaceWithValues.
1639 while (child.isReplaceWithValue()) {
1640 if (!(child = child.getArgAsNestedDag(index: 0)))
1641 return;
1642 }
1643 if (!child.isNativeCodeCall() && !child.isVariadic())
1644 PrintFatalError(ErrorLoc: loc, Msg: formatv(Fmt: "op expects variadic operand `{0}`, while "
1645 "provided is non-variadic",
1646 Vals: resultOp.getArgName(index: i)));
1647 };
1648
1649 // First go through all the child nodes who are nested DAG constructs to
1650 // create ops for them and remember the symbol names for them, so that we can
1651 // use the results in the current node. This happens in a recursive manner.
1652 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1653 checkIfMatchedVariadic(i);
1654 if (auto child = tree.getArgAsNestedDag(index: i))
1655 childNodeNames[i] = handleResultPattern(resultTree: child, resultIndex: i, depth: depth + 1);
1656 }
1657
1658 // The name of the local variable holding this op.
1659 std::string valuePackName;
1660 // The symbol for holding the result of this pattern. Note that the result of
1661 // this pattern is not necessarily the same as the variable created by this
1662 // pattern because we can use `__N` suffix to refer only a specific result if
1663 // the generated op is a multi-result op.
1664 std::string resultValue;
1665 if (tree.getSymbol().empty()) {
1666 // No symbol is explicitly bound to this op in the pattern. Generate a
1667 // unique name.
1668 valuePackName = resultValue = getUniqueSymbol(op: &resultOp);
1669 } else {
1670 resultValue = std::string(tree.getSymbol());
1671 // Strip the index to get the name for the value pack and use it to name the
1672 // local variable for the op.
1673 valuePackName = std::string(SymbolInfoMap::getValuePackName(symbol: resultValue));
1674 }
1675
1676 // Create the local variable for this op.
1677 os << formatv(Fmt: "{0} {1};\n{{\n", Vals: resultOp.getQualCppClassName(),
1678 Vals&: valuePackName);
1679
1680 // Right now ODS don't have general type inference support. Except a few
1681 // special cases listed below, DRR needs to supply types for all results
1682 // when building an op.
1683 bool isSameOperandsAndResultType =
1684 resultOp.getTrait(trait: "::mlir::OpTrait::SameOperandsAndResultType");
1685 bool useFirstAttr =
1686 resultOp.getTrait(trait: "::mlir::OpTrait::FirstAttrDerivedResultType");
1687
1688 if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) {
1689 // We know how to deduce the result type for ops with these traits and we've
1690 // generated builders taking aggregate parameters. Use those builders to
1691 // create the ops.
1692
1693 // First prepare local variables for op arguments used in builder call.
1694 createAggregateLocalVarsForOpArgs(node: tree, childNodeNames, depth);
1695
1696 // Then create the op.
1697 os.scope(open: "", close: "\n}\n").os
1698 << formatv(Fmt: "{0} = rewriter.create<{1}>({2}, tblgen_values, {3});",
1699 Vals&: valuePackName, Vals: resultOp.getQualCppClassName(), Vals&: locToUse,
1700 Vals: useProperties ? "tblgen_props" : "tblgen_attrs");
1701 return resultValue;
1702 }
1703
1704 bool usePartialResults = valuePackName != resultValue;
1705
1706 if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) {
1707 // For these cases (broadcastable ops, op results used both as auxiliary
1708 // values and replacement values, ops in nested patterns, auxiliary ops), we
1709 // still need to supply the result types when building the op. But because
1710 // we don't generate a builder automatically with ODS for them, it's the
1711 // developer's responsibility to make sure such a builder (with result type
1712 // deduction ability) exists. We go through the separate-parameter builder
1713 // here given that it's easier for developers to write compared to
1714 // aggregate-parameter builders.
1715 createSeparateLocalVarsForOpArgs(node: tree, childNodeNames);
1716
1717 os.scope().os << formatv(Fmt: "{0} = rewriter.create<{1}>({2}", Vals&: valuePackName,
1718 Vals: resultOp.getQualCppClassName(), Vals&: locToUse);
1719 supplyValuesForOpArgs(node: tree, childNodeNames, depth);
1720 os << "\n );\n}\n";
1721 return resultValue;
1722 }
1723
1724 // If we are provided explicit return types, use them to build the op.
1725 // However, if depth == 0 and resultIndex >= 0, it means we are replacing
1726 // the values generated from the source pattern root op. Then we must use the
1727 // source pattern's value types to determine the value type of the generated
1728 // op here.
1729 if (depth == 0 && resultIndex >= 0 && tail.returnType)
1730 PrintFatalError(ErrorLoc: loc, Msg: "Cannot specify explicit return types in an op whose "
1731 "return values replace the source pattern's root op");
1732
1733 // First prepare local variables for op arguments used in builder call.
1734 createAggregateLocalVarsForOpArgs(node: tree, childNodeNames, depth);
1735
1736 // Then prepare the result types. We need to specify the types for all
1737 // results.
1738 os.indent() << formatv(Fmt: "::llvm::SmallVector<::mlir::Type, 4> tblgen_types; "
1739 "(void)tblgen_types;\n");
1740 int numResults = resultOp.getNumResults();
1741 if (tail.returnType) {
1742 auto numRetTys = tail.returnType.getNumArgs();
1743 for (int i = 0; i < numRetTys; ++i) {
1744 auto varName = handleReturnTypeArg(returnType: tail.returnType, i, depth: depth + 1);
1745 os << "tblgen_types.push_back(" << varName << ");\n";
1746 }
1747 } else {
1748 if (numResults != 0) {
1749 // Copy the result types from the source pattern.
1750 for (int i = 0; i < numResults; ++i)
1751 os << formatv(Fmt: "for (auto v: castedOp0.getODSResults({0})) {{\n"
1752 " tblgen_types.push_back(v.getType());\n}\n",
1753 Vals: resultIndex + i);
1754 }
1755 }
1756 os << formatv(Fmt: "{0} = rewriter.create<{1}>({2}, tblgen_types, "
1757 "tblgen_values, {3});\n",
1758 Vals&: valuePackName, Vals: resultOp.getQualCppClassName(), Vals&: locToUse,
1759 Vals: useProperties ? "tblgen_props" : "tblgen_attrs");
1760 os.unindent() << "}\n";
1761 return resultValue;
1762}
1763
1764void PatternEmitter::createSeparateLocalVarsForOpArgs(
1765 DagNode node, ChildNodeIndexNameMap &childNodeNames) {
1766 Operator &resultOp = node.getDialectOp(mapper: opMap);
1767
1768 // Now prepare operands used for building this op:
1769 // * If the operand is non-variadic, we create a `Value` local variable.
1770 // * If the operand is variadic, we create a `SmallVector<Value>` local
1771 // variable.
1772
1773 int valueIndex = 0; // An index for uniquing local variable names.
1774 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1775 const auto *operand =
1776 llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val: resultOp.getArg(index: argIndex));
1777 // We do not need special handling for attributes or properties.
1778 if (!operand)
1779 continue;
1780
1781 raw_indented_ostream::DelimitedScope scope(os);
1782 std::string varName;
1783 if (operand->isVariadic()) {
1784 varName = std::string(formatv(Fmt: "tblgen_values_{0}", Vals: valueIndex++));
1785 os << formatv(Fmt: "::llvm::SmallVector<::mlir::Value, 4> {0};\n", Vals&: varName);
1786 std::string range;
1787 if (node.isNestedDagArg(index: argIndex)) {
1788 range = childNodeNames[argIndex];
1789 } else {
1790 range = std::string(node.getArgName(index: argIndex));
1791 }
1792 // Resolve the symbol for all range use so that we have a uniform way of
1793 // capturing the values.
1794 range = symbolInfoMap.getValueAndRangeUse(symbol: range);
1795 os << formatv(Fmt: "for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", Vals&: range,
1796 Vals&: varName);
1797 } else {
1798 varName = std::string(formatv(Fmt: "tblgen_value_{0}", Vals: valueIndex++));
1799 os << formatv(Fmt: "::mlir::Value {0} = ", Vals&: varName);
1800 if (node.isNestedDagArg(index: argIndex)) {
1801 os << symbolInfoMap.getValueAndRangeUse(symbol: childNodeNames[argIndex]);
1802 } else {
1803 DagLeaf leaf = node.getArgAsLeaf(index: argIndex);
1804 auto symbol =
1805 symbolInfoMap.getValueAndRangeUse(symbol: node.getArgName(index: argIndex));
1806 if (leaf.isNativeCodeCall()) {
1807 os << std::string(
1808 tgfmt(fmt: leaf.getNativeCodeTemplate(), ctx: &fmtCtx.withSelf(subst: symbol)));
1809 } else {
1810 os << symbol;
1811 }
1812 }
1813 os << ";\n";
1814 }
1815
1816 // Update to use the newly created local variable for building the op later.
1817 childNodeNames[argIndex] = varName;
1818 }
1819}
1820
1821void PatternEmitter::supplyValuesForOpArgs(
1822 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1823 Operator &resultOp = node.getDialectOp(mapper: opMap);
1824 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
1825 argIndex != numOpArgs; ++argIndex) {
1826 // Start each argument on its own line.
1827 os << ",\n ";
1828
1829 Argument opArg = resultOp.getArg(index: argIndex);
1830 // Handle the case of operand first.
1831 if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: opArg)) {
1832 if (!operand->name.empty())
1833 os << "/*" << operand->name << "=*/";
1834 os << childNodeNames.lookup(Val: argIndex);
1835 continue;
1836 }
1837
1838 // The argument in the op definition.
1839 auto opArgName = resultOp.getArgName(index: argIndex);
1840 if (auto subTree = node.getArgAsNestedDag(index: argIndex)) {
1841 if (!subTree.isNativeCodeCall())
1842 PrintFatalError(ErrorLoc: loc, Msg: "only NativeCodeCall allowed in nested dag node "
1843 "for creating attributes and properties");
1844 os << formatv(Fmt: "/*{0}=*/{1}", Vals&: opArgName, Vals: childNodeNames.lookup(Val: argIndex));
1845 } else {
1846 auto leaf = node.getArgAsLeaf(index: argIndex);
1847 // The argument in the result DAG pattern.
1848 auto patArgName = node.getArgName(index: argIndex);
1849 if (leaf.isConstantAttr() || leaf.isEnumCase()) {
1850 // TODO: Refactor out into map to avoid recomputing these.
1851 if (!isa<NamedAttribute *>(Val: opArg))
1852 PrintFatalError(ErrorLoc: loc, Msg: Twine("expected attribute ") + Twine(argIndex));
1853 if (!patArgName.empty())
1854 os << "/*" << patArgName << "=*/";
1855 } else if (leaf.isConstantProp()) {
1856 if (!isa<NamedProperty *>(Val: opArg))
1857 PrintFatalError(ErrorLoc: loc, Msg: Twine("expected property ") + Twine(argIndex));
1858 if (!patArgName.empty())
1859 os << "/*" << patArgName << "=*/";
1860 } else {
1861 os << "/*" << opArgName << "=*/";
1862 }
1863 os << handleOpArgument(leaf, patArgName);
1864 }
1865 }
1866}
1867
1868void PatternEmitter::createAggregateLocalVarsForOpArgs(
1869 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1870 Operator &resultOp = node.getDialectOp(mapper: opMap);
1871
1872 bool useProperties = resultOp.getDialect().usePropertiesForAttributes();
1873 auto scope = os.scope();
1874 os << formatv(Fmt: "::llvm::SmallVector<::mlir::Value, 4> "
1875 "tblgen_values; (void)tblgen_values;\n");
1876 if (useProperties) {
1877 os << formatv(Fmt: "{0}::Properties tblgen_props; (void)tblgen_props;\n",
1878 Vals: resultOp.getQualCppClassName());
1879 } else {
1880 os << formatv(Fmt: "::llvm::SmallVector<::mlir::NamedAttribute, 4> "
1881 "tblgen_attrs; (void)tblgen_attrs;\n");
1882 }
1883
1884 const char *setPropCmd =
1885 "tblgen_props.{0} = "
1886 "::llvm::dyn_cast_if_present<decltype(tblgen_props.{0})>({1});\n";
1887 const char *addAttrCmd =
1888 "if (auto tmpAttr = {1}) {\n"
1889 " tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
1890 "tmpAttr);\n}\n";
1891 const char *setterCmd = (useProperties) ? setPropCmd : addAttrCmd;
1892 const char *propSetterCmd = "tblgen_props.{0}({1});\n";
1893
1894 int numVariadic = 0;
1895 bool hasOperandSegmentSizes = false;
1896 std::vector<std::string> sizes;
1897 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1898 if (isa<NamedAttribute *>(Val: resultOp.getArg(index: argIndex))) {
1899 // The argument in the op definition.
1900 auto opArgName = resultOp.getArgName(index: argIndex);
1901 hasOperandSegmentSizes =
1902 hasOperandSegmentSizes || opArgName == "operandSegmentSizes";
1903 if (auto subTree = node.getArgAsNestedDag(index: argIndex)) {
1904 if (!subTree.isNativeCodeCall())
1905 PrintFatalError(ErrorLoc: loc, Msg: "only NativeCodeCall allowed in nested dag node "
1906 "for creating attribute");
1907
1908 os << formatv(Fmt: setterCmd, Vals&: opArgName, Vals: childNodeNames.lookup(Val: argIndex));
1909 } else {
1910 auto leaf = node.getArgAsLeaf(index: argIndex);
1911 // The argument in the result DAG pattern.
1912 auto patArgName = node.getArgName(index: argIndex);
1913 os << formatv(Fmt: setterCmd, Vals&: opArgName, Vals: handleOpArgument(leaf, patArgName));
1914 }
1915 continue;
1916 }
1917
1918 if (isa<NamedProperty *>(Val: resultOp.getArg(index: argIndex))) {
1919 // The argument in the op definition.
1920 auto opArgName = resultOp.getArgName(index: argIndex);
1921 auto setterName = resultOp.getSetterName(name: opArgName);
1922 if (auto subTree = node.getArgAsNestedDag(index: argIndex)) {
1923 if (!subTree.isNativeCodeCall())
1924 PrintFatalError(ErrorLoc: loc, Msg: "only NativeCodeCall allowed in nested dag node "
1925 "for creating property");
1926
1927 os << formatv(Fmt: propSetterCmd, Vals&: setterName,
1928 Vals: childNodeNames.lookup(Val: argIndex));
1929 } else {
1930 auto leaf = node.getArgAsLeaf(index: argIndex);
1931 // The argument in the result DAG pattern.
1932 auto patArgName = node.getArgName(index: argIndex);
1933 // The argument in the result DAG pattern.
1934 os << formatv(Fmt: propSetterCmd, Vals&: setterName,
1935 Vals: handleOpArgument(leaf, patArgName));
1936 }
1937 continue;
1938 }
1939
1940 const auto *operand =
1941 cast<NamedTypeConstraint *>(Val: resultOp.getArg(index: argIndex));
1942 if (operand->isVariadic()) {
1943 ++numVariadic;
1944 std::string range;
1945 if (node.isNestedDagArg(index: argIndex)) {
1946 range = childNodeNames.lookup(Val: argIndex);
1947 } else {
1948 range = std::string(node.getArgName(index: argIndex));
1949 }
1950 // Resolve the symbol for all range use so that we have a uniform way of
1951 // capturing the values.
1952 range = symbolInfoMap.getValueAndRangeUse(symbol: range);
1953 os << formatv(Fmt: "for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n",
1954 Vals&: range);
1955 sizes.push_back(x: formatv(Fmt: "static_cast<int32_t>({0}.size())", Vals&: range));
1956 } else {
1957 sizes.emplace_back(args: "1");
1958 os << formatv(Fmt: "tblgen_values.push_back(");
1959 if (node.isNestedDagArg(index: argIndex)) {
1960 os << symbolInfoMap.getValueAndRangeUse(
1961 symbol: childNodeNames.lookup(Val: argIndex));
1962 } else {
1963 DagLeaf leaf = node.getArgAsLeaf(index: argIndex);
1964 if (leaf.isConstantAttr())
1965 // TODO: Use better location
1966 PrintFatalError(
1967 ErrorLoc: loc,
1968 Msg: "attribute found where value was expected, if attempting to use "
1969 "constant value, construct a constant op with given attribute "
1970 "instead");
1971
1972 auto symbol =
1973 symbolInfoMap.getValueAndRangeUse(symbol: node.getArgName(index: argIndex));
1974 if (leaf.isNativeCodeCall()) {
1975 os << std::string(
1976 tgfmt(fmt: leaf.getNativeCodeTemplate(), ctx: &fmtCtx.withSelf(subst: symbol)));
1977 } else {
1978 os << symbol;
1979 }
1980 }
1981 os << ");\n";
1982 }
1983 }
1984
1985 if (numVariadic > 1 && !hasOperandSegmentSizes) {
1986 // Only set size if it can't be computed.
1987 const auto *sameVariadicSize =
1988 resultOp.getTrait(trait: "::mlir::OpTrait::SameVariadicOperandSize");
1989 if (!sameVariadicSize) {
1990 if (useProperties) {
1991 const char *setSizes = R"(
1992 tblgen_props.operandSegmentSizes = {{ {0} };
1993 )";
1994 os.printReindented(str: formatv(Fmt: setSizes, Vals: llvm::join(R&: sizes, Separator: ", ")).str());
1995 } else {
1996 const char *setSizes = R"(
1997 tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
1998 rewriter.getDenseI32ArrayAttr({{ {0} }));
1999 )";
2000 os.printReindented(str: formatv(Fmt: setSizes, Vals: llvm::join(R&: sizes, Separator: ", ")).str());
2001 }
2002 }
2003 }
2004}
2005
2006StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os,
2007 const RecordKeeper &records,
2008 RecordOperatorMap &mapper)
2009 : opMap(mapper), staticVerifierEmitter(os, records) {}
2010
2011void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
2012 // PatternEmitter will use the static matcher if there's one generated. To
2013 // ensure that all the dependent static matchers are generated before emitting
2014 // the matching logic of the DagNode, we use topological order to achieve it.
2015 for (auto &dagInfo : topologicalOrder) {
2016 DagNode node = dagInfo.first;
2017 if (!useStaticMatcher(node))
2018 continue;
2019
2020 std::string funcName =
2021 formatv(Fmt: "static_dag_matcher_{0}", Vals: staticMatcherCounter++);
2022 assert(!matcherNames.contains(node));
2023 PatternEmitter(dagInfo.second, &opMap, os, *this)
2024 .emitStaticMatcher(tree: node, funcName);
2025 matcherNames[node] = funcName;
2026 }
2027}
2028
2029void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) {
2030 staticVerifierEmitter.emitPatternConstraints(constraints: constraints.getArrayRef());
2031}
2032
2033void StaticMatcherHelper::addPattern(const Record *record) {
2034 Pattern pat(record, &opMap);
2035
2036 // While generating the function body of the DAG matcher, it may depends on
2037 // other DAG matchers. To ensure the dependent matchers are ready, we compute
2038 // the topological order for all the DAGs and emit the DAG matchers in this
2039 // order.
2040 llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) {
2041 ++refStats[node];
2042
2043 if (refStats[node] != 1)
2044 return;
2045
2046 for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i)
2047 if (DagNode sibling = node.getArgAsNestedDag(index: i))
2048 dfs(sibling);
2049 else {
2050 DagLeaf leaf = node.getArgAsLeaf(index: i);
2051 if (!leaf.isUnspecified())
2052 constraints.insert(X: leaf);
2053 }
2054
2055 topologicalOrder.push_back(Elt: std::make_pair(x&: node, y&: record));
2056 };
2057
2058 dfs(pat.getSourcePattern());
2059}
2060
2061StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) {
2062 if (leaf.isAttrMatcher()) {
2063 std::optional<StringRef> constraint =
2064 staticVerifierEmitter.getAttrConstraintFn(constraint: leaf.getAsConstraint());
2065 assert(constraint && "attribute constraint was not uniqued");
2066 return *constraint;
2067 }
2068 if (leaf.isPropMatcher()) {
2069 std::optional<StringRef> constraint =
2070 staticVerifierEmitter.getPropConstraintFn(constraint: leaf.getAsConstraint());
2071 assert(constraint && "prop constraint was not uniqued");
2072 return *constraint;
2073 }
2074 assert(leaf.isOperandMatcher());
2075 return staticVerifierEmitter.getTypeConstraintFn(constraint: leaf.getAsConstraint());
2076}
2077
2078static void emitRewriters(const RecordKeeper &records, raw_ostream &os) {
2079 emitSourceFileHeader(Desc: "Rewriters", OS&: os, Record: records);
2080
2081 auto patterns = records.getAllDerivedDefinitions(ClassName: "Pattern");
2082
2083 // We put the map here because it can be shared among multiple patterns.
2084 RecordOperatorMap recordOpMap;
2085
2086 // Exam all the patterns and generate static matcher for the duplicated
2087 // DagNode.
2088 StaticMatcherHelper staticMatcher(os, records, recordOpMap);
2089 for (const Record *p : patterns)
2090 staticMatcher.addPattern(record: p);
2091 staticMatcher.populateStaticConstraintFunctions(os);
2092 staticMatcher.populateStaticMatchers(os);
2093
2094 std::vector<std::string> rewriterNames;
2095 rewriterNames.reserve(n: patterns.size());
2096
2097 std::string baseRewriterName = "GeneratedConvert";
2098 int rewriterIndex = 0;
2099
2100 for (const Record *p : patterns) {
2101 std::string name;
2102 if (p->isAnonymous()) {
2103 // If no name is provided, ensure unique rewriter names simply by
2104 // appending unique suffix.
2105 name = baseRewriterName + llvm::utostr(X: rewriterIndex++);
2106 } else {
2107 name = std::string(p->getName());
2108 }
2109 LLVM_DEBUG(llvm::dbgs()
2110 << "=== start generating pattern '" << name << "' ===\n");
2111 PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(rewriteName: name);
2112 LLVM_DEBUG(llvm::dbgs()
2113 << "=== done generating pattern '" << name << "' ===\n");
2114 rewriterNames.push_back(x: std::move(name));
2115 }
2116
2117 // Emit function to add the generated matchers to the pattern list.
2118 os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated("
2119 "::mlir::RewritePatternSet &patterns) {\n";
2120 for (const auto &name : rewriterNames) {
2121 os << " patterns.add<" << name << ">(patterns.getContext());\n";
2122 }
2123 os << "}\n";
2124}
2125
2126static mlir::GenRegistration
2127 genRewriters("gen-rewriters", "Generate pattern rewriters",
2128 [](const RecordKeeper &records, raw_ostream &os) {
2129 emitRewriters(records, os);
2130 return false;
2131 });
2132

source code of mlir/tools/mlir-tblgen/RewriterGen.cpp