1//===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Support/IndentedOstream.h"
14#include "mlir/TableGen/Argument.h"
15#include "mlir/TableGen/Attribute.h"
16#include "mlir/TableGen/CodeGenHelpers.h"
17#include "mlir/TableGen/Format.h"
18#include "mlir/TableGen/GenInfo.h"
19#include "mlir/TableGen/Operator.h"
20#include "mlir/TableGen/Pattern.h"
21#include "mlir/TableGen/Predicate.h"
22#include "mlir/TableGen/Property.h"
23#include "mlir/TableGen/Type.h"
24#include "llvm/ADT/FunctionExtras.h"
25#include "llvm/ADT/SetVector.h"
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/ADT/StringSet.h"
28#include "llvm/Support/CommandLine.h"
29#include "llvm/Support/Debug.h"
30#include "llvm/Support/FormatAdapters.h"
31#include "llvm/Support/PrettyStackTrace.h"
32#include "llvm/Support/Signals.h"
33#include "llvm/TableGen/Error.h"
34#include "llvm/TableGen/Main.h"
35#include "llvm/TableGen/Record.h"
36#include "llvm/TableGen/TableGenBackend.h"
37
38using namespace mlir;
39using namespace mlir::tblgen;
40
41using llvm::formatv;
42using llvm::Record;
43using llvm::RecordKeeper;
44
45#define DEBUG_TYPE "mlir-tblgen-rewritergen"
46
47namespace llvm {
48template <>
49struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
50 static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
51 raw_ostream &os, StringRef style) {
52 os << v.first << ":" << v.second;
53 }
54};
55} // namespace llvm
56
57//===----------------------------------------------------------------------===//
58// PatternEmitter
59//===----------------------------------------------------------------------===//
60
61namespace {
62
63class StaticMatcherHelper;
64
65class PatternEmitter {
66public:
67 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os,
68 StaticMatcherHelper &helper);
69
70 // Emits the mlir::RewritePattern struct named `rewriteName`.
71 void emit(StringRef rewriteName);
72
73 // Emits the static function of DAG matcher.
74 void emitStaticMatcher(DagNode tree, std::string funcName);
75
76private:
77 // Emits the code for matching ops.
78 void emitMatchLogic(DagNode tree, StringRef opName);
79
80 // Emits the code for rewriting ops.
81 void emitRewriteLogic();
82
83 //===--------------------------------------------------------------------===//
84 // Match utilities
85 //===--------------------------------------------------------------------===//
86
87 // Emits C++ statements for matching the DAG structure.
88 void emitMatch(DagNode tree, StringRef name, int depth);
89
90 // Emit C++ function call to static DAG matcher.
91 void emitStaticMatchCall(DagNode tree, StringRef name);
92
93 // Emit C++ function call to static type/attribute constraint function.
94 void emitStaticVerifierCall(StringRef funcName, StringRef opName,
95 StringRef arg, StringRef failureStr);
96
97 // Emits C++ statements for matching using a native code call.
98 void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
99
100 // Emits C++ statements for matching the op constrained by the given DAG
101 // `tree` returning the op's variable name.
102 void emitOpMatch(DagNode tree, StringRef opName, int depth);
103
104 // Emits C++ statements for matching the `argIndex`-th argument of the given
105 // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the
106 // bound name and the constraint of the operand respectively.
107 void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName,
108 int operandIndex, DagLeaf operandMatcher,
109 StringRef argName, int argIndex,
110 std::optional<int> variadicSubIndex);
111
112 // Emits C++ statements for matching the operands which can be matched in
113 // either order.
114 void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
115 StringRef opName, int argIndex, int &operandIndex,
116 int depth);
117
118 // Emits C++ statements for matching a variadic operand.
119 void emitVariadicOperandMatch(DagNode tree, DagNode variadicArgTree,
120 StringRef opName, int argIndex,
121 int &operandIndex, int depth);
122
123 // Emits C++ statements for matching the `argIndex`-th argument of the given
124 // DAG `tree` as an attribute.
125 void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
126 int depth);
127
128 // Emits C++ for checking a match with a corresponding match failure
129 // diagnostic.
130 void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
131 const llvm::formatv_object_base &failureFmt);
132
133 // Emits C++ for checking a match with a corresponding match failure
134 // diagnostics.
135 void emitMatchCheck(StringRef opName, const std::string &matchStr,
136 const std::string &failureStr);
137
138 //===--------------------------------------------------------------------===//
139 // Rewrite utilities
140 //===--------------------------------------------------------------------===//
141
142 // The entry point for handling a result pattern rooted at `resultTree`. This
143 // method dispatches to concrete handlers according to `resultTree`'s kind and
144 // returns a symbol representing the whole value pack. Callers are expected to
145 // further resolve the symbol according to the specific use case.
146 //
147 // `depth` is the nesting level of `resultTree`; 0 means top-level result
148 // pattern. For top-level result pattern, `resultIndex` indicates which result
149 // of the matched root op this pattern is intended to replace, which can be
150 // used to deduce the result type of the op generated from this result
151 // pattern.
152 std::string handleResultPattern(DagNode resultTree, int resultIndex,
153 int depth);
154
155 // Emits the C++ statement to replace the matched DAG with a value built via
156 // calling native C++ code.
157 std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
158
159 // Returns the symbol of the old value serving as the replacement.
160 StringRef handleReplaceWithValue(DagNode tree);
161
162 // Trailing directives are used at the end of DAG node argument lists to
163 // specify additional behaviour for op matchers and creators, etc.
164 struct TrailingDirectives {
165 // DAG node containing the `location` directive. Null if there is none.
166 DagNode location;
167
168 // DAG node containing the `returnType` directive. Null if there is none.
169 DagNode returnType;
170
171 // Number of found trailing directives.
172 int numDirectives;
173 };
174
175 // Collect any trailing directives.
176 TrailingDirectives getTrailingDirectives(DagNode tree);
177
178 // Returns the location value to use.
179 std::string getLocation(TrailingDirectives &tail);
180
181 // Returns the location value to use.
182 std::string handleLocationDirective(DagNode tree);
183
184 // Emit return type argument.
185 std::string handleReturnTypeArg(DagNode returnType, int i, int depth);
186
187 // Emits the C++ statement to build a new op out of the given DAG `tree` and
188 // returns the variable name that this op is assigned to. If the root op in
189 // DAG `tree` has a specified name, the created op will be assigned to a
190 // variable of the given name. Otherwise, a unique name will be used as the
191 // result value name.
192 std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
193
194 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
195
196 // Emits a local variable for each value and attribute to be used for creating
197 // an op.
198 void createSeparateLocalVarsForOpArgs(DagNode node,
199 ChildNodeIndexNameMap &childNodeNames);
200
201 // Emits the concrete arguments used to call an op's builder.
202 void supplyValuesForOpArgs(DagNode node,
203 const ChildNodeIndexNameMap &childNodeNames,
204 int depth);
205
206 // Emits the local variables for holding all values as a whole and all named
207 // attributes as a whole to be used for creating an op.
208 void createAggregateLocalVarsForOpArgs(
209 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
210
211 // Returns the C++ expression to construct a constant attribute of the given
212 // `value` for the given attribute kind `attr`.
213 std::string handleConstantAttr(Attribute attr, const Twine &value);
214
215 // Returns the C++ expression to build an argument from the given DAG `leaf`.
216 // `patArgName` is used to bound the argument to the source pattern.
217 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
218
219 //===--------------------------------------------------------------------===//
220 // General utilities
221 //===--------------------------------------------------------------------===//
222
223 // Collects all of the operations within the given dag tree.
224 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
225
226 // Returns a unique symbol for a local variable of the given `op`.
227 std::string getUniqueSymbol(const Operator *op);
228
229 //===--------------------------------------------------------------------===//
230 // Symbol utilities
231 //===--------------------------------------------------------------------===//
232
233 // Returns how many static values the given DAG `node` correspond to.
234 int getNodeValueCount(DagNode node);
235
236private:
237 // Pattern instantiation location followed by the location of multiclass
238 // prototypes used. This is intended to be used as a whole to
239 // PrintFatalError() on errors.
240 ArrayRef<SMLoc> loc;
241
242 // Op's TableGen Record to wrapper object.
243 RecordOperatorMap *opMap;
244
245 // Handy wrapper for pattern being emitted.
246 Pattern pattern;
247
248 // Map for all bound symbols' info.
249 SymbolInfoMap symbolInfoMap;
250
251 StaticMatcherHelper &staticMatcherHelper;
252
253 // The next unused ID for newly created values.
254 unsigned nextValueId = 0;
255
256 raw_indented_ostream os;
257
258 // Format contexts containing placeholder substitutions.
259 FmtContext fmtCtx;
260};
261
262// Tracks DagNode's reference multiple times across patterns. Enables generating
263// static matcher functions for DagNode's referenced multiple times rather than
264// inlining them.
265class StaticMatcherHelper {
266public:
267 StaticMatcherHelper(raw_ostream &os, const RecordKeeper &recordKeeper,
268 RecordOperatorMap &mapper);
269
270 // Determine if we should inline the match logic or delegate to a static
271 // function.
272 bool useStaticMatcher(DagNode node) {
273 // either/variadic node must be associated to the parentOp, thus we can't
274 // emit a static matcher rooted at them.
275 if (node.isEither() || node.isVariadic())
276 return false;
277
278 return refStats[node] > kStaticMatcherThreshold;
279 }
280
281 // Get the name of the static DAG matcher function corresponding to the node.
282 std::string getMatcherName(DagNode node) {
283 assert(useStaticMatcher(node));
284 return matcherNames[node];
285 }
286
287 // Get the name of static type/attribute verification function.
288 StringRef getVerifierName(DagLeaf leaf);
289
290 // Collect the `Record`s, i.e., the DRR, so that we can get the information of
291 // the duplicated DAGs.
292 void addPattern(Record *record);
293
294 // Emit all static functions of DAG Matcher.
295 void populateStaticMatchers(raw_ostream &os);
296
297 // Emit all static functions for Constraints.
298 void populateStaticConstraintFunctions(raw_ostream &os);
299
300private:
301 static constexpr unsigned kStaticMatcherThreshold = 1;
302
303 // Consider two patterns as down below,
304 // DagNode_Root_A DagNode_Root_B
305 // \ \
306 // DagNode_C DagNode_C
307 // \ \
308 // DagNode_D DagNode_D
309 //
310 // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of
311 // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced
312 // multiple times so we'll have static matchers for both of them. When we're
313 // emitting the match logic for DagNode_C, we will check if DagNode_D has the
314 // static matcher generated. If so, then we'll generate a call to the
315 // function, inline otherwise. In this case, inlining is not what we want. As
316 // a result, generate the static matcher in topological order to ensure all
317 // the dependent static matchers are generated and we can avoid accidentally
318 // inlining.
319 //
320 // The topological order of all the DagNodes among all patterns.
321 SmallVector<std::pair<DagNode, Record *>> topologicalOrder;
322
323 RecordOperatorMap &opMap;
324
325 // Records of the static function name of each DagNode
326 DenseMap<DagNode, std::string> matcherNames;
327
328 // After collecting all the DagNode in each pattern, `refStats` records the
329 // number of users for each DagNode. We will generate the static matcher for a
330 // DagNode while the number of users exceeds a certain threshold.
331 DenseMap<DagNode, unsigned> refStats;
332
333 // Number of static matcher generated. This is used to generate a unique name
334 // for each DagNode.
335 int staticMatcherCounter = 0;
336
337 // The DagLeaf which contains type or attr constraint.
338 SetVector<DagLeaf> constraints;
339
340 // Static type/attribute verification function emitter.
341 StaticVerifierFunctionEmitter staticVerifierEmitter;
342};
343
344} // namespace
345
346PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
347 raw_ostream &os, StaticMatcherHelper &helper)
348 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
349 symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), os(os) {
350 fmtCtx.withBuilder(subst: "rewriter");
351}
352
353std::string PatternEmitter::handleConstantAttr(Attribute attr,
354 const Twine &value) {
355 if (!attr.isConstBuildable())
356 PrintFatalError(ErrorLoc: loc, Msg: "Attribute " + attr.getAttrDefName() +
357 " does not have the 'constBuilderCall' field");
358
359 // TODO: Verify the constants here
360 return std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fmtCtx, vals: value));
361}
362
363void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) {
364 os << formatv(
365 Fmt: "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
366 "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation "
367 "*, 4> &tblgen_ops",
368 Vals&: funcName);
369
370 // We pass the reference of the variables that need to be captured. Hence we
371 // need to collect all the symbols in the tree first.
372 pattern.collectBoundSymbols(tree, infoMap&: symbolInfoMap, /*isSrcPattern=*/true);
373 symbolInfoMap.assignUniqueAlternativeNames();
374 for (const auto &info : symbolInfoMap)
375 os << formatv(Fmt: ", {0}", Vals: info.second.getArgDecl(name: info.first));
376
377 os << ") {\n";
378 os.indent();
379 os << "(void)tblgen_ops;\n";
380
381 // Note that a static matcher is considered at least one step from the match
382 // entry.
383 emitMatch(tree, name: "op0", /*depth=*/1);
384
385 os << "return ::mlir::success();\n";
386 os.unindent();
387 os << "}\n\n";
388}
389
390// Helper function to match patterns.
391void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
392 if (tree.isNativeCodeCall()) {
393 emitNativeCodeMatch(tree, name, depth);
394 return;
395 }
396
397 if (tree.isOperation()) {
398 emitOpMatch(tree, opName: name, depth);
399 return;
400 }
401
402 PrintFatalError(ErrorLoc: loc, Msg: "encountered non-op, non-NativeCodeCall match.");
403}
404
405void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) {
406 std::string funcName = staticMatcherHelper.getMatcherName(node: tree);
407 os << formatv(Fmt: "if(::mlir::failed({0}(rewriter, {1}, tblgen_ops", Vals&: funcName,
408 Vals&: opName);
409
410 // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in
411 // one pass.
412
413 // In general, bound symbol should have the unique name in the pattern but
414 // for the operand, binding same symbol to multiple operands imply a
415 // constraint at the same time. In this case, we will rename those operands
416 // with different names. As a result, we need to collect all the symbolInfos
417 // from the DagNode then get the updated name of the local variables from the
418 // global symbolInfoMap.
419
420 // Collect all the bound symbols in the Dag
421 SymbolInfoMap localSymbolMap(loc);
422 pattern.collectBoundSymbols(tree, infoMap&: localSymbolMap, /*isSrcPattern=*/true);
423
424 for (const auto &info : localSymbolMap) {
425 auto name = info.first;
426 auto symboInfo = info.second;
427 auto ret = symbolInfoMap.findBoundSymbol(key: name, symbolInfo: symboInfo);
428 os << formatv(Fmt: ", {0}", Vals: ret->second.getVarName(name));
429 }
430
431 os << "))) {\n";
432 os.scope().os << "return ::mlir::failure();\n";
433 os << "}\n";
434}
435
436void PatternEmitter::emitStaticVerifierCall(StringRef funcName,
437 StringRef opName, StringRef arg,
438 StringRef failureStr) {
439 os << formatv(Fmt: "if(::mlir::failed({0}(rewriter, {1}, {2}, {3}))) {{\n",
440 Vals&: funcName, Vals&: opName, Vals&: arg, Vals&: failureStr);
441 os.scope().os << "return ::mlir::failure();\n";
442 os << "}\n";
443}
444
445// Helper function to match patterns.
446void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
447 int depth) {
448 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
449 LLVM_DEBUG(tree.print(llvm::dbgs()));
450 LLVM_DEBUG(llvm::dbgs() << '\n');
451
452 // The order of generating static matcher follows the topological order so
453 // that for every dependent DagNode already have their static matcher
454 // generated if needed. The reason we check if `getMatcherName(tree).empty()`
455 // is when we are generating the static matcher for a DagNode itself. In this
456 // case, we need to emit the function body rather than a function call.
457 if (staticMatcherHelper.useStaticMatcher(node: tree) &&
458 !staticMatcherHelper.getMatcherName(node: tree).empty()) {
459 emitStaticMatchCall(tree, opName);
460
461 // NativeCodeCall will never be at depth 0 so that we don't need to catch
462 // the root operation as emitOpMatch();
463
464 return;
465 }
466
467 // TODO(suderman): iterate through arguments, determine their types, output
468 // names.
469 SmallVector<std::string, 8> capture;
470
471 raw_indented_ostream::DelimitedScope scope(os);
472
473 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
474 std::string argName = formatv(Fmt: "arg{0}_{1}", Vals&: depth, Vals&: i);
475 if (DagNode argTree = tree.getArgAsNestedDag(index: i)) {
476 if (argTree.isEither())
477 PrintFatalError(ErrorLoc: loc, Msg: "NativeCodeCall cannot have `either` operands");
478 if (argTree.isVariadic())
479 PrintFatalError(ErrorLoc: loc, Msg: "NativeCodeCall cannot have `variadic` operands");
480
481 os << "::mlir::Value " << argName << ";\n";
482 } else {
483 auto leaf = tree.getArgAsLeaf(index: i);
484 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
485 os << "::mlir::Attribute " << argName << ";\n";
486 } else {
487 os << "::mlir::Value " << argName << ";\n";
488 }
489 }
490
491 capture.push_back(Elt: std::move(argName));
492 }
493
494 auto tail = getTrailingDirectives(tree);
495 if (tail.returnType)
496 PrintFatalError(ErrorLoc: loc, Msg: "`NativeCodeCall` cannot have return type specifier");
497 auto locToUse = getLocation(tail);
498
499 auto fmt = tree.getNativeCodeTemplate();
500 if (fmt.count(Str: "$_self") != 1)
501 PrintFatalError(ErrorLoc: loc, Msg: "NativeCodeCall must have $_self as argument for "
502 "passing the defining Operation");
503
504 auto nativeCodeCall = std::string(
505 tgfmt(fmt, ctx: &fmtCtx.addSubst(placeholder: "_loc", subst: locToUse).withSelf(subst: opName.str()),
506 params: static_cast<ArrayRef<std::string>>(capture)));
507
508 emitMatchCheck(opName, matchStr: formatv(Fmt: "!::mlir::failed({0})", Vals&: nativeCodeCall),
509 failureStr: formatv(Fmt: "\"{0} return ::mlir::failure\"", Vals&: nativeCodeCall));
510
511 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
512 auto name = tree.getArgName(index: i);
513 if (!name.empty() && name != "_") {
514 os << formatv(Fmt: "{0} = {1};\n", Vals&: name, Vals&: capture[i]);
515 }
516 }
517
518 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
519 std::string argName = capture[i];
520
521 // Handle nested DAG construct first
522 if (tree.getArgAsNestedDag(index: i)) {
523 PrintFatalError(
524 ErrorLoc: loc, Msg: formatv(Fmt: "Matching nested tree in NativeCodecall not support for "
525 "{0} as arg {1}",
526 Vals&: argName, Vals&: i));
527 }
528
529 DagLeaf leaf = tree.getArgAsLeaf(index: i);
530
531 // The parameter for native function doesn't bind any constraints.
532 if (leaf.isUnspecified())
533 continue;
534
535 auto constraint = leaf.getAsConstraint();
536
537 std::string self;
538 if (leaf.isAttrMatcher() || leaf.isConstantAttr())
539 self = argName;
540 else
541 self = formatv(Fmt: "{0}.getType()", Vals&: argName);
542 StringRef verifier = staticMatcherHelper.getVerifierName(leaf);
543 emitStaticVerifierCall(
544 funcName: verifier, opName, arg: self,
545 failureStr: formatv(Fmt: "\"operand {0} of native code call '{1}' failed to satisfy "
546 "constraint: "
547 "'{2}'\"",
548 Vals&: i, Vals: tree.getNativeCodeTemplate(),
549 Vals: escapeString(value: constraint.getSummary()))
550 .str());
551 }
552
553 LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
554}
555
556// Helper function to match patterns.
557void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
558 Operator &op = tree.getDialectOp(mapper: opMap);
559 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
560 << op.getOperationName() << "' at depth " << depth
561 << '\n');
562
563 auto getCastedName = [depth]() -> std::string {
564 return formatv(Fmt: "castedOp{0}", Vals: depth);
565 };
566
567 // The order of generating static matcher follows the topological order so
568 // that for every dependent DagNode already have their static matcher
569 // generated if needed. The reason we check if `getMatcherName(tree).empty()`
570 // is when we are generating the static matcher for a DagNode itself. In this
571 // case, we need to emit the function body rather than a function call.
572 if (staticMatcherHelper.useStaticMatcher(node: tree) &&
573 !staticMatcherHelper.getMatcherName(node: tree).empty()) {
574 emitStaticMatchCall(tree, opName);
575 // In the codegen of rewriter, we suppose that castedOp0 will capture the
576 // root operation. Manually add it if the root DagNode is a static matcher.
577 if (depth == 0)
578 os << formatv(Fmt: "auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); "
579 "(void){2};\n",
580 Vals&: opName, Vals: op.getQualCppClassName(), Vals: getCastedName());
581 return;
582 }
583
584 std::string castedName = getCastedName();
585 os << formatv(Fmt: "auto {0} = ::llvm::dyn_cast<{2}>({1}); "
586 "(void){0};\n",
587 Vals&: castedName, Vals&: opName, Vals: op.getQualCppClassName());
588
589 // Skip the operand matching at depth 0 as the pattern rewriter already does.
590 if (depth != 0)
591 emitMatchCheck(opName, /*matchStr=*/castedName,
592 failureStr: formatv(Fmt: "\"{0} is not {1} type\"", Vals&: castedName,
593 Vals: op.getQualCppClassName()));
594
595 // If the operand's name is set, set to that variable.
596 auto name = tree.getSymbol();
597 if (!name.empty())
598 os << formatv(Fmt: "{0} = {1};\n", Vals&: name, Vals&: castedName);
599
600 for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e;
601 ++i, ++opArgIdx) {
602 auto opArg = op.getArg(index: opArgIdx);
603 std::string argName = formatv(Fmt: "op{0}", Vals: depth + 1);
604
605 // Handle nested DAG construct first
606 if (DagNode argTree = tree.getArgAsNestedDag(index: i)) {
607 if (argTree.isEither()) {
608 emitEitherOperandMatch(tree, eitherArgTree: argTree, opName: castedName, argIndex: opArgIdx, operandIndex&: nextOperand,
609 depth);
610 ++opArgIdx;
611 continue;
612 }
613 if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: opArg)) {
614 if (argTree.isVariadic()) {
615 if (!operand->isVariadic()) {
616 auto error = formatv(Fmt: "variadic DAG construct can't match op {0}'s "
617 "non-variadic operand #{1}",
618 Vals: op.getOperationName(), Vals&: opArgIdx);
619 PrintFatalError(ErrorLoc: loc, Msg: error);
620 }
621 emitVariadicOperandMatch(tree, variadicArgTree: argTree, opName: castedName, argIndex: opArgIdx,
622 operandIndex&: nextOperand, depth);
623 ++nextOperand;
624 continue;
625 }
626 if (operand->isVariableLength()) {
627 auto error = formatv(Fmt: "use nested DAG construct to match op {0}'s "
628 "variadic operand #{1} unsupported now",
629 Vals: op.getOperationName(), Vals&: opArgIdx);
630 PrintFatalError(ErrorLoc: loc, Msg: error);
631 }
632 }
633
634 os << "{\n";
635
636 // Attributes don't count for getODSOperands.
637 // TODO: Operand is a Value, check if we should remove `getDefiningOp()`.
638 os.indent() << formatv(
639 Fmt: "auto *{0} = "
640 "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
641 Vals&: argName, Vals&: castedName, Vals&: nextOperand);
642 // Null check of operand's definingOp
643 emitMatchCheck(
644 opName: castedName, /*matchStr=*/argName,
645 failureStr: formatv(Fmt: "\"There's no operation that defines operand {0} of {1}\"",
646 Vals: nextOperand++, Vals&: castedName));
647 emitMatch(tree: argTree, name: argName, depth: depth + 1);
648 os << formatv(Fmt: "tblgen_ops.push_back({0});\n", Vals&: argName);
649 os.unindent() << "}\n";
650 continue;
651 }
652
653 // Next handle DAG leaf: operand or attribute
654 if (opArg.is<NamedTypeConstraint *>()) {
655 auto operandName =
656 formatv(Fmt: "{0}.getODSOperands({1})", Vals&: castedName, Vals&: nextOperand);
657 emitOperandMatch(tree, opName: castedName, operandName: operandName.str(), operandIndex: opArgIdx,
658 /*operandMatcher=*/tree.getArgAsLeaf(index: i),
659 /*argName=*/tree.getArgName(index: i), argIndex: opArgIdx,
660 /*variadicSubIndex=*/std::nullopt);
661 ++nextOperand;
662 } else if (opArg.is<NamedAttribute *>()) {
663 emitAttributeMatch(tree, opName, argIndex: opArgIdx, depth);
664 } else {
665 PrintFatalError(ErrorLoc: loc, Msg: "unhandled case when matching op");
666 }
667 }
668 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
669 << op.getOperationName() << "' at depth " << depth
670 << '\n');
671}
672
673void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
674 StringRef operandName, int operandIndex,
675 DagLeaf operandMatcher, StringRef argName,
676 int argIndex,
677 std::optional<int> variadicSubIndex) {
678 Operator &op = tree.getDialectOp(mapper: opMap);
679 auto *operand = op.getArg(index: operandIndex).get<NamedTypeConstraint *>();
680
681 // If a constraint is specified, we need to generate C++ statements to
682 // check the constraint.
683 if (!operandMatcher.isUnspecified()) {
684 if (!operandMatcher.isOperandMatcher())
685 PrintFatalError(
686 ErrorLoc: loc, Msg: formatv(Fmt: "the {1}-th argument of op '{0}' should be an operand",
687 Vals: op.getOperationName(), Vals: argIndex + 1));
688
689 // Only need to verify if the matcher's type is different from the one
690 // of op definition.
691 Constraint constraint = operandMatcher.getAsConstraint();
692 if (operand->constraint != constraint) {
693 if (operand->isVariableLength()) {
694 auto error = formatv(
695 Fmt: "further constrain op {0}'s variadic operand #{1} unsupported now",
696 Vals: op.getOperationName(), Vals&: argIndex);
697 PrintFatalError(ErrorLoc: loc, Msg: error);
698 }
699 auto self = formatv(Fmt: "(*{0}.begin()).getType()", Vals&: operandName);
700 StringRef verifier = staticMatcherHelper.getVerifierName(leaf: operandMatcher);
701 emitStaticVerifierCall(
702 funcName: verifier, opName, arg: self.str(),
703 failureStr: formatv(
704 Fmt: "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
705 Vals: operand - op.operand_begin(), Vals: op.getOperationName(),
706 Vals: escapeString(value: constraint.getSummary()))
707 .str());
708 }
709 }
710
711 // Capture the value
712 // `$_` is a special symbol to ignore op argument matching.
713 if (!argName.empty() && argName != "_") {
714 auto res = symbolInfoMap.findBoundSymbol(key: argName, node: tree, op, argIndex: operandIndex,
715 variadicSubIndex);
716 if (res == symbolInfoMap.end())
717 PrintFatalError(ErrorLoc: loc, Msg: formatv(Fmt: "symbol not found: {0}", Vals&: argName));
718
719 os << formatv(Fmt: "{0} = {1};\n", Vals: res->second.getVarName(name: argName), Vals&: operandName);
720 }
721}
722
723void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
724 StringRef opName, int argIndex,
725 int &operandIndex, int depth) {
726 constexpr int numEitherArgs = 2;
727 if (eitherArgTree.getNumArgs() != numEitherArgs)
728 PrintFatalError(ErrorLoc: loc, Msg: "`either` only supports grouping two operands");
729
730 Operator &op = tree.getDialectOp(mapper: opMap);
731
732 std::string codeBuffer;
733 llvm::raw_string_ostream tblgenOps(codeBuffer);
734
735 std::string lambda = formatv(Fmt: "eitherLambda{0}", Vals&: depth);
736 os << formatv(
737 Fmt: "auto {0} = [&](::mlir::OperandRange v0, ::mlir::OperandRange v1) {{\n",
738 Vals&: lambda);
739
740 os.indent();
741
742 for (int i = 0; i < numEitherArgs; ++i, ++argIndex) {
743 if (DagNode argTree = eitherArgTree.getArgAsNestedDag(index: i)) {
744 if (argTree.isEither())
745 PrintFatalError(ErrorLoc: loc, Msg: "either cannot be nested");
746
747 std::string argName = formatv(Fmt: "local_op_{0}", Vals&: i).str();
748
749 os << formatv(Fmt: "auto {0} = (*v{1}.begin()).getDefiningOp();\n", Vals&: argName,
750 Vals&: i);
751
752 // Indent emitMatchCheck and emitMatch because they declare local
753 // variables.
754 os << "{\n";
755 os.indent();
756
757 emitMatchCheck(
758 opName, /*matchStr=*/argName,
759 failureStr: formatv(Fmt: "\"There's no operation that defines operand {0} of {1}\"",
760 Vals: operandIndex++, Vals&: opName));
761 emitMatch(tree: argTree, name: argName, depth: depth + 1);
762
763 os.unindent() << "}\n";
764
765 // `tblgen_ops` is used to collect the matched operations. In either, we
766 // need to queue the operation only if the matching success. Thus we emit
767 // the code at the end.
768 tblgenOps << formatv(Fmt: "tblgen_ops.push_back({0});\n", Vals&: argName);
769 } else if (op.getArg(index: argIndex).is<NamedTypeConstraint *>()) {
770 emitOperandMatch(tree, opName, /*operandName=*/formatv(Fmt: "v{0}", Vals&: i).str(),
771 operandIndex,
772 /*operandMatcher=*/eitherArgTree.getArgAsLeaf(index: i),
773 /*argName=*/eitherArgTree.getArgName(index: i), argIndex,
774 /*variadicSubIndex=*/std::nullopt);
775 ++operandIndex;
776 } else {
777 PrintFatalError(ErrorLoc: loc, Msg: "either can only be applied on operand");
778 }
779 }
780
781 os << tblgenOps.str();
782 os << "return ::mlir::success();\n";
783 os.unindent() << "};\n";
784
785 os << "{\n";
786 os.indent();
787
788 os << formatv(Fmt: "auto eitherOperand0 = {0}.getODSOperands({1});\n", Vals&: opName,
789 Vals: operandIndex - 2);
790 os << formatv(Fmt: "auto eitherOperand1 = {0}.getODSOperands({1});\n", Vals&: opName,
791 Vals: operandIndex - 1);
792
793 os << formatv(Fmt: "if(::mlir::failed({0}(eitherOperand0, eitherOperand1)) && "
794 "::mlir::failed({0}(eitherOperand1, "
795 "eitherOperand0)))\n",
796 Vals&: lambda);
797 os.indent() << "return ::mlir::failure();\n";
798
799 os.unindent().unindent() << "}\n";
800}
801
802void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
803 DagNode variadicArgTree,
804 StringRef opName, int argIndex,
805 int &operandIndex, int depth) {
806 Operator &op = tree.getDialectOp(mapper: opMap);
807
808 os << "{\n";
809 os.indent();
810
811 os << formatv(Fmt: "auto variadic_operand_range = {0}.getODSOperands({1});\n",
812 Vals&: opName, Vals&: operandIndex);
813 os << formatv(Fmt: "if (variadic_operand_range.size() != {0}) "
814 "return ::mlir::failure();\n",
815 Vals: variadicArgTree.getNumArgs());
816
817 StringRef variadicTreeName = variadicArgTree.getSymbol();
818 if (!variadicTreeName.empty()) {
819 auto res =
820 symbolInfoMap.findBoundSymbol(key: variadicTreeName, node: tree, op, argIndex: operandIndex,
821 /*variadicSubIndex=*/std::nullopt);
822 if (res == symbolInfoMap.end())
823 PrintFatalError(ErrorLoc: loc, Msg: formatv(Fmt: "symbol not found: {0}", Vals&: variadicTreeName));
824
825 os << formatv(Fmt: "{0} = variadic_operand_range;\n",
826 Vals: res->second.getVarName(name: variadicTreeName));
827 }
828
829 for (int i = 0; i < variadicArgTree.getNumArgs(); ++i) {
830 if (DagNode argTree = variadicArgTree.getArgAsNestedDag(index: i)) {
831 if (!argTree.isOperation())
832 PrintFatalError(ErrorLoc: loc, Msg: "variadic only accepts operation sub-dags");
833
834 os << "{\n";
835 os.indent();
836
837 std::string argName = formatv(Fmt: "local_op_{0}", Vals&: i).str();
838 os << formatv(Fmt: "auto *{0} = "
839 "variadic_operand_range[{1}].getDefiningOp();\n",
840 Vals&: argName, Vals&: i);
841 emitMatchCheck(
842 opName, /*matchStr=*/argName,
843 failureStr: formatv(Fmt: "\"There's no operation that defines variadic operand "
844 "{0} (variadic sub-opearnd #{1}) of {2}\"",
845 Vals&: operandIndex, Vals&: i, Vals&: opName));
846 emitMatch(tree: argTree, name: argName, depth: depth + 1);
847 os << formatv(Fmt: "tblgen_ops.push_back({0});\n", Vals&: argName);
848
849 os.unindent() << "}\n";
850 } else if (op.getArg(index: argIndex).is<NamedTypeConstraint *>()) {
851 auto operandName = formatv(Fmt: "variadic_operand_range.slice({0}, 1)", Vals&: i);
852 emitOperandMatch(tree, opName, operandName: operandName.str(), operandIndex,
853 /*operandMatcher=*/variadicArgTree.getArgAsLeaf(index: i),
854 /*argName=*/variadicArgTree.getArgName(index: i), argIndex, variadicSubIndex: i);
855 } else {
856 PrintFatalError(ErrorLoc: loc, Msg: "variadic can only be applied on operand");
857 }
858 }
859
860 os.unindent() << "}\n";
861}
862
863void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
864 int argIndex, int depth) {
865 Operator &op = tree.getDialectOp(mapper: opMap);
866 auto *namedAttr = op.getArg(index: argIndex).get<NamedAttribute *>();
867 const auto &attr = namedAttr->attr;
868
869 os << "{\n";
870 os.indent() << formatv(Fmt: "auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
871 "(void)tblgen_attr;\n",
872 Vals&: opName, Vals: attr.getStorageType(), Vals&: namedAttr->name);
873
874 // TODO: This should use getter method to avoid duplication.
875 if (attr.hasDefaultValue()) {
876 os << "if (!tblgen_attr) tblgen_attr = "
877 << std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fmtCtx,
878 vals: attr.getDefaultValue()))
879 << ";\n";
880 } else if (attr.isOptional()) {
881 // For a missing attribute that is optional according to definition, we
882 // should just capture a mlir::Attribute() to signal the missing state.
883 // That is precisely what getDiscardableAttr() returns on missing
884 // attributes.
885 } else {
886 emitMatchCheck(opName, matchFmt: tgfmt(fmt: "tblgen_attr", ctx: &fmtCtx),
887 failureFmt: formatv(Fmt: "\"expected op '{0}' to have attribute '{1}' "
888 "of type '{2}'\"",
889 Vals: op.getOperationName(), Vals&: namedAttr->name,
890 Vals: attr.getStorageType()));
891 }
892
893 auto matcher = tree.getArgAsLeaf(index: argIndex);
894 if (!matcher.isUnspecified()) {
895 if (!matcher.isAttrMatcher()) {
896 PrintFatalError(
897 ErrorLoc: loc, Msg: formatv(Fmt: "the {1}-th argument of op '{0}' should be an attribute",
898 Vals: op.getOperationName(), Vals: argIndex + 1));
899 }
900
901 // If a constraint is specified, we need to generate function call to its
902 // static verifier.
903 StringRef verifier = staticMatcherHelper.getVerifierName(leaf: matcher);
904 if (attr.isOptional()) {
905 // Avoid dereferencing null attribute. This is using a simple heuristic to
906 // avoid common cases of attempting to dereference null attribute. This
907 // will return where there is no check if attribute is null unless the
908 // attribute's value is not used.
909 // FIXME: This could be improved as some null dereferences could slip
910 // through.
911 if (!StringRef(matcher.getConditionTemplate()).contains(Other: "!$_self") &&
912 StringRef(matcher.getConditionTemplate()).contains(Other: "$_self")) {
913 os << "if (!tblgen_attr) return ::mlir::failure();\n";
914 }
915 }
916 emitStaticVerifierCall(
917 funcName: verifier, opName, arg: "tblgen_attr",
918 failureStr: formatv(Fmt: "\"op '{0}' attribute '{1}' failed to satisfy constraint: "
919 "'{2}'\"",
920 Vals: op.getOperationName(), Vals&: namedAttr->name,
921 Vals: escapeString(value: matcher.getAsConstraint().getSummary()))
922 .str());
923 }
924
925 // Capture the value
926 auto name = tree.getArgName(index: argIndex);
927 // `$_` is a special symbol to ignore op argument matching.
928 if (!name.empty() && name != "_") {
929 os << formatv(Fmt: "{0} = tblgen_attr;\n", Vals&: name);
930 }
931
932 os.unindent() << "}\n";
933}
934
935void PatternEmitter::emitMatchCheck(
936 StringRef opName, const FmtObjectBase &matchFmt,
937 const llvm::formatv_object_base &failureFmt) {
938 emitMatchCheck(opName, matchStr: matchFmt.str(), failureStr: failureFmt.str());
939}
940
941void PatternEmitter::emitMatchCheck(StringRef opName,
942 const std::string &matchStr,
943 const std::string &failureStr) {
944
945 os << "if (!(" << matchStr << "))";
946 os.scope(open: "{\n", close: "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
947 << ", [&](::mlir::Diagnostic &diag) {\n diag << "
948 << failureStr << ";\n});";
949}
950
951void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
952 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
953 int depth = 0;
954 emitMatch(tree, name: opName, depth);
955
956 for (auto &appliedConstraint : pattern.getConstraints()) {
957 auto &constraint = appliedConstraint.constraint;
958 auto &entities = appliedConstraint.entities;
959
960 auto condition = constraint.getConditionTemplate();
961 if (isa<TypeConstraint>(Val: constraint)) {
962 if (entities.size() != 1)
963 PrintFatalError(ErrorLoc: loc, Msg: "type constraint requires exactly one argument");
964
965 auto self = formatv(Fmt: "({0}.getType())",
966 Vals: symbolInfoMap.getValueAndRangeUse(symbol: entities.front()));
967 emitMatchCheck(
968 opName, matchFmt: tgfmt(fmt: condition, ctx: &fmtCtx.withSelf(subst: self.str())),
969 failureFmt: formatv(Fmt: "\"value entity '{0}' failed to satisfy constraint: '{1}'\"",
970 Vals&: entities.front(), Vals: escapeString(value: constraint.getSummary())));
971
972 } else if (isa<AttrConstraint>(Val: constraint)) {
973 PrintFatalError(
974 ErrorLoc: loc, Msg: "cannot use AttrConstraint in Pattern multi-entity constraints");
975 } else {
976 // TODO: replace formatv arguments with the exact specified
977 // args.
978 if (entities.size() > 4) {
979 PrintFatalError(ErrorLoc: loc, Msg: "only support up to 4-entity constraints now");
980 }
981 SmallVector<std::string, 4> names;
982 int i = 0;
983 for (int e = entities.size(); i < e; ++i)
984 names.push_back(Elt: symbolInfoMap.getValueAndRangeUse(symbol: entities[i]));
985 std::string self = appliedConstraint.self;
986 if (!self.empty())
987 self = symbolInfoMap.getValueAndRangeUse(symbol: self);
988 for (; i < 4; ++i)
989 names.push_back(Elt: "<unused>");
990 emitMatchCheck(opName,
991 matchFmt: tgfmt(fmt: condition, ctx: &fmtCtx.withSelf(subst: self), vals&: names[0],
992 vals&: names[1], vals&: names[2], vals&: names[3]),
993 failureFmt: formatv(Fmt: "\"entities '{0}' failed to satisfy constraint: "
994 "'{1}'\"",
995 Vals: llvm::join(R&: entities, Separator: ", "),
996 Vals: escapeString(value: constraint.getSummary())));
997 }
998 }
999
1000 // Some of the operands could be bound to the same symbol name, we need
1001 // to enforce equality constraint on those.
1002 // TODO: we should be able to emit equality checks early
1003 // and short circuit unnecessary work if vars are not equal.
1004 for (auto symbolInfoIt = symbolInfoMap.begin();
1005 symbolInfoIt != symbolInfoMap.end();) {
1006 auto range = symbolInfoMap.getRangeOfEqualElements(key: symbolInfoIt->first);
1007 auto startRange = range.first;
1008 auto endRange = range.second;
1009
1010 auto firstOperand = symbolInfoIt->second.getVarName(name: symbolInfoIt->first);
1011 for (++startRange; startRange != endRange; ++startRange) {
1012 auto secondOperand = startRange->second.getVarName(name: symbolInfoIt->first);
1013 emitMatchCheck(
1014 opName,
1015 matchStr: formatv(Fmt: "*{0}.begin() == *{1}.begin()", Vals&: firstOperand, Vals&: secondOperand),
1016 failureStr: formatv(Fmt: "\"Operands '{0}' and '{1}' must be equal\"", Vals&: firstOperand,
1017 Vals&: secondOperand));
1018 }
1019
1020 symbolInfoIt = endRange;
1021 }
1022
1023 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
1024}
1025
1026void PatternEmitter::collectOps(DagNode tree,
1027 llvm::SmallPtrSetImpl<const Operator *> &ops) {
1028 // Check if this tree is an operation.
1029 if (tree.isOperation()) {
1030 const Operator &op = tree.getDialectOp(mapper: opMap);
1031 LLVM_DEBUG(llvm::dbgs()
1032 << "found operation " << op.getOperationName() << '\n');
1033 ops.insert(Ptr: &op);
1034 }
1035
1036 // Recurse the arguments of the tree.
1037 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
1038 if (auto child = tree.getArgAsNestedDag(index: i))
1039 collectOps(tree: child, ops);
1040}
1041
1042void PatternEmitter::emit(StringRef rewriteName) {
1043 // Get the DAG tree for the source pattern.
1044 DagNode sourceTree = pattern.getSourcePattern();
1045
1046 const Operator &rootOp = pattern.getSourceRootOp();
1047 auto rootName = rootOp.getOperationName();
1048
1049 // Collect the set of result operations.
1050 llvm::SmallPtrSet<const Operator *, 4> resultOps;
1051 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
1052 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
1053 collectOps(tree: pattern.getResultPattern(index: i), ops&: resultOps);
1054 }
1055 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
1056
1057 // Emit RewritePattern for Pattern.
1058 auto locs = pattern.getLocation();
1059 os << formatv(Fmt: "/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n",
1060 Vals: make_range(x: locs.rbegin(), y: locs.rend()));
1061 os << formatv(Fmt: R"(struct {0} : public ::mlir::RewritePattern {
1062 {0}(::mlir::MLIRContext *context)
1063 : ::mlir::RewritePattern("{1}", {2}, context, {{)",
1064 Vals&: rewriteName, Vals&: rootName, Vals: pattern.getBenefit());
1065 // Sort result operators by name.
1066 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
1067 resultOps.end());
1068 llvm::sort(C&: sortedResultOps, Comp: [&](const Operator *lhs, const Operator *rhs) {
1069 return lhs->getOperationName() < rhs->getOperationName();
1070 });
1071 llvm::interleaveComma(c: sortedResultOps, os, each_fn: [&](const Operator *op) {
1072 os << '"' << op->getOperationName() << '"';
1073 });
1074 os << "}) {}\n";
1075
1076 // Emit matchAndRewrite() function.
1077 {
1078 auto classScope = os.scope();
1079 os.printReindented(str: R"(
1080 ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
1081 ::mlir::PatternRewriter &rewriter) const override {)")
1082 << '\n';
1083 {
1084 auto functionScope = os.scope();
1085
1086 // Register all symbols bound in the source pattern.
1087 pattern.collectSourcePatternBoundSymbols(infoMap&: symbolInfoMap);
1088
1089 LLVM_DEBUG(llvm::dbgs()
1090 << "start creating local variables for capturing matches\n");
1091 os << "// Variables for capturing values and attributes used while "
1092 "creating ops\n";
1093 // Create local variables for storing the arguments and results bound
1094 // to symbols.
1095 for (const auto &symbolInfoPair : symbolInfoMap) {
1096 const auto &symbol = symbolInfoPair.first;
1097 const auto &info = symbolInfoPair.second;
1098
1099 os << info.getVarDecl(name: symbol);
1100 }
1101 // TODO: capture ops with consistent numbering so that it can be
1102 // reused for fused loc.
1103 os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n";
1104 LLVM_DEBUG(llvm::dbgs()
1105 << "done creating local variables for capturing matches\n");
1106
1107 os << "// Match\n";
1108 os << "tblgen_ops.push_back(op0);\n";
1109 emitMatchLogic(tree: sourceTree, opName: "op0");
1110
1111 os << "\n// Rewrite\n";
1112 emitRewriteLogic();
1113
1114 os << "return ::mlir::success();\n";
1115 }
1116 os << "};\n";
1117 }
1118 os << "};\n\n";
1119}
1120
1121void PatternEmitter::emitRewriteLogic() {
1122 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
1123 const Operator &rootOp = pattern.getSourceRootOp();
1124 int numExpectedResults = rootOp.getNumResults();
1125 int numResultPatterns = pattern.getNumResultPatterns();
1126
1127 // First register all symbols bound to ops generated in result patterns.
1128 pattern.collectResultPatternBoundSymbols(infoMap&: symbolInfoMap);
1129
1130 // Only the last N static values generated are used to replace the matched
1131 // root N-result op. We need to calculate the starting index (of the results
1132 // of the matched op) each result pattern is to replace.
1133 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
1134 // If we don't need to replace any value at all, set the replacement starting
1135 // index as the number of result patterns so we skip all of them when trying
1136 // to replace the matched op's results.
1137 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
1138 for (int i = numResultPatterns - 1; i >= 0; --i) {
1139 auto numValues = getNodeValueCount(node: pattern.getResultPattern(index: i));
1140 offsets[i] = offsets[i + 1] - numValues;
1141 if (offsets[i] == 0) {
1142 if (replStartIndex == -1)
1143 replStartIndex = i;
1144 } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
1145 auto error = formatv(
1146 Fmt: "cannot use the same multi-result op '{0}' to generate both "
1147 "auxiliary values and values to be used for replacing the matched op",
1148 Vals: pattern.getResultPattern(index: i).getSymbol());
1149 PrintFatalError(ErrorLoc: loc, Msg: error);
1150 }
1151 }
1152
1153 if (offsets.front() > 0) {
1154 const char error[] =
1155 "not enough values generated to replace the matched op";
1156 PrintFatalError(ErrorLoc: loc, Msg: error);
1157 }
1158
1159 os << "auto odsLoc = rewriter.getFusedLoc({";
1160 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
1161 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
1162 }
1163 os << "}); (void)odsLoc;\n";
1164
1165 // Process auxiliary result patterns.
1166 for (int i = 0; i < replStartIndex; ++i) {
1167 DagNode resultTree = pattern.getResultPattern(index: i);
1168 auto val = handleResultPattern(resultTree, resultIndex: offsets[i], depth: 0);
1169 // Normal op creation will be streamed to `os` by the above call; but
1170 // NativeCodeCall will only be materialized to `os` if it is used. Here
1171 // we are handling auxiliary patterns so we want the side effect even if
1172 // NativeCodeCall is not replacing matched root op's results.
1173 if (resultTree.isNativeCodeCall() &&
1174 resultTree.getNumReturnsOfNativeCode() == 0)
1175 os << val << ";\n";
1176 }
1177
1178 auto processSupplementalPatterns = [&]() {
1179 int numSupplementalPatterns = pattern.getNumSupplementalPatterns();
1180 for (int i = 0, offset = -numSupplementalPatterns;
1181 i < numSupplementalPatterns; ++i) {
1182 DagNode resultTree = pattern.getSupplementalPattern(index: i);
1183 auto val = handleResultPattern(resultTree, resultIndex: offset++, depth: 0);
1184 if (resultTree.isNativeCodeCall() &&
1185 resultTree.getNumReturnsOfNativeCode() == 0)
1186 os << val << ";\n";
1187 }
1188 };
1189
1190 if (numExpectedResults == 0) {
1191 assert(replStartIndex >= numResultPatterns &&
1192 "invalid auxiliary vs. replacement pattern division!");
1193 processSupplementalPatterns();
1194 // No result to replace. Just erase the op.
1195 os << "rewriter.eraseOp(op0);\n";
1196 } else {
1197 // Process replacement result patterns.
1198 os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
1199 for (int i = replStartIndex; i < numResultPatterns; ++i) {
1200 DagNode resultTree = pattern.getResultPattern(index: i);
1201 auto val = handleResultPattern(resultTree, resultIndex: offsets[i], depth: 0);
1202 os << "\n";
1203 // Resolve each symbol for all range use so that we can loop over them.
1204 // We need an explicit cast to `SmallVector` to capture the cases where
1205 // `{0}` resolves to an `Operation::result_range` as well as cases that
1206 // are not iterable (e.g. vector that gets wrapped in additional braces by
1207 // RewriterGen).
1208 // TODO: Revisit the need for materializing a vector.
1209 os << symbolInfoMap.getAllRangeUse(
1210 symbol: val,
1211 fmt: "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
1212 " tblgen_repl_values.push_back(v);\n}\n",
1213 separator: "\n");
1214 }
1215 processSupplementalPatterns();
1216 os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
1217 }
1218
1219 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
1220}
1221
1222std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
1223 return std::string(
1224 formatv(Fmt: "tblgen_{0}_{1}", Vals: op->getCppClassName(), Vals: nextValueId++));
1225}
1226
1227std::string PatternEmitter::handleResultPattern(DagNode resultTree,
1228 int resultIndex, int depth) {
1229 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
1230 LLVM_DEBUG(resultTree.print(llvm::dbgs()));
1231 LLVM_DEBUG(llvm::dbgs() << '\n');
1232
1233 if (resultTree.isLocationDirective()) {
1234 PrintFatalError(ErrorLoc: loc,
1235 Msg: "location directive can only be used with op creation");
1236 }
1237
1238 if (resultTree.isNativeCodeCall())
1239 return handleReplaceWithNativeCodeCall(resultTree, depth);
1240
1241 if (resultTree.isReplaceWithValue())
1242 return handleReplaceWithValue(tree: resultTree).str();
1243
1244 // Normal op creation.
1245 auto symbol = handleOpCreation(tree: resultTree, resultIndex, depth);
1246 if (resultTree.getSymbol().empty()) {
1247 // This is an op not explicitly bound to a symbol in the rewrite rule.
1248 // Register the auto-generated symbol for it.
1249 symbolInfoMap.bindOpResult(symbol, op: pattern.getDialectOp(node: resultTree));
1250 }
1251 return symbol;
1252}
1253
1254StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
1255 assert(tree.isReplaceWithValue());
1256
1257 if (tree.getNumArgs() != 1) {
1258 PrintFatalError(
1259 ErrorLoc: loc, Msg: "replaceWithValue directive must take exactly one argument");
1260 }
1261
1262 if (!tree.getSymbol().empty()) {
1263 PrintFatalError(ErrorLoc: loc, Msg: "cannot bind symbol to replaceWithValue");
1264 }
1265
1266 return tree.getArgName(index: 0);
1267}
1268
1269std::string PatternEmitter::handleLocationDirective(DagNode tree) {
1270 assert(tree.isLocationDirective());
1271 auto lookUpArgLoc = [this, &tree](int idx) {
1272 const auto *const lookupFmt = "{0}.getLoc()";
1273 return symbolInfoMap.getValueAndRangeUse(symbol: tree.getArgName(index: idx), fmt: lookupFmt);
1274 };
1275
1276 if (tree.getNumArgs() == 0)
1277 llvm::PrintFatalError(
1278 Msg: "At least one argument to location directive required");
1279
1280 if (!tree.getSymbol().empty())
1281 PrintFatalError(ErrorLoc: loc, Msg: "cannot bind symbol to location");
1282
1283 if (tree.getNumArgs() == 1) {
1284 DagLeaf leaf = tree.getArgAsLeaf(index: 0);
1285 if (leaf.isStringAttr())
1286 return formatv(Fmt: "::mlir::NameLoc::get(rewriter.getStringAttr(\"{0}\"))",
1287 Vals: leaf.getStringAttr())
1288 .str();
1289 return lookUpArgLoc(0);
1290 }
1291
1292 std::string ret;
1293 llvm::raw_string_ostream os(ret);
1294 std::string strAttr;
1295 os << "rewriter.getFusedLoc({";
1296 bool first = true;
1297 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
1298 DagLeaf leaf = tree.getArgAsLeaf(index: i);
1299 // Handle the optional string value.
1300 if (leaf.isStringAttr()) {
1301 if (!strAttr.empty())
1302 llvm::PrintFatalError(Msg: "Only one string attribute may be specified");
1303 strAttr = leaf.getStringAttr();
1304 continue;
1305 }
1306 os << (first ? "" : ", ") << lookUpArgLoc(i);
1307 first = false;
1308 }
1309 os << "}";
1310 if (!strAttr.empty()) {
1311 os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
1312 }
1313 os << ")";
1314 return os.str();
1315}
1316
1317std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i,
1318 int depth) {
1319 // Nested NativeCodeCall.
1320 if (auto dagNode = returnType.getArgAsNestedDag(index: i)) {
1321 if (!dagNode.isNativeCodeCall())
1322 PrintFatalError(ErrorLoc: loc, Msg: "nested DAG in `returnType` must be a native code "
1323 "call");
1324 return handleReplaceWithNativeCodeCall(resultTree: dagNode, depth);
1325 }
1326 // String literal.
1327 auto dagLeaf = returnType.getArgAsLeaf(index: i);
1328 if (dagLeaf.isStringAttr())
1329 return tgfmt(fmt: dagLeaf.getStringAttr(), ctx: &fmtCtx);
1330 return tgfmt(
1331 fmt: "$0.getType()", ctx: &fmtCtx,
1332 vals: handleOpArgument(leaf: returnType.getArgAsLeaf(index: i), patArgName: returnType.getArgName(index: i)));
1333}
1334
1335std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
1336 StringRef patArgName) {
1337 if (leaf.isStringAttr())
1338 PrintFatalError(ErrorLoc: loc, Msg: "raw string not supported as argument");
1339 if (leaf.isConstantAttr()) {
1340 auto constAttr = leaf.getAsConstantAttr();
1341 return handleConstantAttr(attr: constAttr.getAttribute(),
1342 value: constAttr.getConstantValue());
1343 }
1344 if (leaf.isEnumAttrCase()) {
1345 auto enumCase = leaf.getAsEnumAttrCase();
1346 // This is an enum case backed by an IntegerAttr. We need to get its value
1347 // to build the constant.
1348 std::string val = std::to_string(val: enumCase.getValue());
1349 return handleConstantAttr(attr: enumCase, value: val);
1350 }
1351
1352 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
1353 auto argName = symbolInfoMap.getValueAndRangeUse(symbol: patArgName);
1354 if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
1355 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
1356 << "' (via symbol ref)\n");
1357 return argName;
1358 }
1359 if (leaf.isNativeCodeCall()) {
1360 auto repl = tgfmt(fmt: leaf.getNativeCodeTemplate(), ctx: &fmtCtx.withSelf(subst: argName));
1361 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
1362 << "' (via NativeCodeCall)\n");
1363 return std::string(repl);
1364 }
1365 PrintFatalError(ErrorLoc: loc, Msg: "unhandled case when rewriting op");
1366}
1367
1368std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
1369 int depth) {
1370 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
1371 LLVM_DEBUG(tree.print(llvm::dbgs()));
1372 LLVM_DEBUG(llvm::dbgs() << '\n');
1373
1374 auto fmt = tree.getNativeCodeTemplate();
1375
1376 SmallVector<std::string, 16> attrs;
1377
1378 auto tail = getTrailingDirectives(tree);
1379 if (tail.returnType)
1380 PrintFatalError(ErrorLoc: loc, Msg: "`NativeCodeCall` cannot have return type specifier");
1381 auto locToUse = getLocation(tail);
1382
1383 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1384 if (tree.isNestedDagArg(index: i)) {
1385 attrs.push_back(
1386 Elt: handleResultPattern(resultTree: tree.getArgAsNestedDag(index: i), resultIndex: i, depth: depth + 1));
1387 } else {
1388 attrs.push_back(
1389 Elt: handleOpArgument(leaf: tree.getArgAsLeaf(index: i), patArgName: tree.getArgName(index: i)));
1390 }
1391 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
1392 << " replacement: " << attrs[i] << "\n");
1393 }
1394
1395 std::string symbol = tgfmt(fmt, ctx: &fmtCtx.addSubst(placeholder: "_loc", subst: locToUse),
1396 params: static_cast<ArrayRef<std::string>>(attrs));
1397
1398 // In general, NativeCodeCall without naming binding don't need this. To
1399 // ensure void helper function has been correctly labeled, i.e., use
1400 // NativeCodeCallVoid, we cache the result to a local variable so that we will
1401 // get a compilation error in the auto-generated file.
1402 // Example.
1403 // // In the td file
1404 // Pat<(...), (NativeCodeCall<Foo> ...)>
1405 //
1406 // ---
1407 //
1408 // // In the auto-generated .cpp
1409 // ...
1410 // // Causes compilation error if Foo() returns void.
1411 // auto nativeVar = Foo();
1412 // ...
1413 if (tree.getNumReturnsOfNativeCode() != 0) {
1414 // Determine the local variable name for return value.
1415 std::string varName =
1416 SymbolInfoMap::getValuePackName(symbol: tree.getSymbol()).str();
1417 if (varName.empty()) {
1418 varName = formatv(Fmt: "nativeVar_{0}", Vals: nextValueId++);
1419 // Register the local variable for later uses.
1420 symbolInfoMap.bindValues(symbol: varName, numValues: tree.getNumReturnsOfNativeCode());
1421 }
1422
1423 // Catch the return value of helper function.
1424 os << formatv(Fmt: "auto {0} = {1}; (void){0};\n", Vals&: varName, Vals&: symbol);
1425
1426 if (!tree.getSymbol().empty())
1427 symbol = tree.getSymbol().str();
1428 else
1429 symbol = varName;
1430 }
1431
1432 return symbol;
1433}
1434
1435int PatternEmitter::getNodeValueCount(DagNode node) {
1436 if (node.isOperation()) {
1437 // If the op is bound to a symbol in the rewrite rule, query its result
1438 // count from the symbol info map.
1439 auto symbol = node.getSymbol();
1440 if (!symbol.empty()) {
1441 return symbolInfoMap.getStaticValueCount(symbol);
1442 }
1443 // Otherwise this is an unbound op; we will use all its results.
1444 return pattern.getDialectOp(node).getNumResults();
1445 }
1446
1447 if (node.isNativeCodeCall())
1448 return node.getNumReturnsOfNativeCode();
1449
1450 return 1;
1451}
1452
1453PatternEmitter::TrailingDirectives
1454PatternEmitter::getTrailingDirectives(DagNode tree) {
1455 TrailingDirectives tail = {.location: DagNode(nullptr), .returnType: DagNode(nullptr), .numDirectives: 0};
1456
1457 // Look backwards through the arguments.
1458 auto numPatArgs = tree.getNumArgs();
1459 for (int i = numPatArgs - 1; i >= 0; --i) {
1460 auto dagArg = tree.getArgAsNestedDag(index: i);
1461 // A leaf is not a directive. Stop looking.
1462 if (!dagArg)
1463 break;
1464
1465 auto isLocation = dagArg.isLocationDirective();
1466 auto isReturnType = dagArg.isReturnTypeDirective();
1467 // If encountered a DAG node that isn't a trailing directive, stop looking.
1468 if (!(isLocation || isReturnType))
1469 break;
1470 // Save the directive, but error if one of the same type was already
1471 // found.
1472 ++tail.numDirectives;
1473 if (isLocation) {
1474 if (tail.location)
1475 PrintFatalError(ErrorLoc: loc, Msg: "`location` directive can only be specified "
1476 "once");
1477 tail.location = dagArg;
1478 } else if (isReturnType) {
1479 if (tail.returnType)
1480 PrintFatalError(ErrorLoc: loc, Msg: "`returnType` directive can only be specified "
1481 "once");
1482 tail.returnType = dagArg;
1483 }
1484 }
1485
1486 return tail;
1487}
1488
1489std::string
1490PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) {
1491 if (tail.location)
1492 return handleLocationDirective(tree: tail.location);
1493
1494 // If no explicit location is given, use the default, all fused, location.
1495 return "odsLoc";
1496}
1497
1498std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
1499 int depth) {
1500 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
1501 LLVM_DEBUG(tree.print(llvm::dbgs()));
1502 LLVM_DEBUG(llvm::dbgs() << '\n');
1503
1504 Operator &resultOp = tree.getDialectOp(mapper: opMap);
1505 auto numOpArgs = resultOp.getNumArgs();
1506 auto numPatArgs = tree.getNumArgs();
1507
1508 auto tail = getTrailingDirectives(tree);
1509 auto locToUse = getLocation(tail);
1510
1511 auto inPattern = numPatArgs - tail.numDirectives;
1512 if (numOpArgs != inPattern) {
1513 PrintFatalError(ErrorLoc: loc,
1514 Msg: formatv(Fmt: "resultant op '{0}' argument number mismatch: "
1515 "{1} in pattern vs. {2} in definition",
1516 Vals: resultOp.getOperationName(), Vals&: inPattern, Vals&: numOpArgs));
1517 }
1518
1519 // A map to collect all nested DAG child nodes' names, with operand index as
1520 // the key. This includes both bound and unbound child nodes.
1521 ChildNodeIndexNameMap childNodeNames;
1522
1523 // If the argument is a type constraint, then its an operand. Check if the
1524 // op's argument is variadic that the argument in the pattern is too.
1525 auto checkIfMatchedVariadic = [&](int i) {
1526 // FIXME: This does not yet check for variable/leaf case.
1527 // FIXME: Change so that native code call can be handled.
1528 const auto *operand =
1529 llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val: resultOp.getArg(index: i));
1530 if (!operand || !operand->isVariadic())
1531 return;
1532
1533 auto child = tree.getArgAsNestedDag(index: i);
1534 if (!child)
1535 return;
1536
1537 // Skip over replaceWithValues.
1538 while (child.isReplaceWithValue()) {
1539 if (!(child = child.getArgAsNestedDag(index: 0)))
1540 return;
1541 }
1542 if (!child.isNativeCodeCall() && !child.isVariadic())
1543 PrintFatalError(ErrorLoc: loc, Msg: formatv(Fmt: "op expects variadic operand `{0}`, while "
1544 "provided is non-variadic",
1545 Vals: resultOp.getArgName(index: i)));
1546 };
1547
1548 // First go through all the child nodes who are nested DAG constructs to
1549 // create ops for them and remember the symbol names for them, so that we can
1550 // use the results in the current node. This happens in a recursive manner.
1551 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1552 checkIfMatchedVariadic(i);
1553 if (auto child = tree.getArgAsNestedDag(index: i))
1554 childNodeNames[i] = handleResultPattern(resultTree: child, resultIndex: i, depth: depth + 1);
1555 }
1556
1557 // The name of the local variable holding this op.
1558 std::string valuePackName;
1559 // The symbol for holding the result of this pattern. Note that the result of
1560 // this pattern is not necessarily the same as the variable created by this
1561 // pattern because we can use `__N` suffix to refer only a specific result if
1562 // the generated op is a multi-result op.
1563 std::string resultValue;
1564 if (tree.getSymbol().empty()) {
1565 // No symbol is explicitly bound to this op in the pattern. Generate a
1566 // unique name.
1567 valuePackName = resultValue = getUniqueSymbol(op: &resultOp);
1568 } else {
1569 resultValue = std::string(tree.getSymbol());
1570 // Strip the index to get the name for the value pack and use it to name the
1571 // local variable for the op.
1572 valuePackName = std::string(SymbolInfoMap::getValuePackName(symbol: resultValue));
1573 }
1574
1575 // Create the local variable for this op.
1576 os << formatv(Fmt: "{0} {1};\n{{\n", Vals: resultOp.getQualCppClassName(),
1577 Vals&: valuePackName);
1578
1579 // Right now ODS don't have general type inference support. Except a few
1580 // special cases listed below, DRR needs to supply types for all results
1581 // when building an op.
1582 bool isSameOperandsAndResultType =
1583 resultOp.getTrait(trait: "::mlir::OpTrait::SameOperandsAndResultType");
1584 bool useFirstAttr =
1585 resultOp.getTrait(trait: "::mlir::OpTrait::FirstAttrDerivedResultType");
1586
1587 if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) {
1588 // We know how to deduce the result type for ops with these traits and we've
1589 // generated builders taking aggregate parameters. Use those builders to
1590 // create the ops.
1591
1592 // First prepare local variables for op arguments used in builder call.
1593 createAggregateLocalVarsForOpArgs(node: tree, childNodeNames, depth);
1594
1595 // Then create the op.
1596 os.scope(open: "", close: "\n}\n").os << formatv(
1597 Fmt: "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
1598 Vals&: valuePackName, Vals: resultOp.getQualCppClassName(), Vals&: locToUse);
1599 return resultValue;
1600 }
1601
1602 bool usePartialResults = valuePackName != resultValue;
1603
1604 if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) {
1605 // For these cases (broadcastable ops, op results used both as auxiliary
1606 // values and replacement values, ops in nested patterns, auxiliary ops), we
1607 // still need to supply the result types when building the op. But because
1608 // we don't generate a builder automatically with ODS for them, it's the
1609 // developer's responsibility to make sure such a builder (with result type
1610 // deduction ability) exists. We go through the separate-parameter builder
1611 // here given that it's easier for developers to write compared to
1612 // aggregate-parameter builders.
1613 createSeparateLocalVarsForOpArgs(node: tree, childNodeNames);
1614
1615 os.scope().os << formatv(Fmt: "{0} = rewriter.create<{1}>({2}", Vals&: valuePackName,
1616 Vals: resultOp.getQualCppClassName(), Vals&: locToUse);
1617 supplyValuesForOpArgs(node: tree, childNodeNames, depth);
1618 os << "\n );\n}\n";
1619 return resultValue;
1620 }
1621
1622 // If we are provided explicit return types, use them to build the op.
1623 // However, if depth == 0 and resultIndex >= 0, it means we are replacing
1624 // the values generated from the source pattern root op. Then we must use the
1625 // source pattern's value types to determine the value type of the generated
1626 // op here.
1627 if (depth == 0 && resultIndex >= 0 && tail.returnType)
1628 PrintFatalError(ErrorLoc: loc, Msg: "Cannot specify explicit return types in an op whose "
1629 "return values replace the source pattern's root op");
1630
1631 // First prepare local variables for op arguments used in builder call.
1632 createAggregateLocalVarsForOpArgs(node: tree, childNodeNames, depth);
1633
1634 // Then prepare the result types. We need to specify the types for all
1635 // results.
1636 os.indent() << formatv(Fmt: "::llvm::SmallVector<::mlir::Type, 4> tblgen_types; "
1637 "(void)tblgen_types;\n");
1638 int numResults = resultOp.getNumResults();
1639 if (tail.returnType) {
1640 auto numRetTys = tail.returnType.getNumArgs();
1641 for (int i = 0; i < numRetTys; ++i) {
1642 auto varName = handleReturnTypeArg(returnType: tail.returnType, i, depth: depth + 1);
1643 os << "tblgen_types.push_back(" << varName << ");\n";
1644 }
1645 } else {
1646 if (numResults != 0) {
1647 // Copy the result types from the source pattern.
1648 for (int i = 0; i < numResults; ++i)
1649 os << formatv(Fmt: "for (auto v: castedOp0.getODSResults({0})) {{\n"
1650 " tblgen_types.push_back(v.getType());\n}\n",
1651 Vals: resultIndex + i);
1652 }
1653 }
1654 os << formatv(Fmt: "{0} = rewriter.create<{1}>({2}, tblgen_types, "
1655 "tblgen_values, tblgen_attrs);\n",
1656 Vals&: valuePackName, Vals: resultOp.getQualCppClassName(), Vals&: locToUse);
1657 os.unindent() << "}\n";
1658 return resultValue;
1659}
1660
1661void PatternEmitter::createSeparateLocalVarsForOpArgs(
1662 DagNode node, ChildNodeIndexNameMap &childNodeNames) {
1663 Operator &resultOp = node.getDialectOp(mapper: opMap);
1664
1665 // Now prepare operands used for building this op:
1666 // * If the operand is non-variadic, we create a `Value` local variable.
1667 // * If the operand is variadic, we create a `SmallVector<Value>` local
1668 // variable.
1669
1670 int valueIndex = 0; // An index for uniquing local variable names.
1671 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1672 const auto *operand =
1673 llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val: resultOp.getArg(index: argIndex));
1674 // We do not need special handling for attributes.
1675 if (!operand)
1676 continue;
1677
1678 raw_indented_ostream::DelimitedScope scope(os);
1679 std::string varName;
1680 if (operand->isVariadic()) {
1681 varName = std::string(formatv(Fmt: "tblgen_values_{0}", Vals: valueIndex++));
1682 os << formatv(Fmt: "::llvm::SmallVector<::mlir::Value, 4> {0};\n", Vals&: varName);
1683 std::string range;
1684 if (node.isNestedDagArg(index: argIndex)) {
1685 range = childNodeNames[argIndex];
1686 } else {
1687 range = std::string(node.getArgName(index: argIndex));
1688 }
1689 // Resolve the symbol for all range use so that we have a uniform way of
1690 // capturing the values.
1691 range = symbolInfoMap.getValueAndRangeUse(symbol: range);
1692 os << formatv(Fmt: "for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", Vals&: range,
1693 Vals&: varName);
1694 } else {
1695 varName = std::string(formatv(Fmt: "tblgen_value_{0}", Vals: valueIndex++));
1696 os << formatv(Fmt: "::mlir::Value {0} = ", Vals&: varName);
1697 if (node.isNestedDagArg(index: argIndex)) {
1698 os << symbolInfoMap.getValueAndRangeUse(symbol: childNodeNames[argIndex]);
1699 } else {
1700 DagLeaf leaf = node.getArgAsLeaf(index: argIndex);
1701 auto symbol =
1702 symbolInfoMap.getValueAndRangeUse(symbol: node.getArgName(index: argIndex));
1703 if (leaf.isNativeCodeCall()) {
1704 os << std::string(
1705 tgfmt(fmt: leaf.getNativeCodeTemplate(), ctx: &fmtCtx.withSelf(subst: symbol)));
1706 } else {
1707 os << symbol;
1708 }
1709 }
1710 os << ";\n";
1711 }
1712
1713 // Update to use the newly created local variable for building the op later.
1714 childNodeNames[argIndex] = varName;
1715 }
1716}
1717
1718void PatternEmitter::supplyValuesForOpArgs(
1719 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1720 Operator &resultOp = node.getDialectOp(mapper: opMap);
1721 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
1722 argIndex != numOpArgs; ++argIndex) {
1723 // Start each argument on its own line.
1724 os << ",\n ";
1725
1726 Argument opArg = resultOp.getArg(index: argIndex);
1727 // Handle the case of operand first.
1728 if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: opArg)) {
1729 if (!operand->name.empty())
1730 os << "/*" << operand->name << "=*/";
1731 os << childNodeNames.lookup(Val: argIndex);
1732 continue;
1733 }
1734
1735 // The argument in the op definition.
1736 auto opArgName = resultOp.getArgName(index: argIndex);
1737 if (auto subTree = node.getArgAsNestedDag(index: argIndex)) {
1738 if (!subTree.isNativeCodeCall())
1739 PrintFatalError(ErrorLoc: loc, Msg: "only NativeCodeCall allowed in nested dag node "
1740 "for creating attribute");
1741 os << formatv(Fmt: "/*{0}=*/{1}", Vals&: opArgName, Vals: childNodeNames.lookup(Val: argIndex));
1742 } else {
1743 auto leaf = node.getArgAsLeaf(index: argIndex);
1744 // The argument in the result DAG pattern.
1745 auto patArgName = node.getArgName(index: argIndex);
1746 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
1747 // TODO: Refactor out into map to avoid recomputing these.
1748 if (!opArg.is<NamedAttribute *>())
1749 PrintFatalError(ErrorLoc: loc, Msg: Twine("expected attribute ") + Twine(argIndex));
1750 if (!patArgName.empty())
1751 os << "/*" << patArgName << "=*/";
1752 } else {
1753 os << "/*" << opArgName << "=*/";
1754 }
1755 os << handleOpArgument(leaf, patArgName);
1756 }
1757 }
1758}
1759
1760void PatternEmitter::createAggregateLocalVarsForOpArgs(
1761 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1762 Operator &resultOp = node.getDialectOp(mapper: opMap);
1763
1764 auto scope = os.scope();
1765 os << formatv(Fmt: "::llvm::SmallVector<::mlir::Value, 4> "
1766 "tblgen_values; (void)tblgen_values;\n");
1767 os << formatv(Fmt: "::llvm::SmallVector<::mlir::NamedAttribute, 4> "
1768 "tblgen_attrs; (void)tblgen_attrs;\n");
1769
1770 const char *addAttrCmd =
1771 "if (auto tmpAttr = {1}) {\n"
1772 " tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
1773 "tmpAttr);\n}\n";
1774 int numVariadic = 0;
1775 bool hasOperandSegmentSizes = false;
1776 std::vector<std::string> sizes;
1777 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1778 if (resultOp.getArg(index: argIndex).is<NamedAttribute *>()) {
1779 // The argument in the op definition.
1780 auto opArgName = resultOp.getArgName(index: argIndex);
1781 hasOperandSegmentSizes =
1782 hasOperandSegmentSizes || opArgName == "operandSegmentSizes";
1783 if (auto subTree = node.getArgAsNestedDag(index: argIndex)) {
1784 if (!subTree.isNativeCodeCall())
1785 PrintFatalError(ErrorLoc: loc, Msg: "only NativeCodeCall allowed in nested dag node "
1786 "for creating attribute");
1787 os << formatv(Fmt: addAttrCmd, Vals&: opArgName, Vals: childNodeNames.lookup(Val: argIndex));
1788 } else {
1789 auto leaf = node.getArgAsLeaf(index: argIndex);
1790 // The argument in the result DAG pattern.
1791 auto patArgName = node.getArgName(index: argIndex);
1792 os << formatv(Fmt: addAttrCmd, Vals&: opArgName,
1793 Vals: handleOpArgument(leaf, patArgName));
1794 }
1795 continue;
1796 }
1797
1798 const auto *operand =
1799 resultOp.getArg(index: argIndex).get<NamedTypeConstraint *>();
1800 std::string varName;
1801 if (operand->isVariadic()) {
1802 ++numVariadic;
1803 std::string range;
1804 if (node.isNestedDagArg(index: argIndex)) {
1805 range = childNodeNames.lookup(Val: argIndex);
1806 } else {
1807 range = std::string(node.getArgName(index: argIndex));
1808 }
1809 // Resolve the symbol for all range use so that we have a uniform way of
1810 // capturing the values.
1811 range = symbolInfoMap.getValueAndRangeUse(symbol: range);
1812 os << formatv(Fmt: "for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n",
1813 Vals&: range);
1814 sizes.push_back(x: formatv(Fmt: "static_cast<int32_t>({0}.size())", Vals&: range));
1815 } else {
1816 sizes.emplace_back(args: "1");
1817 os << formatv(Fmt: "tblgen_values.push_back(");
1818 if (node.isNestedDagArg(index: argIndex)) {
1819 os << symbolInfoMap.getValueAndRangeUse(
1820 symbol: childNodeNames.lookup(Val: argIndex));
1821 } else {
1822 DagLeaf leaf = node.getArgAsLeaf(index: argIndex);
1823 if (leaf.isConstantAttr())
1824 // TODO: Use better location
1825 PrintFatalError(
1826 ErrorLoc: loc,
1827 Msg: "attribute found where value was expected, if attempting to use "
1828 "constant value, construct a constant op with given attribute "
1829 "instead");
1830
1831 auto symbol =
1832 symbolInfoMap.getValueAndRangeUse(symbol: node.getArgName(index: argIndex));
1833 if (leaf.isNativeCodeCall()) {
1834 os << std::string(
1835 tgfmt(fmt: leaf.getNativeCodeTemplate(), ctx: &fmtCtx.withSelf(subst: symbol)));
1836 } else {
1837 os << symbol;
1838 }
1839 }
1840 os << ");\n";
1841 }
1842 }
1843
1844 if (numVariadic > 1 && !hasOperandSegmentSizes) {
1845 // Only set size if it can't be computed.
1846 const auto *sameVariadicSize =
1847 resultOp.getTrait(trait: "::mlir::OpTrait::SameVariadicOperandSize");
1848 if (!sameVariadicSize) {
1849 const char *setSizes = R"(
1850 tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
1851 rewriter.getDenseI32ArrayAttr({{ {0} }));
1852 )";
1853 os.printReindented(str: formatv(Fmt: setSizes, Vals: llvm::join(R&: sizes, Separator: ", ")).str());
1854 }
1855 }
1856}
1857
1858StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os,
1859 const RecordKeeper &recordKeeper,
1860 RecordOperatorMap &mapper)
1861 : opMap(mapper), staticVerifierEmitter(os, recordKeeper) {}
1862
1863void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
1864 // PatternEmitter will use the static matcher if there's one generated. To
1865 // ensure that all the dependent static matchers are generated before emitting
1866 // the matching logic of the DagNode, we use topological order to achieve it.
1867 for (auto &dagInfo : topologicalOrder) {
1868 DagNode node = dagInfo.first;
1869 if (!useStaticMatcher(node))
1870 continue;
1871
1872 std::string funcName =
1873 formatv(Fmt: "static_dag_matcher_{0}", Vals: staticMatcherCounter++);
1874 assert(!matcherNames.contains(node));
1875 PatternEmitter(dagInfo.second, &opMap, os, *this)
1876 .emitStaticMatcher(tree: node, funcName);
1877 matcherNames[node] = funcName;
1878 }
1879}
1880
1881void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) {
1882 staticVerifierEmitter.emitPatternConstraints(constraints: constraints.getArrayRef());
1883}
1884
1885void StaticMatcherHelper::addPattern(Record *record) {
1886 Pattern pat(record, &opMap);
1887
1888 // While generating the function body of the DAG matcher, it may depends on
1889 // other DAG matchers. To ensure the dependent matchers are ready, we compute
1890 // the topological order for all the DAGs and emit the DAG matchers in this
1891 // order.
1892 llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) {
1893 ++refStats[node];
1894
1895 if (refStats[node] != 1)
1896 return;
1897
1898 for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i)
1899 if (DagNode sibling = node.getArgAsNestedDag(index: i))
1900 dfs(sibling);
1901 else {
1902 DagLeaf leaf = node.getArgAsLeaf(index: i);
1903 if (!leaf.isUnspecified())
1904 constraints.insert(X: leaf);
1905 }
1906
1907 topologicalOrder.push_back(Elt: std::make_pair(x&: node, y&: record));
1908 };
1909
1910 dfs(pat.getSourcePattern());
1911}
1912
1913StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) {
1914 if (leaf.isAttrMatcher()) {
1915 std::optional<StringRef> constraint =
1916 staticVerifierEmitter.getAttrConstraintFn(constraint: leaf.getAsConstraint());
1917 assert(constraint && "attribute constraint was not uniqued");
1918 return *constraint;
1919 }
1920 assert(leaf.isOperandMatcher());
1921 return staticVerifierEmitter.getTypeConstraintFn(constraint: leaf.getAsConstraint());
1922}
1923
1924static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
1925 emitSourceFileHeader(Desc: "Rewriters", OS&: os, Record: recordKeeper);
1926
1927 const auto &patterns = recordKeeper.getAllDerivedDefinitions(ClassName: "Pattern");
1928
1929 // We put the map here because it can be shared among multiple patterns.
1930 RecordOperatorMap recordOpMap;
1931
1932 // Exam all the patterns and generate static matcher for the duplicated
1933 // DagNode.
1934 StaticMatcherHelper staticMatcher(os, recordKeeper, recordOpMap);
1935 for (Record *p : patterns)
1936 staticMatcher.addPattern(record: p);
1937 staticMatcher.populateStaticConstraintFunctions(os);
1938 staticMatcher.populateStaticMatchers(os);
1939
1940 std::vector<std::string> rewriterNames;
1941 rewriterNames.reserve(n: patterns.size());
1942
1943 std::string baseRewriterName = "GeneratedConvert";
1944 int rewriterIndex = 0;
1945
1946 for (Record *p : patterns) {
1947 std::string name;
1948 if (p->isAnonymous()) {
1949 // If no name is provided, ensure unique rewriter names simply by
1950 // appending unique suffix.
1951 name = baseRewriterName + llvm::utostr(X: rewriterIndex++);
1952 } else {
1953 name = std::string(p->getName());
1954 }
1955 LLVM_DEBUG(llvm::dbgs()
1956 << "=== start generating pattern '" << name << "' ===\n");
1957 PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(rewriteName: name);
1958 LLVM_DEBUG(llvm::dbgs()
1959 << "=== done generating pattern '" << name << "' ===\n");
1960 rewriterNames.push_back(x: std::move(name));
1961 }
1962
1963 // Emit function to add the generated matchers to the pattern list.
1964 os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated("
1965 "::mlir::RewritePatternSet &patterns) {\n";
1966 for (const auto &name : rewriterNames) {
1967 os << " patterns.add<" << name << ">(patterns.getContext());\n";
1968 }
1969 os << "}\n";
1970}
1971
1972static mlir::GenRegistration
1973 genRewriters("gen-rewriters", "Generate pattern rewriters",
1974 [](const RecordKeeper &records, raw_ostream &os) {
1975 emitRewriters(recordKeeper: records, os);
1976 return false;
1977 });
1978

source code of mlir/tools/mlir-tblgen/RewriterGen.cpp