1//===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Support/IndentedOstream.h"
14#include "mlir/TableGen/Argument.h"
15#include "mlir/TableGen/Attribute.h"
16#include "mlir/TableGen/CodeGenHelpers.h"
17#include "mlir/TableGen/Format.h"
18#include "mlir/TableGen/GenInfo.h"
19#include "mlir/TableGen/Operator.h"
20#include "mlir/TableGen/Pattern.h"
21#include "mlir/TableGen/Predicate.h"
22#include "mlir/TableGen/Property.h"
23#include "mlir/TableGen/Type.h"
24#include "llvm/ADT/FunctionExtras.h"
25#include "llvm/ADT/SetVector.h"
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/ADT/StringSet.h"
28#include "llvm/Support/CommandLine.h"
29#include "llvm/Support/Debug.h"
30#include "llvm/Support/FormatAdapters.h"
31#include "llvm/Support/PrettyStackTrace.h"
32#include "llvm/Support/Signals.h"
33#include "llvm/TableGen/Error.h"
34#include "llvm/TableGen/Main.h"
35#include "llvm/TableGen/Record.h"
36#include "llvm/TableGen/TableGenBackend.h"
37
38using namespace mlir;
39using namespace mlir::tblgen;
40
41using llvm::formatv;
42using llvm::Record;
43using llvm::RecordKeeper;
44
45#define DEBUG_TYPE "mlir-tblgen-rewritergen"
46
47namespace llvm {
48template <>
49struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
50 static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
51 raw_ostream &os, StringRef style) {
52 os << v.first << ":" << v.second;
53 }
54};
55} // namespace llvm
56
57//===----------------------------------------------------------------------===//
58// PatternEmitter
59//===----------------------------------------------------------------------===//
60
61namespace {
62
63class StaticMatcherHelper;
64
65class PatternEmitter {
66public:
67 PatternEmitter(const Record *pat, RecordOperatorMap *mapper, raw_ostream &os,
68 StaticMatcherHelper &helper);
69
70 // Emits the mlir::RewritePattern struct named `rewriteName`.
71 void emit(StringRef rewriteName);
72
73 // Emits the static function of DAG matcher.
74 void emitStaticMatcher(DagNode tree, std::string funcName);
75
76private:
77 // Emits the code for matching ops.
78 void emitMatchLogic(DagNode tree, StringRef opName);
79
80 // Emits the code for rewriting ops.
81 void emitRewriteLogic();
82
83 //===--------------------------------------------------------------------===//
84 // Match utilities
85 //===--------------------------------------------------------------------===//
86
87 // Emits C++ statements for matching the DAG structure.
88 void emitMatch(DagNode tree, StringRef name, int depth);
89
90 // Emit C++ function call to static DAG matcher.
91 void emitStaticMatchCall(DagNode tree, StringRef name);
92
93 // Emit C++ function call to static type/attribute constraint function.
94 void emitStaticVerifierCall(StringRef funcName, StringRef opName,
95 StringRef arg, StringRef failureStr);
96
97 // Emits C++ statements for matching using a native code call.
98 void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
99
100 // Emits C++ statements for matching the op constrained by the given DAG
101 // `tree` returning the op's variable name.
102 void emitOpMatch(DagNode tree, StringRef opName, int depth);
103
104 // Emits C++ statements for matching the `argIndex`-th argument of the given
105 // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the
106 // bound name and the constraint of the operand respectively.
107 void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName,
108 int operandIndex, DagLeaf operandMatcher,
109 StringRef argName, int argIndex,
110 std::optional<int> variadicSubIndex);
111
112 // Emits C++ statements for matching the operands which can be matched in
113 // either order.
114 void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
115 StringRef opName, int argIndex, int &operandIndex,
116 int depth);
117
118 // Emits C++ statements for matching a variadic operand.
119 void emitVariadicOperandMatch(DagNode tree, DagNode variadicArgTree,
120 StringRef opName, int argIndex,
121 int &operandIndex, int depth);
122
123 // Emits C++ statements for matching the `argIndex`-th argument of the given
124 // DAG `tree` as an attribute.
125 void emitAttributeMatch(DagNode tree, StringRef castedName, int argIndex,
126 int depth);
127
128 // Emits C++ for checking a match with a corresponding match failure
129 // diagnostic.
130 void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
131 const llvm::formatv_object_base &failureFmt);
132
133 // Emits C++ for checking a match with a corresponding match failure
134 // diagnostics.
135 void emitMatchCheck(StringRef opName, const std::string &matchStr,
136 const std::string &failureStr);
137
138 //===--------------------------------------------------------------------===//
139 // Rewrite utilities
140 //===--------------------------------------------------------------------===//
141
142 // The entry point for handling a result pattern rooted at `resultTree`. This
143 // method dispatches to concrete handlers according to `resultTree`'s kind and
144 // returns a symbol representing the whole value pack. Callers are expected to
145 // further resolve the symbol according to the specific use case.
146 //
147 // `depth` is the nesting level of `resultTree`; 0 means top-level result
148 // pattern. For top-level result pattern, `resultIndex` indicates which result
149 // of the matched root op this pattern is intended to replace, which can be
150 // used to deduce the result type of the op generated from this result
151 // pattern.
152 std::string handleResultPattern(DagNode resultTree, int resultIndex,
153 int depth);
154
155 // Emits the C++ statement to replace the matched DAG with a value built via
156 // calling native C++ code.
157 std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
158
159 // Returns the symbol of the old value serving as the replacement.
160 StringRef handleReplaceWithValue(DagNode tree);
161
162 // Emits the C++ statement to replace the matched DAG with an array of
163 // matched values.
164 std::string handleVariadic(DagNode tree, int depth);
165
166 // Trailing directives are used at the end of DAG node argument lists to
167 // specify additional behaviour for op matchers and creators, etc.
168 struct TrailingDirectives {
169 // DAG node containing the `location` directive. Null if there is none.
170 DagNode location;
171
172 // DAG node containing the `returnType` directive. Null if there is none.
173 DagNode returnType;
174
175 // Number of found trailing directives.
176 int numDirectives;
177 };
178
179 // Collect any trailing directives.
180 TrailingDirectives getTrailingDirectives(DagNode tree);
181
182 // Returns the location value to use.
183 std::string getLocation(TrailingDirectives &tail);
184
185 // Returns the location value to use.
186 std::string handleLocationDirective(DagNode tree);
187
188 // Emit return type argument.
189 std::string handleReturnTypeArg(DagNode returnType, int i, int depth);
190
191 // Emits the C++ statement to build a new op out of the given DAG `tree` and
192 // returns the variable name that this op is assigned to. If the root op in
193 // DAG `tree` has a specified name, the created op will be assigned to a
194 // variable of the given name. Otherwise, a unique name will be used as the
195 // result value name.
196 std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
197
198 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
199
200 // Emits a local variable for each value and attribute to be used for creating
201 // an op.
202 void createSeparateLocalVarsForOpArgs(DagNode node,
203 ChildNodeIndexNameMap &childNodeNames);
204
205 // Emits the concrete arguments used to call an op's builder.
206 void supplyValuesForOpArgs(DagNode node,
207 const ChildNodeIndexNameMap &childNodeNames,
208 int depth);
209
210 // Emits the local variables for holding all values as a whole and all named
211 // attributes as a whole to be used for creating an op.
212 void createAggregateLocalVarsForOpArgs(
213 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
214
215 // Returns the C++ expression to construct a constant attribute of the given
216 // `value` for the given attribute kind `attr`.
217 std::string handleConstantAttr(Attribute attr, const Twine &value);
218
219 // Returns the C++ expression to build an argument from the given DAG `leaf`.
220 // `patArgName` is used to bound the argument to the source pattern.
221 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
222
223 //===--------------------------------------------------------------------===//
224 // General utilities
225 //===--------------------------------------------------------------------===//
226
227 // Collects all of the operations within the given dag tree.
228 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
229
230 // Returns a unique symbol for a local variable of the given `op`.
231 std::string getUniqueSymbol(const Operator *op);
232
233 //===--------------------------------------------------------------------===//
234 // Symbol utilities
235 //===--------------------------------------------------------------------===//
236
237 // Returns how many static values the given DAG `node` correspond to.
238 int getNodeValueCount(DagNode node);
239
240private:
241 // Pattern instantiation location followed by the location of multiclass
242 // prototypes used. This is intended to be used as a whole to
243 // PrintFatalError() on errors.
244 ArrayRef<SMLoc> loc;
245
246 // Op's TableGen Record to wrapper object.
247 RecordOperatorMap *opMap;
248
249 // Handy wrapper for pattern being emitted.
250 Pattern pattern;
251
252 // Map for all bound symbols' info.
253 SymbolInfoMap symbolInfoMap;
254
255 StaticMatcherHelper &staticMatcherHelper;
256
257 // The next unused ID for newly created values.
258 unsigned nextValueId = 0;
259
260 raw_indented_ostream os;
261
262 // Format contexts containing placeholder substitutions.
263 FmtContext fmtCtx;
264};
265
266// Tracks DagNode's reference multiple times across patterns. Enables generating
267// static matcher functions for DagNode's referenced multiple times rather than
268// inlining them.
269class StaticMatcherHelper {
270public:
271 StaticMatcherHelper(raw_ostream &os, const RecordKeeper &records,
272 RecordOperatorMap &mapper);
273
274 // Determine if we should inline the match logic or delegate to a static
275 // function.
276 bool useStaticMatcher(DagNode node) {
277 // either/variadic node must be associated to the parentOp, thus we can't
278 // emit a static matcher rooted at them.
279 if (node.isEither() || node.isVariadic())
280 return false;
281
282 return refStats[node] > kStaticMatcherThreshold;
283 }
284
285 // Get the name of the static DAG matcher function corresponding to the node.
286 std::string getMatcherName(DagNode node) {
287 assert(useStaticMatcher(node));
288 return matcherNames[node];
289 }
290
291 // Get the name of static type/attribute verification function.
292 StringRef getVerifierName(DagLeaf leaf);
293
294 // Collect the `Record`s, i.e., the DRR, so that we can get the information of
295 // the duplicated DAGs.
296 void addPattern(const Record *record);
297
298 // Emit all static functions of DAG Matcher.
299 void populateStaticMatchers(raw_ostream &os);
300
301 // Emit all static functions for Constraints.
302 void populateStaticConstraintFunctions(raw_ostream &os);
303
304private:
305 static constexpr unsigned kStaticMatcherThreshold = 1;
306
307 // Consider two patterns as down below,
308 // DagNode_Root_A DagNode_Root_B
309 // \ \
310 // DagNode_C DagNode_C
311 // \ \
312 // DagNode_D DagNode_D
313 //
314 // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of
315 // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced
316 // multiple times so we'll have static matchers for both of them. When we're
317 // emitting the match logic for DagNode_C, we will check if DagNode_D has the
318 // static matcher generated. If so, then we'll generate a call to the
319 // function, inline otherwise. In this case, inlining is not what we want. As
320 // a result, generate the static matcher in topological order to ensure all
321 // the dependent static matchers are generated and we can avoid accidentally
322 // inlining.
323 //
324 // The topological order of all the DagNodes among all patterns.
325 SmallVector<std::pair<DagNode, const Record *>> topologicalOrder;
326
327 RecordOperatorMap &opMap;
328
329 // Records of the static function name of each DagNode
330 DenseMap<DagNode, std::string> matcherNames;
331
332 // After collecting all the DagNode in each pattern, `refStats` records the
333 // number of users for each DagNode. We will generate the static matcher for a
334 // DagNode while the number of users exceeds a certain threshold.
335 DenseMap<DagNode, unsigned> refStats;
336
337 // Number of static matcher generated. This is used to generate a unique name
338 // for each DagNode.
339 int staticMatcherCounter = 0;
340
341 // The DagLeaf which contains type or attr constraint.
342 SetVector<DagLeaf> constraints;
343
344 // Static type/attribute verification function emitter.
345 StaticVerifierFunctionEmitter staticVerifierEmitter;
346};
347
348} // namespace
349
350PatternEmitter::PatternEmitter(const Record *pat, RecordOperatorMap *mapper,
351 raw_ostream &os, StaticMatcherHelper &helper)
352 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
353 symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), os(os) {
354 fmtCtx.withBuilder(subst: "rewriter");
355}
356
357std::string PatternEmitter::handleConstantAttr(Attribute attr,
358 const Twine &value) {
359 if (!attr.isConstBuildable())
360 PrintFatalError(ErrorLoc: loc, Msg: "Attribute " + attr.getAttrDefName() +
361 " does not have the 'constBuilderCall' field");
362
363 // TODO: Verify the constants here
364 return std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fmtCtx, vals: value));
365}
366
367void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) {
368 os << formatv(
369 Fmt: "static ::llvm::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
370 "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation "
371 "*, 4> &tblgen_ops",
372 Vals&: funcName);
373
374 // We pass the reference of the variables that need to be captured. Hence we
375 // need to collect all the symbols in the tree first.
376 pattern.collectBoundSymbols(tree, infoMap&: symbolInfoMap, /*isSrcPattern=*/true);
377 symbolInfoMap.assignUniqueAlternativeNames();
378 for (const auto &info : symbolInfoMap)
379 os << formatv(Fmt: ", {0}", Vals: info.second.getArgDecl(name: info.first));
380
381 os << ") {\n";
382 os.indent();
383 os << "(void)tblgen_ops;\n";
384
385 // Note that a static matcher is considered at least one step from the match
386 // entry.
387 emitMatch(tree, name: "op0", /*depth=*/1);
388
389 os << "return ::mlir::success();\n";
390 os.unindent();
391 os << "}\n\n";
392}
393
394// Helper function to match patterns.
395void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
396 if (tree.isNativeCodeCall()) {
397 emitNativeCodeMatch(tree, name, depth);
398 return;
399 }
400
401 if (tree.isOperation()) {
402 emitOpMatch(tree, opName: name, depth);
403 return;
404 }
405
406 PrintFatalError(ErrorLoc: loc, Msg: "encountered non-op, non-NativeCodeCall match.");
407}
408
409void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) {
410 std::string funcName = staticMatcherHelper.getMatcherName(node: tree);
411 os << formatv(Fmt: "if(::mlir::failed({0}(rewriter, {1}, tblgen_ops", Vals&: funcName,
412 Vals&: opName);
413
414 // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in
415 // one pass.
416
417 // In general, bound symbol should have the unique name in the pattern but
418 // for the operand, binding same symbol to multiple operands imply a
419 // constraint at the same time. In this case, we will rename those operands
420 // with different names. As a result, we need to collect all the symbolInfos
421 // from the DagNode then get the updated name of the local variables from the
422 // global symbolInfoMap.
423
424 // Collect all the bound symbols in the Dag
425 SymbolInfoMap localSymbolMap(loc);
426 pattern.collectBoundSymbols(tree, infoMap&: localSymbolMap, /*isSrcPattern=*/true);
427
428 for (const auto &info : localSymbolMap) {
429 auto name = info.first;
430 auto symboInfo = info.second;
431 auto ret = symbolInfoMap.findBoundSymbol(key: name, symbolInfo: symboInfo);
432 os << formatv(Fmt: ", {0}", Vals: ret->second.getVarName(name));
433 }
434
435 os << "))) {\n";
436 os.scope().os << "return ::mlir::failure();\n";
437 os << "}\n";
438}
439
440void PatternEmitter::emitStaticVerifierCall(StringRef funcName,
441 StringRef opName, StringRef arg,
442 StringRef failureStr) {
443 os << formatv(Fmt: "if(::mlir::failed({0}(rewriter, {1}, {2}, {3}))) {{\n",
444 Vals&: funcName, Vals&: opName, Vals&: arg, Vals&: failureStr);
445 os.scope().os << "return ::mlir::failure();\n";
446 os << "}\n";
447}
448
449// Helper function to match patterns.
450void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
451 int depth) {
452 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
453 LLVM_DEBUG(tree.print(llvm::dbgs()));
454 LLVM_DEBUG(llvm::dbgs() << '\n');
455
456 // The order of generating static matcher follows the topological order so
457 // that for every dependent DagNode already have their static matcher
458 // generated if needed. The reason we check if `getMatcherName(tree).empty()`
459 // is when we are generating the static matcher for a DagNode itself. In this
460 // case, we need to emit the function body rather than a function call.
461 if (staticMatcherHelper.useStaticMatcher(node: tree) &&
462 !staticMatcherHelper.getMatcherName(node: tree).empty()) {
463 emitStaticMatchCall(tree, opName);
464
465 // NativeCodeCall will never be at depth 0 so that we don't need to catch
466 // the root operation as emitOpMatch();
467
468 return;
469 }
470
471 // TODO(suderman): iterate through arguments, determine their types, output
472 // names.
473 SmallVector<std::string, 8> capture;
474
475 raw_indented_ostream::DelimitedScope scope(os);
476
477 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
478 std::string argName = formatv(Fmt: "arg{0}_{1}", Vals&: depth, Vals&: i);
479 if (DagNode argTree = tree.getArgAsNestedDag(index: i)) {
480 if (argTree.isEither())
481 PrintFatalError(ErrorLoc: loc, Msg: "NativeCodeCall cannot have `either` operands");
482 if (argTree.isVariadic())
483 PrintFatalError(ErrorLoc: loc, Msg: "NativeCodeCall cannot have `variadic` operands");
484
485 os << "::mlir::Value " << argName << ";\n";
486 } else {
487 auto leaf = tree.getArgAsLeaf(index: i);
488 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
489 os << "::mlir::Attribute " << argName << ";\n";
490 } else {
491 os << "::mlir::Value " << argName << ";\n";
492 }
493 }
494
495 capture.push_back(Elt: std::move(argName));
496 }
497
498 auto tail = getTrailingDirectives(tree);
499 if (tail.returnType)
500 PrintFatalError(ErrorLoc: loc, Msg: "`NativeCodeCall` cannot have return type specifier");
501 auto locToUse = getLocation(tail);
502
503 auto fmt = tree.getNativeCodeTemplate();
504 if (fmt.count(Str: "$_self") != 1)
505 PrintFatalError(ErrorLoc: loc, Msg: "NativeCodeCall must have $_self as argument for "
506 "passing the defining Operation");
507
508 auto nativeCodeCall = std::string(
509 tgfmt(fmt, ctx: &fmtCtx.addSubst(placeholder: "_loc", subst: locToUse).withSelf(subst: opName.str()),
510 params: static_cast<ArrayRef<std::string>>(capture)));
511
512 emitMatchCheck(opName, matchStr: formatv(Fmt: "!::mlir::failed({0})", Vals&: nativeCodeCall),
513 failureStr: formatv(Fmt: "\"{0} return ::mlir::failure\"", Vals&: nativeCodeCall));
514
515 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
516 auto name = tree.getArgName(index: i);
517 if (!name.empty() && name != "_") {
518 os << formatv(Fmt: "{0} = {1};\n", Vals&: name, Vals&: capture[i]);
519 }
520 }
521
522 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
523 std::string argName = capture[i];
524
525 // Handle nested DAG construct first
526 if (tree.getArgAsNestedDag(index: i)) {
527 PrintFatalError(
528 ErrorLoc: loc, Msg: formatv(Fmt: "Matching nested tree in NativeCodecall not support for "
529 "{0} as arg {1}",
530 Vals&: argName, Vals&: i));
531 }
532
533 DagLeaf leaf = tree.getArgAsLeaf(index: i);
534
535 // The parameter for native function doesn't bind any constraints.
536 if (leaf.isUnspecified())
537 continue;
538
539 auto constraint = leaf.getAsConstraint();
540
541 std::string self;
542 if (leaf.isAttrMatcher() || leaf.isConstantAttr())
543 self = argName;
544 else
545 self = formatv(Fmt: "{0}.getType()", Vals&: argName);
546 StringRef verifier = staticMatcherHelper.getVerifierName(leaf);
547 emitStaticVerifierCall(
548 funcName: verifier, opName, arg: self,
549 failureStr: formatv(Fmt: "\"operand {0} of native code call '{1}' failed to satisfy "
550 "constraint: "
551 "'{2}'\"",
552 Vals&: i, Vals: tree.getNativeCodeTemplate(),
553 Vals: escapeString(value: constraint.getSummary()))
554 .str());
555 }
556
557 LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
558}
559
560// Helper function to match patterns.
561void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
562 Operator &op = tree.getDialectOp(mapper: opMap);
563 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
564 << op.getOperationName() << "' at depth " << depth
565 << '\n');
566
567 auto getCastedName = [depth]() -> std::string {
568 return formatv(Fmt: "castedOp{0}", Vals: depth);
569 };
570
571 // The order of generating static matcher follows the topological order so
572 // that for every dependent DagNode already have their static matcher
573 // generated if needed. The reason we check if `getMatcherName(tree).empty()`
574 // is when we are generating the static matcher for a DagNode itself. In this
575 // case, we need to emit the function body rather than a function call.
576 if (staticMatcherHelper.useStaticMatcher(node: tree) &&
577 !staticMatcherHelper.getMatcherName(node: tree).empty()) {
578 emitStaticMatchCall(tree, opName);
579 // In the codegen of rewriter, we suppose that castedOp0 will capture the
580 // root operation. Manually add it if the root DagNode is a static matcher.
581 if (depth == 0)
582 os << formatv(Fmt: "auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); "
583 "(void){2};\n",
584 Vals&: opName, Vals: op.getQualCppClassName(), Vals: getCastedName());
585 return;
586 }
587
588 std::string castedName = getCastedName();
589 os << formatv(Fmt: "auto {0} = ::llvm::dyn_cast<{2}>({1}); "
590 "(void){0};\n",
591 Vals&: castedName, Vals&: opName, Vals: op.getQualCppClassName());
592
593 // Skip the operand matching at depth 0 as the pattern rewriter already does.
594 if (depth != 0)
595 emitMatchCheck(opName, /*matchStr=*/castedName,
596 failureStr: formatv(Fmt: "\"{0} is not {1} type\"", Vals&: castedName,
597 Vals: op.getQualCppClassName()));
598
599 // If the operand's name is set, set to that variable.
600 auto name = tree.getSymbol();
601 if (!name.empty())
602 os << formatv(Fmt: "{0} = {1};\n", Vals&: name, Vals&: castedName);
603
604 for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e;
605 ++i, ++opArgIdx) {
606 auto opArg = op.getArg(index: opArgIdx);
607 std::string argName = formatv(Fmt: "op{0}", Vals: depth + 1);
608
609 // Handle nested DAG construct first
610 if (DagNode argTree = tree.getArgAsNestedDag(index: i)) {
611 if (argTree.isEither()) {
612 emitEitherOperandMatch(tree, eitherArgTree: argTree, opName: castedName, argIndex: opArgIdx, operandIndex&: nextOperand,
613 depth);
614 ++opArgIdx;
615 continue;
616 }
617 if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: opArg)) {
618 if (argTree.isVariadic()) {
619 if (!operand->isVariadic()) {
620 auto error = formatv(Fmt: "variadic DAG construct can't match op {0}'s "
621 "non-variadic operand #{1}",
622 Vals: op.getOperationName(), Vals&: opArgIdx);
623 PrintFatalError(ErrorLoc: loc, Msg: error);
624 }
625 emitVariadicOperandMatch(tree, variadicArgTree: argTree, opName: castedName, argIndex: opArgIdx,
626 operandIndex&: nextOperand, depth);
627 ++nextOperand;
628 continue;
629 }
630 if (operand->isVariableLength()) {
631 auto error = formatv(Fmt: "use nested DAG construct to match op {0}'s "
632 "variadic operand #{1} unsupported now",
633 Vals: op.getOperationName(), Vals&: opArgIdx);
634 PrintFatalError(ErrorLoc: loc, Msg: error);
635 }
636 }
637
638 os << "{\n";
639
640 // Attributes don't count for getODSOperands.
641 // TODO: Operand is a Value, check if we should remove `getDefiningOp()`.
642 os.indent() << formatv(
643 Fmt: "auto *{0} = "
644 "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
645 Vals&: argName, Vals&: castedName, Vals&: nextOperand);
646 // Null check of operand's definingOp
647 emitMatchCheck(
648 opName: castedName, /*matchStr=*/argName,
649 failureStr: formatv(Fmt: "\"There's no operation that defines operand {0} of {1}\"",
650 Vals: nextOperand++, Vals&: castedName));
651 emitMatch(tree: argTree, name: argName, depth: depth + 1);
652 os << formatv(Fmt: "tblgen_ops.push_back({0});\n", Vals&: argName);
653 os.unindent() << "}\n";
654 continue;
655 }
656
657 // Next handle DAG leaf: operand or attribute
658 if (isa<NamedTypeConstraint *>(Val: opArg)) {
659 auto operandName =
660 formatv(Fmt: "{0}.getODSOperands({1})", Vals&: castedName, Vals&: nextOperand);
661 emitOperandMatch(tree, opName: castedName, operandName: operandName.str(), operandIndex: nextOperand,
662 /*operandMatcher=*/tree.getArgAsLeaf(index: i),
663 /*argName=*/tree.getArgName(index: i), argIndex: opArgIdx,
664 /*variadicSubIndex=*/std::nullopt);
665 ++nextOperand;
666 } else if (isa<NamedAttribute *>(Val: opArg)) {
667 emitAttributeMatch(tree, castedName, argIndex: opArgIdx, depth);
668 } else {
669 PrintFatalError(ErrorLoc: loc, Msg: "unhandled case when matching op");
670 }
671 }
672 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
673 << op.getOperationName() << "' at depth " << depth
674 << '\n');
675}
676
677void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
678 StringRef operandName, int operandIndex,
679 DagLeaf operandMatcher, StringRef argName,
680 int argIndex,
681 std::optional<int> variadicSubIndex) {
682 Operator &op = tree.getDialectOp(mapper: opMap);
683 NamedTypeConstraint operand = op.getOperand(index: operandIndex);
684
685 // If a constraint is specified, we need to generate C++ statements to
686 // check the constraint.
687 if (!operandMatcher.isUnspecified()) {
688 if (!operandMatcher.isOperandMatcher())
689 PrintFatalError(
690 ErrorLoc: loc, Msg: formatv(Fmt: "the {1}-th argument of op '{0}' should be an operand",
691 Vals: op.getOperationName(), Vals: argIndex + 1));
692
693 // Only need to verify if the matcher's type is different from the one
694 // of op definition.
695 Constraint constraint = operandMatcher.getAsConstraint();
696 if (operand.constraint != constraint) {
697 if (operand.isVariableLength()) {
698 auto error = formatv(
699 Fmt: "further constrain op {0}'s variadic operand #{1} unsupported now",
700 Vals: op.getOperationName(), Vals&: argIndex);
701 PrintFatalError(ErrorLoc: loc, Msg: error);
702 }
703 auto self = formatv(Fmt: "(*{0}.begin()).getType()", Vals&: operandName);
704 StringRef verifier = staticMatcherHelper.getVerifierName(leaf: operandMatcher);
705 emitStaticVerifierCall(
706 funcName: verifier, opName, arg: self.str(),
707 failureStr: formatv(
708 Fmt: "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
709 Vals&: operandIndex, Vals: op.getOperationName(),
710 Vals: escapeString(value: constraint.getSummary()))
711 .str());
712 }
713 }
714
715 // Capture the value
716 // `$_` is a special symbol to ignore op argument matching.
717 if (!argName.empty() && argName != "_") {
718 auto res = symbolInfoMap.findBoundSymbol(key: argName, node: tree, op, argIndex,
719 variadicSubIndex);
720 if (res == symbolInfoMap.end())
721 PrintFatalError(ErrorLoc: loc, Msg: formatv(Fmt: "symbol not found: {0}", Vals&: argName));
722
723 os << formatv(Fmt: "{0} = {1};\n", Vals: res->second.getVarName(name: argName), Vals&: operandName);
724 }
725}
726
727void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
728 StringRef opName, int argIndex,
729 int &operandIndex, int depth) {
730 constexpr int numEitherArgs = 2;
731 if (eitherArgTree.getNumArgs() != numEitherArgs)
732 PrintFatalError(ErrorLoc: loc, Msg: "`either` only supports grouping two operands");
733
734 Operator &op = tree.getDialectOp(mapper: opMap);
735
736 std::string codeBuffer;
737 llvm::raw_string_ostream tblgenOps(codeBuffer);
738
739 std::string lambda = formatv(Fmt: "eitherLambda{0}", Vals&: depth);
740 os << formatv(
741 Fmt: "auto {0} = [&](::mlir::OperandRange v0, ::mlir::OperandRange v1) {{\n",
742 Vals&: lambda);
743
744 os.indent();
745
746 for (int i = 0; i < numEitherArgs; ++i, ++argIndex) {
747 if (DagNode argTree = eitherArgTree.getArgAsNestedDag(index: i)) {
748 if (argTree.isEither())
749 PrintFatalError(ErrorLoc: loc, Msg: "either cannot be nested");
750
751 std::string argName = formatv(Fmt: "local_op_{0}", Vals&: i).str();
752
753 os << formatv(Fmt: "auto {0} = (*v{1}.begin()).getDefiningOp();\n", Vals&: argName,
754 Vals&: i);
755
756 // Indent emitMatchCheck and emitMatch because they declare local
757 // variables.
758 os << "{\n";
759 os.indent();
760
761 emitMatchCheck(
762 opName, /*matchStr=*/argName,
763 failureStr: formatv(Fmt: "\"There's no operation that defines operand {0} of {1}\"",
764 Vals: operandIndex++, Vals&: opName));
765 emitMatch(tree: argTree, name: argName, depth: depth + 1);
766
767 os.unindent() << "}\n";
768
769 // `tblgen_ops` is used to collect the matched operations. In either, we
770 // need to queue the operation only if the matching success. Thus we emit
771 // the code at the end.
772 tblgenOps << formatv(Fmt: "tblgen_ops.push_back({0});\n", Vals&: argName);
773 } else if (isa<NamedTypeConstraint *>(Val: op.getArg(index: argIndex))) {
774 emitOperandMatch(tree, opName, /*operandName=*/formatv(Fmt: "v{0}", Vals&: i).str(),
775 operandIndex,
776 /*operandMatcher=*/eitherArgTree.getArgAsLeaf(index: i),
777 /*argName=*/eitherArgTree.getArgName(index: i), argIndex,
778 /*variadicSubIndex=*/std::nullopt);
779 ++operandIndex;
780 } else {
781 PrintFatalError(ErrorLoc: loc, Msg: "either can only be applied on operand");
782 }
783 }
784
785 os << tblgenOps.str();
786 os << "return ::mlir::success();\n";
787 os.unindent() << "};\n";
788
789 os << "{\n";
790 os.indent();
791
792 os << formatv(Fmt: "auto eitherOperand0 = {0}.getODSOperands({1});\n", Vals&: opName,
793 Vals: operandIndex - 2);
794 os << formatv(Fmt: "auto eitherOperand1 = {0}.getODSOperands({1});\n", Vals&: opName,
795 Vals: operandIndex - 1);
796
797 os << formatv(Fmt: "if(::mlir::failed({0}(eitherOperand0, eitherOperand1)) && "
798 "::mlir::failed({0}(eitherOperand1, "
799 "eitherOperand0)))\n",
800 Vals&: lambda);
801 os.indent() << "return ::mlir::failure();\n";
802
803 os.unindent().unindent() << "}\n";
804}
805
806void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
807 DagNode variadicArgTree,
808 StringRef opName, int argIndex,
809 int &operandIndex, int depth) {
810 Operator &op = tree.getDialectOp(mapper: opMap);
811
812 os << "{\n";
813 os.indent();
814
815 os << formatv(Fmt: "auto variadic_operand_range = {0}.getODSOperands({1});\n",
816 Vals&: opName, Vals&: operandIndex);
817 os << formatv(Fmt: "if (variadic_operand_range.size() != {0}) "
818 "return ::mlir::failure();\n",
819 Vals: variadicArgTree.getNumArgs());
820
821 StringRef variadicTreeName = variadicArgTree.getSymbol();
822 if (!variadicTreeName.empty()) {
823 auto res =
824 symbolInfoMap.findBoundSymbol(key: variadicTreeName, node: tree, op, argIndex,
825 /*variadicSubIndex=*/std::nullopt);
826 if (res == symbolInfoMap.end())
827 PrintFatalError(ErrorLoc: loc, Msg: formatv(Fmt: "symbol not found: {0}", Vals&: variadicTreeName));
828
829 os << formatv(Fmt: "{0} = variadic_operand_range;\n",
830 Vals: res->second.getVarName(name: variadicTreeName));
831 }
832
833 for (int i = 0; i < variadicArgTree.getNumArgs(); ++i) {
834 if (DagNode argTree = variadicArgTree.getArgAsNestedDag(index: i)) {
835 if (!argTree.isOperation())
836 PrintFatalError(ErrorLoc: loc, Msg: "variadic only accepts operation sub-dags");
837
838 os << "{\n";
839 os.indent();
840
841 std::string argName = formatv(Fmt: "local_op_{0}", Vals&: i).str();
842 os << formatv(Fmt: "auto *{0} = "
843 "variadic_operand_range[{1}].getDefiningOp();\n",
844 Vals&: argName, Vals&: i);
845 emitMatchCheck(
846 opName, /*matchStr=*/argName,
847 failureStr: formatv(Fmt: "\"There's no operation that defines variadic operand "
848 "{0} (variadic sub-opearnd #{1}) of {2}\"",
849 Vals&: operandIndex, Vals&: i, Vals&: opName));
850 emitMatch(tree: argTree, name: argName, depth: depth + 1);
851 os << formatv(Fmt: "tblgen_ops.push_back({0});\n", Vals&: argName);
852
853 os.unindent() << "}\n";
854 } else if (isa<NamedTypeConstraint *>(Val: op.getArg(index: argIndex))) {
855 auto operandName = formatv(Fmt: "variadic_operand_range.slice({0}, 1)", Vals&: i);
856 emitOperandMatch(tree, opName, operandName: operandName.str(), operandIndex,
857 /*operandMatcher=*/variadicArgTree.getArgAsLeaf(index: i),
858 /*argName=*/variadicArgTree.getArgName(index: i), argIndex, variadicSubIndex: i);
859 } else {
860 PrintFatalError(ErrorLoc: loc, Msg: "variadic can only be applied on operand");
861 }
862 }
863
864 os.unindent() << "}\n";
865}
866
867void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef castedName,
868 int argIndex, int depth) {
869 Operator &op = tree.getDialectOp(mapper: opMap);
870 auto *namedAttr = cast<NamedAttribute *>(Val: op.getArg(index: argIndex));
871 const auto &attr = namedAttr->attr;
872
873 os << "{\n";
874 if (op.getDialect().usePropertiesForAttributes()) {
875 os.indent() << formatv(
876 Fmt: "[[maybe_unused]] auto tblgen_attr = {0}.getProperties().{1}();\n",
877 Vals&: castedName, Vals: op.getGetterName(name: namedAttr->name));
878 } else {
879 os.indent() << formatv(Fmt: "[[maybe_unused]] auto tblgen_attr = "
880 "{0}->getAttrOfType<{1}>(\"{2}\");\n",
881 Vals&: castedName, Vals: attr.getStorageType(), Vals&: namedAttr->name);
882 }
883
884 // TODO: This should use getter method to avoid duplication.
885 if (attr.hasDefaultValue()) {
886 os << "if (!tblgen_attr) tblgen_attr = "
887 << std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fmtCtx,
888 vals: tgfmt(fmt: attr.getDefaultValue(), ctx: &fmtCtx)))
889 << ";\n";
890 } else if (attr.isOptional()) {
891 // For a missing attribute that is optional according to definition, we
892 // should just capture a mlir::Attribute() to signal the missing state.
893 // That is precisely what getDiscardableAttr() returns on missing
894 // attributes.
895 } else {
896 emitMatchCheck(opName: castedName, matchFmt: tgfmt(fmt: "tblgen_attr", ctx: &fmtCtx),
897 failureFmt: formatv(Fmt: "\"expected op '{0}' to have attribute '{1}' "
898 "of type '{2}'\"",
899 Vals: op.getOperationName(), Vals&: namedAttr->name,
900 Vals: attr.getStorageType()));
901 }
902
903 auto matcher = tree.getArgAsLeaf(index: argIndex);
904 if (!matcher.isUnspecified()) {
905 if (!matcher.isAttrMatcher()) {
906 PrintFatalError(
907 ErrorLoc: loc, Msg: formatv(Fmt: "the {1}-th argument of op '{0}' should be an attribute",
908 Vals: op.getOperationName(), Vals: argIndex + 1));
909 }
910
911 // If a constraint is specified, we need to generate function call to its
912 // static verifier.
913 StringRef verifier = staticMatcherHelper.getVerifierName(leaf: matcher);
914 if (attr.isOptional()) {
915 // Avoid dereferencing null attribute. This is using a simple heuristic to
916 // avoid common cases of attempting to dereference null attribute. This
917 // will return where there is no check if attribute is null unless the
918 // attribute's value is not used.
919 // FIXME: This could be improved as some null dereferences could slip
920 // through.
921 if (!StringRef(matcher.getConditionTemplate()).contains(Other: "!$_self") &&
922 StringRef(matcher.getConditionTemplate()).contains(Other: "$_self")) {
923 os << "if (!tblgen_attr) return ::mlir::failure();\n";
924 }
925 }
926 emitStaticVerifierCall(
927 funcName: verifier, opName: castedName, arg: "tblgen_attr",
928 failureStr: formatv(Fmt: "\"op '{0}' attribute '{1}' failed to satisfy constraint: "
929 "'{2}'\"",
930 Vals: op.getOperationName(), Vals&: namedAttr->name,
931 Vals: escapeString(value: matcher.getAsConstraint().getSummary()))
932 .str());
933 }
934
935 // Capture the value
936 auto name = tree.getArgName(index: argIndex);
937 // `$_` is a special symbol to ignore op argument matching.
938 if (!name.empty() && name != "_") {
939 os << formatv(Fmt: "{0} = tblgen_attr;\n", Vals&: name);
940 }
941
942 os.unindent() << "}\n";
943}
944
945void PatternEmitter::emitMatchCheck(
946 StringRef opName, const FmtObjectBase &matchFmt,
947 const llvm::formatv_object_base &failureFmt) {
948 emitMatchCheck(opName, matchStr: matchFmt.str(), failureStr: failureFmt.str());
949}
950
951void PatternEmitter::emitMatchCheck(StringRef opName,
952 const std::string &matchStr,
953 const std::string &failureStr) {
954
955 os << "if (!(" << matchStr << "))";
956 os.scope(open: "{\n", close: "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
957 << ", [&](::mlir::Diagnostic &diag) {\n diag << "
958 << failureStr << ";\n});";
959}
960
961void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
962 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
963 int depth = 0;
964 emitMatch(tree, name: opName, depth);
965
966 for (auto &appliedConstraint : pattern.getConstraints()) {
967 auto &constraint = appliedConstraint.constraint;
968 auto &entities = appliedConstraint.entities;
969
970 auto condition = constraint.getConditionTemplate();
971 if (isa<TypeConstraint>(Val: constraint)) {
972 if (entities.size() != 1)
973 PrintFatalError(ErrorLoc: loc, Msg: "type constraint requires exactly one argument");
974
975 auto self = formatv(Fmt: "({0}.getType())",
976 Vals: symbolInfoMap.getValueAndRangeUse(symbol: entities.front()));
977 emitMatchCheck(
978 opName, matchFmt: tgfmt(fmt: condition, ctx: &fmtCtx.withSelf(subst: self.str())),
979 failureFmt: formatv(Fmt: "\"value entity '{0}' failed to satisfy constraint: '{1}'\"",
980 Vals&: entities.front(), Vals: escapeString(value: constraint.getSummary())));
981
982 } else if (isa<AttrConstraint>(Val: constraint)) {
983 PrintFatalError(
984 ErrorLoc: loc, Msg: "cannot use AttrConstraint in Pattern multi-entity constraints");
985 } else {
986 // TODO: replace formatv arguments with the exact specified
987 // args.
988 if (entities.size() > 4) {
989 PrintFatalError(ErrorLoc: loc, Msg: "only support up to 4-entity constraints now");
990 }
991 SmallVector<std::string, 4> names;
992 int i = 0;
993 for (int e = entities.size(); i < e; ++i)
994 names.push_back(Elt: symbolInfoMap.getValueAndRangeUse(symbol: entities[i]));
995 std::string self = appliedConstraint.self;
996 if (!self.empty())
997 self = symbolInfoMap.getValueAndRangeUse(symbol: self);
998 for (; i < 4; ++i)
999 names.push_back(Elt: "<unused>");
1000 emitMatchCheck(opName,
1001 matchFmt: tgfmt(fmt: condition, ctx: &fmtCtx.withSelf(subst: self), vals&: names[0],
1002 vals&: names[1], vals&: names[2], vals&: names[3]),
1003 failureFmt: formatv(Fmt: "\"entities '{0}' failed to satisfy constraint: "
1004 "'{1}'\"",
1005 Vals: llvm::join(R&: entities, Separator: ", "),
1006 Vals: escapeString(value: constraint.getSummary())));
1007 }
1008 }
1009
1010 // Some of the operands could be bound to the same symbol name, we need
1011 // to enforce equality constraint on those.
1012 // TODO: we should be able to emit equality checks early
1013 // and short circuit unnecessary work if vars are not equal.
1014 for (auto symbolInfoIt = symbolInfoMap.begin();
1015 symbolInfoIt != symbolInfoMap.end();) {
1016 auto range = symbolInfoMap.getRangeOfEqualElements(key: symbolInfoIt->first);
1017 auto startRange = range.first;
1018 auto endRange = range.second;
1019
1020 auto firstOperand = symbolInfoIt->second.getVarName(name: symbolInfoIt->first);
1021 for (++startRange; startRange != endRange; ++startRange) {
1022 auto secondOperand = startRange->second.getVarName(name: symbolInfoIt->first);
1023 emitMatchCheck(
1024 opName,
1025 matchStr: formatv(Fmt: "*{0}.begin() == *{1}.begin()", Vals&: firstOperand, Vals&: secondOperand),
1026 failureStr: formatv(Fmt: "\"Operands '{0}' and '{1}' must be equal\"", Vals&: firstOperand,
1027 Vals&: secondOperand));
1028 }
1029
1030 symbolInfoIt = endRange;
1031 }
1032
1033 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
1034}
1035
1036void PatternEmitter::collectOps(DagNode tree,
1037 llvm::SmallPtrSetImpl<const Operator *> &ops) {
1038 // Check if this tree is an operation.
1039 if (tree.isOperation()) {
1040 const Operator &op = tree.getDialectOp(mapper: opMap);
1041 LLVM_DEBUG(llvm::dbgs()
1042 << "found operation " << op.getOperationName() << '\n');
1043 ops.insert(Ptr: &op);
1044 }
1045
1046 // Recurse the arguments of the tree.
1047 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
1048 if (auto child = tree.getArgAsNestedDag(index: i))
1049 collectOps(tree: child, ops);
1050}
1051
1052void PatternEmitter::emit(StringRef rewriteName) {
1053 // Get the DAG tree for the source pattern.
1054 DagNode sourceTree = pattern.getSourcePattern();
1055
1056 const Operator &rootOp = pattern.getSourceRootOp();
1057 auto rootName = rootOp.getOperationName();
1058
1059 // Collect the set of result operations.
1060 llvm::SmallPtrSet<const Operator *, 4> resultOps;
1061 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
1062 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
1063 collectOps(tree: pattern.getResultPattern(index: i), ops&: resultOps);
1064 }
1065 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
1066
1067 // Emit RewritePattern for Pattern.
1068 auto locs = pattern.getLocation();
1069 os << formatv(Fmt: "/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n",
1070 Vals: llvm::reverse(C&: locs));
1071 os << formatv(Fmt: R"(struct {0} : public ::mlir::RewritePattern {
1072 {0}(::mlir::MLIRContext *context)
1073 : ::mlir::RewritePattern("{1}", {2}, context, {{)",
1074 Vals&: rewriteName, Vals&: rootName, Vals: pattern.getBenefit());
1075 // Sort result operators by name.
1076 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
1077 resultOps.end());
1078 llvm::sort(C&: sortedResultOps, Comp: [&](const Operator *lhs, const Operator *rhs) {
1079 return lhs->getOperationName() < rhs->getOperationName();
1080 });
1081 llvm::interleaveComma(c: sortedResultOps, os, each_fn: [&](const Operator *op) {
1082 os << '"' << op->getOperationName() << '"';
1083 });
1084 os << "}) {}\n";
1085
1086 // Emit matchAndRewrite() function.
1087 {
1088 auto classScope = os.scope();
1089 os.printReindented(str: R"(
1090 ::llvm::LogicalResult matchAndRewrite(::mlir::Operation *op0,
1091 ::mlir::PatternRewriter &rewriter) const override {)")
1092 << '\n';
1093 {
1094 auto functionScope = os.scope();
1095
1096 // Register all symbols bound in the source pattern.
1097 pattern.collectSourcePatternBoundSymbols(infoMap&: symbolInfoMap);
1098
1099 LLVM_DEBUG(llvm::dbgs()
1100 << "start creating local variables for capturing matches\n");
1101 os << "// Variables for capturing values and attributes used while "
1102 "creating ops\n";
1103 // Create local variables for storing the arguments and results bound
1104 // to symbols.
1105 for (const auto &symbolInfoPair : symbolInfoMap) {
1106 const auto &symbol = symbolInfoPair.first;
1107 const auto &info = symbolInfoPair.second;
1108
1109 os << info.getVarDecl(name: symbol);
1110 }
1111 // TODO: capture ops with consistent numbering so that it can be
1112 // reused for fused loc.
1113 os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n";
1114 LLVM_DEBUG(llvm::dbgs()
1115 << "done creating local variables for capturing matches\n");
1116
1117 os << "// Match\n";
1118 os << "tblgen_ops.push_back(op0);\n";
1119 emitMatchLogic(tree: sourceTree, opName: "op0");
1120
1121 os << "\n// Rewrite\n";
1122 emitRewriteLogic();
1123
1124 os << "return ::mlir::success();\n";
1125 }
1126 os << "}\n";
1127 }
1128 os << "};\n\n";
1129}
1130
1131void PatternEmitter::emitRewriteLogic() {
1132 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
1133 const Operator &rootOp = pattern.getSourceRootOp();
1134 int numExpectedResults = rootOp.getNumResults();
1135 int numResultPatterns = pattern.getNumResultPatterns();
1136
1137 // First register all symbols bound to ops generated in result patterns.
1138 pattern.collectResultPatternBoundSymbols(infoMap&: symbolInfoMap);
1139
1140 // Only the last N static values generated are used to replace the matched
1141 // root N-result op. We need to calculate the starting index (of the results
1142 // of the matched op) each result pattern is to replace.
1143 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
1144 // If we don't need to replace any value at all, set the replacement starting
1145 // index as the number of result patterns so we skip all of them when trying
1146 // to replace the matched op's results.
1147 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
1148 for (int i = numResultPatterns - 1; i >= 0; --i) {
1149 auto numValues = getNodeValueCount(node: pattern.getResultPattern(index: i));
1150 offsets[i] = offsets[i + 1] - numValues;
1151 if (offsets[i] == 0) {
1152 if (replStartIndex == -1)
1153 replStartIndex = i;
1154 } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
1155 auto error = formatv(
1156 Fmt: "cannot use the same multi-result op '{0}' to generate both "
1157 "auxiliary values and values to be used for replacing the matched op",
1158 Vals: pattern.getResultPattern(index: i).getSymbol());
1159 PrintFatalError(ErrorLoc: loc, Msg: error);
1160 }
1161 }
1162
1163 if (offsets.front() > 0) {
1164 const char error[] =
1165 "not enough values generated to replace the matched op";
1166 PrintFatalError(ErrorLoc: loc, Msg: error);
1167 }
1168
1169 os << "auto odsLoc = rewriter.getFusedLoc({";
1170 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
1171 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
1172 }
1173 os << "}); (void)odsLoc;\n";
1174
1175 // Process auxiliary result patterns.
1176 for (int i = 0; i < replStartIndex; ++i) {
1177 DagNode resultTree = pattern.getResultPattern(index: i);
1178 auto val = handleResultPattern(resultTree, resultIndex: offsets[i], depth: 0);
1179 // Normal op creation will be streamed to `os` by the above call; but
1180 // NativeCodeCall will only be materialized to `os` if it is used. Here
1181 // we are handling auxiliary patterns so we want the side effect even if
1182 // NativeCodeCall is not replacing matched root op's results.
1183 if (resultTree.isNativeCodeCall() &&
1184 resultTree.getNumReturnsOfNativeCode() == 0)
1185 os << val << ";\n";
1186 }
1187
1188 auto processSupplementalPatterns = [&]() {
1189 int numSupplementalPatterns = pattern.getNumSupplementalPatterns();
1190 for (int i = 0, offset = -numSupplementalPatterns;
1191 i < numSupplementalPatterns; ++i) {
1192 DagNode resultTree = pattern.getSupplementalPattern(index: i);
1193 auto val = handleResultPattern(resultTree, resultIndex: offset++, depth: 0);
1194 if (resultTree.isNativeCodeCall() &&
1195 resultTree.getNumReturnsOfNativeCode() == 0)
1196 os << val << ";\n";
1197 }
1198 };
1199
1200 if (numExpectedResults == 0) {
1201 assert(replStartIndex >= numResultPatterns &&
1202 "invalid auxiliary vs. replacement pattern division!");
1203 processSupplementalPatterns();
1204 // No result to replace. Just erase the op.
1205 os << "rewriter.eraseOp(op0);\n";
1206 } else {
1207 // Process replacement result patterns.
1208 os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
1209 for (int i = replStartIndex; i < numResultPatterns; ++i) {
1210 DagNode resultTree = pattern.getResultPattern(index: i);
1211 auto val = handleResultPattern(resultTree, resultIndex: offsets[i], depth: 0);
1212 os << "\n";
1213 // Resolve each symbol for all range use so that we can loop over them.
1214 // We need an explicit cast to `SmallVector` to capture the cases where
1215 // `{0}` resolves to an `Operation::result_range` as well as cases that
1216 // are not iterable (e.g. vector that gets wrapped in additional braces by
1217 // RewriterGen).
1218 // TODO: Revisit the need for materializing a vector.
1219 os << symbolInfoMap.getAllRangeUse(
1220 symbol: val,
1221 fmt: "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
1222 " tblgen_repl_values.push_back(v);\n}\n",
1223 separator: "\n");
1224 }
1225 processSupplementalPatterns();
1226 os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
1227 }
1228
1229 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
1230}
1231
1232std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
1233 return std::string(
1234 formatv(Fmt: "tblgen_{0}_{1}", Vals: op->getCppClassName(), Vals: nextValueId++));
1235}
1236
1237std::string PatternEmitter::handleResultPattern(DagNode resultTree,
1238 int resultIndex, int depth) {
1239 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
1240 LLVM_DEBUG(resultTree.print(llvm::dbgs()));
1241 LLVM_DEBUG(llvm::dbgs() << '\n');
1242
1243 if (resultTree.isLocationDirective()) {
1244 PrintFatalError(ErrorLoc: loc,
1245 Msg: "location directive can only be used with op creation");
1246 }
1247
1248 if (resultTree.isNativeCodeCall())
1249 return handleReplaceWithNativeCodeCall(resultTree, depth);
1250
1251 if (resultTree.isReplaceWithValue())
1252 return handleReplaceWithValue(tree: resultTree).str();
1253
1254 if (resultTree.isVariadic())
1255 return handleVariadic(tree: resultTree, depth);
1256
1257 // Normal op creation.
1258 auto symbol = handleOpCreation(tree: resultTree, resultIndex, depth);
1259 if (resultTree.getSymbol().empty()) {
1260 // This is an op not explicitly bound to a symbol in the rewrite rule.
1261 // Register the auto-generated symbol for it.
1262 symbolInfoMap.bindOpResult(symbol, op: pattern.getDialectOp(node: resultTree));
1263 }
1264 return symbol;
1265}
1266
1267std::string PatternEmitter::handleVariadic(DagNode tree, int depth) {
1268 assert(tree.isVariadic());
1269
1270 std::string output;
1271 llvm::raw_string_ostream oss(output);
1272 auto name = std::string(formatv(Fmt: "tblgen_variadic_values_{0}", Vals: nextValueId++));
1273 symbolInfoMap.bindValue(symbol: name);
1274 oss << "::llvm::SmallVector<::mlir::Value, 4> " << name << ";\n";
1275 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
1276 if (auto child = tree.getArgAsNestedDag(index: i)) {
1277 oss << name << ".push_back(" << handleResultPattern(resultTree: child, resultIndex: i, depth: depth + 1)
1278 << ");\n";
1279 } else {
1280 oss << name << ".push_back("
1281 << handleOpArgument(leaf: tree.getArgAsLeaf(index: i), patArgName: tree.getArgName(index: i))
1282 << ");\n";
1283 }
1284 }
1285
1286 os << oss.str();
1287 return name;
1288}
1289
1290StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
1291 assert(tree.isReplaceWithValue());
1292
1293 if (tree.getNumArgs() != 1) {
1294 PrintFatalError(
1295 ErrorLoc: loc, Msg: "replaceWithValue directive must take exactly one argument");
1296 }
1297
1298 if (!tree.getSymbol().empty()) {
1299 PrintFatalError(ErrorLoc: loc, Msg: "cannot bind symbol to replaceWithValue");
1300 }
1301
1302 return tree.getArgName(index: 0);
1303}
1304
1305std::string PatternEmitter::handleLocationDirective(DagNode tree) {
1306 assert(tree.isLocationDirective());
1307 auto lookUpArgLoc = [this, &tree](int idx) {
1308 const auto *const lookupFmt = "{0}.getLoc()";
1309 return symbolInfoMap.getValueAndRangeUse(symbol: tree.getArgName(index: idx), fmt: lookupFmt);
1310 };
1311
1312 if (tree.getNumArgs() == 0)
1313 llvm::PrintFatalError(
1314 Msg: "At least one argument to location directive required");
1315
1316 if (!tree.getSymbol().empty())
1317 PrintFatalError(ErrorLoc: loc, Msg: "cannot bind symbol to location");
1318
1319 if (tree.getNumArgs() == 1) {
1320 DagLeaf leaf = tree.getArgAsLeaf(index: 0);
1321 if (leaf.isStringAttr())
1322 return formatv(Fmt: "::mlir::NameLoc::get(rewriter.getStringAttr(\"{0}\"))",
1323 Vals: leaf.getStringAttr())
1324 .str();
1325 return lookUpArgLoc(0);
1326 }
1327
1328 std::string ret;
1329 llvm::raw_string_ostream os(ret);
1330 std::string strAttr;
1331 os << "rewriter.getFusedLoc({";
1332 bool first = true;
1333 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
1334 DagLeaf leaf = tree.getArgAsLeaf(index: i);
1335 // Handle the optional string value.
1336 if (leaf.isStringAttr()) {
1337 if (!strAttr.empty())
1338 llvm::PrintFatalError(Msg: "Only one string attribute may be specified");
1339 strAttr = leaf.getStringAttr();
1340 continue;
1341 }
1342 os << (first ? "" : ", ") << lookUpArgLoc(i);
1343 first = false;
1344 }
1345 os << "}";
1346 if (!strAttr.empty()) {
1347 os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
1348 }
1349 os << ")";
1350 return os.str();
1351}
1352
1353std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i,
1354 int depth) {
1355 // Nested NativeCodeCall.
1356 if (auto dagNode = returnType.getArgAsNestedDag(index: i)) {
1357 if (!dagNode.isNativeCodeCall())
1358 PrintFatalError(ErrorLoc: loc, Msg: "nested DAG in `returnType` must be a native code "
1359 "call");
1360 return handleReplaceWithNativeCodeCall(resultTree: dagNode, depth);
1361 }
1362 // String literal.
1363 auto dagLeaf = returnType.getArgAsLeaf(index: i);
1364 if (dagLeaf.isStringAttr())
1365 return tgfmt(fmt: dagLeaf.getStringAttr(), ctx: &fmtCtx);
1366 return tgfmt(
1367 fmt: "$0.getType()", ctx: &fmtCtx,
1368 vals: handleOpArgument(leaf: returnType.getArgAsLeaf(index: i), patArgName: returnType.getArgName(index: i)));
1369}
1370
1371std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
1372 StringRef patArgName) {
1373 if (leaf.isStringAttr())
1374 PrintFatalError(ErrorLoc: loc, Msg: "raw string not supported as argument");
1375 if (leaf.isConstantAttr()) {
1376 auto constAttr = leaf.getAsConstantAttr();
1377 return handleConstantAttr(attr: constAttr.getAttribute(),
1378 value: constAttr.getConstantValue());
1379 }
1380 if (leaf.isEnumCase()) {
1381 auto enumCase = leaf.getAsEnumCase();
1382 // This is an enum case backed by an IntegerAttr. We need to get its value
1383 // to build the constant.
1384 std::string val = std::to_string(val: enumCase.getValue());
1385 return handleConstantAttr(attr: Attribute(&enumCase.getDef()), value: val);
1386 }
1387
1388 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
1389 auto argName = symbolInfoMap.getValueAndRangeUse(symbol: patArgName);
1390 if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
1391 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
1392 << "' (via symbol ref)\n");
1393 return argName;
1394 }
1395 if (leaf.isNativeCodeCall()) {
1396 auto repl = tgfmt(fmt: leaf.getNativeCodeTemplate(), ctx: &fmtCtx.withSelf(subst: argName));
1397 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
1398 << "' (via NativeCodeCall)\n");
1399 return std::string(repl);
1400 }
1401 PrintFatalError(ErrorLoc: loc, Msg: "unhandled case when rewriting op");
1402}
1403
1404std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
1405 int depth) {
1406 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
1407 LLVM_DEBUG(tree.print(llvm::dbgs()));
1408 LLVM_DEBUG(llvm::dbgs() << '\n');
1409
1410 auto fmt = tree.getNativeCodeTemplate();
1411
1412 SmallVector<std::string, 16> attrs;
1413
1414 auto tail = getTrailingDirectives(tree);
1415 if (tail.returnType)
1416 PrintFatalError(ErrorLoc: loc, Msg: "`NativeCodeCall` cannot have return type specifier");
1417 auto locToUse = getLocation(tail);
1418
1419 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1420 if (tree.isNestedDagArg(index: i)) {
1421 attrs.push_back(
1422 Elt: handleResultPattern(resultTree: tree.getArgAsNestedDag(index: i), resultIndex: i, depth: depth + 1));
1423 } else {
1424 attrs.push_back(
1425 Elt: handleOpArgument(leaf: tree.getArgAsLeaf(index: i), patArgName: tree.getArgName(index: i)));
1426 }
1427 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
1428 << " replacement: " << attrs[i] << "\n");
1429 }
1430
1431 std::string symbol = tgfmt(fmt, ctx: &fmtCtx.addSubst(placeholder: "_loc", subst: locToUse),
1432 params: static_cast<ArrayRef<std::string>>(attrs));
1433
1434 // In general, NativeCodeCall without naming binding don't need this. To
1435 // ensure void helper function has been correctly labeled, i.e., use
1436 // NativeCodeCallVoid, we cache the result to a local variable so that we will
1437 // get a compilation error in the auto-generated file.
1438 // Example.
1439 // // In the td file
1440 // Pat<(...), (NativeCodeCall<Foo> ...)>
1441 //
1442 // ---
1443 //
1444 // // In the auto-generated .cpp
1445 // ...
1446 // // Causes compilation error if Foo() returns void.
1447 // auto nativeVar = Foo();
1448 // ...
1449 if (tree.getNumReturnsOfNativeCode() != 0) {
1450 // Determine the local variable name for return value.
1451 std::string varName =
1452 SymbolInfoMap::getValuePackName(symbol: tree.getSymbol()).str();
1453 if (varName.empty()) {
1454 varName = formatv(Fmt: "nativeVar_{0}", Vals: nextValueId++);
1455 // Register the local variable for later uses.
1456 symbolInfoMap.bindValues(symbol: varName, numValues: tree.getNumReturnsOfNativeCode());
1457 }
1458
1459 // Catch the return value of helper function.
1460 os << formatv(Fmt: "auto {0} = {1}; (void){0};\n", Vals&: varName, Vals&: symbol);
1461
1462 if (!tree.getSymbol().empty())
1463 symbol = tree.getSymbol().str();
1464 else
1465 symbol = varName;
1466 }
1467
1468 return symbol;
1469}
1470
1471int PatternEmitter::getNodeValueCount(DagNode node) {
1472 if (node.isOperation()) {
1473 // If the op is bound to a symbol in the rewrite rule, query its result
1474 // count from the symbol info map.
1475 auto symbol = node.getSymbol();
1476 if (!symbol.empty()) {
1477 return symbolInfoMap.getStaticValueCount(symbol);
1478 }
1479 // Otherwise this is an unbound op; we will use all its results.
1480 return pattern.getDialectOp(node).getNumResults();
1481 }
1482
1483 if (node.isNativeCodeCall())
1484 return node.getNumReturnsOfNativeCode();
1485
1486 return 1;
1487}
1488
1489PatternEmitter::TrailingDirectives
1490PatternEmitter::getTrailingDirectives(DagNode tree) {
1491 TrailingDirectives tail = {.location: DagNode(nullptr), .returnType: DagNode(nullptr), .numDirectives: 0};
1492
1493 // Look backwards through the arguments.
1494 auto numPatArgs = tree.getNumArgs();
1495 for (int i = numPatArgs - 1; i >= 0; --i) {
1496 auto dagArg = tree.getArgAsNestedDag(index: i);
1497 // A leaf is not a directive. Stop looking.
1498 if (!dagArg)
1499 break;
1500
1501 auto isLocation = dagArg.isLocationDirective();
1502 auto isReturnType = dagArg.isReturnTypeDirective();
1503 // If encountered a DAG node that isn't a trailing directive, stop looking.
1504 if (!(isLocation || isReturnType))
1505 break;
1506 // Save the directive, but error if one of the same type was already
1507 // found.
1508 ++tail.numDirectives;
1509 if (isLocation) {
1510 if (tail.location)
1511 PrintFatalError(ErrorLoc: loc, Msg: "`location` directive can only be specified "
1512 "once");
1513 tail.location = dagArg;
1514 } else if (isReturnType) {
1515 if (tail.returnType)
1516 PrintFatalError(ErrorLoc: loc, Msg: "`returnType` directive can only be specified "
1517 "once");
1518 tail.returnType = dagArg;
1519 }
1520 }
1521
1522 return tail;
1523}
1524
1525std::string
1526PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) {
1527 if (tail.location)
1528 return handleLocationDirective(tree: tail.location);
1529
1530 // If no explicit location is given, use the default, all fused, location.
1531 return "odsLoc";
1532}
1533
1534std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
1535 int depth) {
1536 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
1537 LLVM_DEBUG(tree.print(llvm::dbgs()));
1538 LLVM_DEBUG(llvm::dbgs() << '\n');
1539
1540 Operator &resultOp = tree.getDialectOp(mapper: opMap);
1541 bool useProperties = resultOp.getDialect().usePropertiesForAttributes();
1542 auto numOpArgs = resultOp.getNumArgs();
1543 auto numPatArgs = tree.getNumArgs();
1544
1545 auto tail = getTrailingDirectives(tree);
1546 auto locToUse = getLocation(tail);
1547
1548 auto inPattern = numPatArgs - tail.numDirectives;
1549 if (numOpArgs != inPattern) {
1550 PrintFatalError(ErrorLoc: loc,
1551 Msg: formatv(Fmt: "resultant op '{0}' argument number mismatch: "
1552 "{1} in pattern vs. {2} in definition",
1553 Vals: resultOp.getOperationName(), Vals&: inPattern, Vals&: numOpArgs));
1554 }
1555
1556 // A map to collect all nested DAG child nodes' names, with operand index as
1557 // the key. This includes both bound and unbound child nodes.
1558 ChildNodeIndexNameMap childNodeNames;
1559
1560 // If the argument is a type constraint, then its an operand. Check if the
1561 // op's argument is variadic that the argument in the pattern is too.
1562 auto checkIfMatchedVariadic = [&](int i) {
1563 // FIXME: This does not yet check for variable/leaf case.
1564 // FIXME: Change so that native code call can be handled.
1565 const auto *operand =
1566 llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val: resultOp.getArg(index: i));
1567 if (!operand || !operand->isVariadic())
1568 return;
1569
1570 auto child = tree.getArgAsNestedDag(index: i);
1571 if (!child)
1572 return;
1573
1574 // Skip over replaceWithValues.
1575 while (child.isReplaceWithValue()) {
1576 if (!(child = child.getArgAsNestedDag(index: 0)))
1577 return;
1578 }
1579 if (!child.isNativeCodeCall() && !child.isVariadic())
1580 PrintFatalError(ErrorLoc: loc, Msg: formatv(Fmt: "op expects variadic operand `{0}`, while "
1581 "provided is non-variadic",
1582 Vals: resultOp.getArgName(index: i)));
1583 };
1584
1585 // First go through all the child nodes who are nested DAG constructs to
1586 // create ops for them and remember the symbol names for them, so that we can
1587 // use the results in the current node. This happens in a recursive manner.
1588 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1589 checkIfMatchedVariadic(i);
1590 if (auto child = tree.getArgAsNestedDag(index: i))
1591 childNodeNames[i] = handleResultPattern(resultTree: child, resultIndex: i, depth: depth + 1);
1592 }
1593
1594 // The name of the local variable holding this op.
1595 std::string valuePackName;
1596 // The symbol for holding the result of this pattern. Note that the result of
1597 // this pattern is not necessarily the same as the variable created by this
1598 // pattern because we can use `__N` suffix to refer only a specific result if
1599 // the generated op is a multi-result op.
1600 std::string resultValue;
1601 if (tree.getSymbol().empty()) {
1602 // No symbol is explicitly bound to this op in the pattern. Generate a
1603 // unique name.
1604 valuePackName = resultValue = getUniqueSymbol(op: &resultOp);
1605 } else {
1606 resultValue = std::string(tree.getSymbol());
1607 // Strip the index to get the name for the value pack and use it to name the
1608 // local variable for the op.
1609 valuePackName = std::string(SymbolInfoMap::getValuePackName(symbol: resultValue));
1610 }
1611
1612 // Create the local variable for this op.
1613 os << formatv(Fmt: "{0} {1};\n{{\n", Vals: resultOp.getQualCppClassName(),
1614 Vals&: valuePackName);
1615
1616 // Right now ODS don't have general type inference support. Except a few
1617 // special cases listed below, DRR needs to supply types for all results
1618 // when building an op.
1619 bool isSameOperandsAndResultType =
1620 resultOp.getTrait(trait: "::mlir::OpTrait::SameOperandsAndResultType");
1621 bool useFirstAttr =
1622 resultOp.getTrait(trait: "::mlir::OpTrait::FirstAttrDerivedResultType");
1623
1624 if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) {
1625 // We know how to deduce the result type for ops with these traits and we've
1626 // generated builders taking aggregate parameters. Use those builders to
1627 // create the ops.
1628
1629 // First prepare local variables for op arguments used in builder call.
1630 createAggregateLocalVarsForOpArgs(node: tree, childNodeNames, depth);
1631
1632 // Then create the op.
1633 os.scope(open: "", close: "\n}\n").os
1634 << formatv(Fmt: "{0} = rewriter.create<{1}>({2}, tblgen_values, {3});",
1635 Vals&: valuePackName, Vals: resultOp.getQualCppClassName(), Vals&: locToUse,
1636 Vals: useProperties ? "tblgen_props" : "tblgen_attrs");
1637 return resultValue;
1638 }
1639
1640 bool usePartialResults = valuePackName != resultValue;
1641
1642 if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) {
1643 // For these cases (broadcastable ops, op results used both as auxiliary
1644 // values and replacement values, ops in nested patterns, auxiliary ops), we
1645 // still need to supply the result types when building the op. But because
1646 // we don't generate a builder automatically with ODS for them, it's the
1647 // developer's responsibility to make sure such a builder (with result type
1648 // deduction ability) exists. We go through the separate-parameter builder
1649 // here given that it's easier for developers to write compared to
1650 // aggregate-parameter builders.
1651 createSeparateLocalVarsForOpArgs(node: tree, childNodeNames);
1652
1653 os.scope().os << formatv(Fmt: "{0} = rewriter.create<{1}>({2}", Vals&: valuePackName,
1654 Vals: resultOp.getQualCppClassName(), Vals&: locToUse);
1655 supplyValuesForOpArgs(node: tree, childNodeNames, depth);
1656 os << "\n );\n}\n";
1657 return resultValue;
1658 }
1659
1660 // If we are provided explicit return types, use them to build the op.
1661 // However, if depth == 0 and resultIndex >= 0, it means we are replacing
1662 // the values generated from the source pattern root op. Then we must use the
1663 // source pattern's value types to determine the value type of the generated
1664 // op here.
1665 if (depth == 0 && resultIndex >= 0 && tail.returnType)
1666 PrintFatalError(ErrorLoc: loc, Msg: "Cannot specify explicit return types in an op whose "
1667 "return values replace the source pattern's root op");
1668
1669 // First prepare local variables for op arguments used in builder call.
1670 createAggregateLocalVarsForOpArgs(node: tree, childNodeNames, depth);
1671
1672 // Then prepare the result types. We need to specify the types for all
1673 // results.
1674 os.indent() << formatv(Fmt: "::llvm::SmallVector<::mlir::Type, 4> tblgen_types; "
1675 "(void)tblgen_types;\n");
1676 int numResults = resultOp.getNumResults();
1677 if (tail.returnType) {
1678 auto numRetTys = tail.returnType.getNumArgs();
1679 for (int i = 0; i < numRetTys; ++i) {
1680 auto varName = handleReturnTypeArg(returnType: tail.returnType, i, depth: depth + 1);
1681 os << "tblgen_types.push_back(" << varName << ");\n";
1682 }
1683 } else {
1684 if (numResults != 0) {
1685 // Copy the result types from the source pattern.
1686 for (int i = 0; i < numResults; ++i)
1687 os << formatv(Fmt: "for (auto v: castedOp0.getODSResults({0})) {{\n"
1688 " tblgen_types.push_back(v.getType());\n}\n",
1689 Vals: resultIndex + i);
1690 }
1691 }
1692 os << formatv(Fmt: "{0} = rewriter.create<{1}>({2}, tblgen_types, "
1693 "tblgen_values, {3});\n",
1694 Vals&: valuePackName, Vals: resultOp.getQualCppClassName(), Vals&: locToUse,
1695 Vals: useProperties ? "tblgen_props" : "tblgen_attrs");
1696 os.unindent() << "}\n";
1697 return resultValue;
1698}
1699
1700void PatternEmitter::createSeparateLocalVarsForOpArgs(
1701 DagNode node, ChildNodeIndexNameMap &childNodeNames) {
1702 Operator &resultOp = node.getDialectOp(mapper: opMap);
1703
1704 // Now prepare operands used for building this op:
1705 // * If the operand is non-variadic, we create a `Value` local variable.
1706 // * If the operand is variadic, we create a `SmallVector<Value>` local
1707 // variable.
1708
1709 int valueIndex = 0; // An index for uniquing local variable names.
1710 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1711 const auto *operand =
1712 llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val: resultOp.getArg(index: argIndex));
1713 // We do not need special handling for attributes.
1714 if (!operand)
1715 continue;
1716
1717 raw_indented_ostream::DelimitedScope scope(os);
1718 std::string varName;
1719 if (operand->isVariadic()) {
1720 varName = std::string(formatv(Fmt: "tblgen_values_{0}", Vals: valueIndex++));
1721 os << formatv(Fmt: "::llvm::SmallVector<::mlir::Value, 4> {0};\n", Vals&: varName);
1722 std::string range;
1723 if (node.isNestedDagArg(index: argIndex)) {
1724 range = childNodeNames[argIndex];
1725 } else {
1726 range = std::string(node.getArgName(index: argIndex));
1727 }
1728 // Resolve the symbol for all range use so that we have a uniform way of
1729 // capturing the values.
1730 range = symbolInfoMap.getValueAndRangeUse(symbol: range);
1731 os << formatv(Fmt: "for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", Vals&: range,
1732 Vals&: varName);
1733 } else {
1734 varName = std::string(formatv(Fmt: "tblgen_value_{0}", Vals: valueIndex++));
1735 os << formatv(Fmt: "::mlir::Value {0} = ", Vals&: varName);
1736 if (node.isNestedDagArg(index: argIndex)) {
1737 os << symbolInfoMap.getValueAndRangeUse(symbol: childNodeNames[argIndex]);
1738 } else {
1739 DagLeaf leaf = node.getArgAsLeaf(index: argIndex);
1740 auto symbol =
1741 symbolInfoMap.getValueAndRangeUse(symbol: node.getArgName(index: argIndex));
1742 if (leaf.isNativeCodeCall()) {
1743 os << std::string(
1744 tgfmt(fmt: leaf.getNativeCodeTemplate(), ctx: &fmtCtx.withSelf(subst: symbol)));
1745 } else {
1746 os << symbol;
1747 }
1748 }
1749 os << ";\n";
1750 }
1751
1752 // Update to use the newly created local variable for building the op later.
1753 childNodeNames[argIndex] = varName;
1754 }
1755}
1756
1757void PatternEmitter::supplyValuesForOpArgs(
1758 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1759 Operator &resultOp = node.getDialectOp(mapper: opMap);
1760 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
1761 argIndex != numOpArgs; ++argIndex) {
1762 // Start each argument on its own line.
1763 os << ",\n ";
1764
1765 Argument opArg = resultOp.getArg(index: argIndex);
1766 // Handle the case of operand first.
1767 if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: opArg)) {
1768 if (!operand->name.empty())
1769 os << "/*" << operand->name << "=*/";
1770 os << childNodeNames.lookup(Val: argIndex);
1771 continue;
1772 }
1773
1774 // The argument in the op definition.
1775 auto opArgName = resultOp.getArgName(index: argIndex);
1776 if (auto subTree = node.getArgAsNestedDag(index: argIndex)) {
1777 if (!subTree.isNativeCodeCall())
1778 PrintFatalError(ErrorLoc: loc, Msg: "only NativeCodeCall allowed in nested dag node "
1779 "for creating attribute");
1780 os << formatv(Fmt: "/*{0}=*/{1}", Vals&: opArgName, Vals: childNodeNames.lookup(Val: argIndex));
1781 } else {
1782 auto leaf = node.getArgAsLeaf(index: argIndex);
1783 // The argument in the result DAG pattern.
1784 auto patArgName = node.getArgName(index: argIndex);
1785 if (leaf.isConstantAttr() || leaf.isEnumCase()) {
1786 // TODO: Refactor out into map to avoid recomputing these.
1787 if (!isa<NamedAttribute *>(Val: opArg))
1788 PrintFatalError(ErrorLoc: loc, Msg: Twine("expected attribute ") + Twine(argIndex));
1789 if (!patArgName.empty())
1790 os << "/*" << patArgName << "=*/";
1791 } else {
1792 os << "/*" << opArgName << "=*/";
1793 }
1794 os << handleOpArgument(leaf, patArgName);
1795 }
1796 }
1797}
1798
1799void PatternEmitter::createAggregateLocalVarsForOpArgs(
1800 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1801 Operator &resultOp = node.getDialectOp(mapper: opMap);
1802
1803 bool useProperties = resultOp.getDialect().usePropertiesForAttributes();
1804 auto scope = os.scope();
1805 os << formatv(Fmt: "::llvm::SmallVector<::mlir::Value, 4> "
1806 "tblgen_values; (void)tblgen_values;\n");
1807 if (useProperties) {
1808 os << formatv(Fmt: "{0}::Properties tblgen_props; (void)tblgen_props;\n",
1809 Vals: resultOp.getQualCppClassName());
1810 } else {
1811 os << formatv(Fmt: "::llvm::SmallVector<::mlir::NamedAttribute, 4> "
1812 "tblgen_attrs; (void)tblgen_attrs;\n");
1813 }
1814
1815 const char *setPropCmd =
1816 "tblgen_props.{0} = "
1817 "::llvm::dyn_cast_if_present<decltype(tblgen_props.{0})>({1});\n";
1818 const char *addAttrCmd =
1819 "if (auto tmpAttr = {1}) {\n"
1820 " tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
1821 "tmpAttr);\n}\n";
1822 const char *setterCmd = (useProperties) ? setPropCmd : addAttrCmd;
1823
1824 int numVariadic = 0;
1825 bool hasOperandSegmentSizes = false;
1826 std::vector<std::string> sizes;
1827 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1828 if (isa<NamedAttribute *>(Val: resultOp.getArg(index: argIndex))) {
1829 // The argument in the op definition.
1830 auto opArgName = resultOp.getArgName(index: argIndex);
1831 hasOperandSegmentSizes =
1832 hasOperandSegmentSizes || opArgName == "operandSegmentSizes";
1833 if (auto subTree = node.getArgAsNestedDag(index: argIndex)) {
1834 if (!subTree.isNativeCodeCall())
1835 PrintFatalError(ErrorLoc: loc, Msg: "only NativeCodeCall allowed in nested dag node "
1836 "for creating attribute");
1837
1838 os << formatv(Fmt: setterCmd, Vals&: opArgName, Vals: childNodeNames.lookup(Val: argIndex));
1839 } else {
1840 auto leaf = node.getArgAsLeaf(index: argIndex);
1841 // The argument in the result DAG pattern.
1842 auto patArgName = node.getArgName(index: argIndex);
1843 os << formatv(Fmt: setterCmd, Vals&: opArgName, Vals: handleOpArgument(leaf, patArgName));
1844 }
1845 continue;
1846 }
1847
1848 const auto *operand =
1849 cast<NamedTypeConstraint *>(Val: resultOp.getArg(index: argIndex));
1850 if (operand->isVariadic()) {
1851 ++numVariadic;
1852 std::string range;
1853 if (node.isNestedDagArg(index: argIndex)) {
1854 range = childNodeNames.lookup(Val: argIndex);
1855 } else {
1856 range = std::string(node.getArgName(index: argIndex));
1857 }
1858 // Resolve the symbol for all range use so that we have a uniform way of
1859 // capturing the values.
1860 range = symbolInfoMap.getValueAndRangeUse(symbol: range);
1861 os << formatv(Fmt: "for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n",
1862 Vals&: range);
1863 sizes.push_back(x: formatv(Fmt: "static_cast<int32_t>({0}.size())", Vals&: range));
1864 } else {
1865 sizes.emplace_back(args: "1");
1866 os << formatv(Fmt: "tblgen_values.push_back(");
1867 if (node.isNestedDagArg(index: argIndex)) {
1868 os << symbolInfoMap.getValueAndRangeUse(
1869 symbol: childNodeNames.lookup(Val: argIndex));
1870 } else {
1871 DagLeaf leaf = node.getArgAsLeaf(index: argIndex);
1872 if (leaf.isConstantAttr())
1873 // TODO: Use better location
1874 PrintFatalError(
1875 ErrorLoc: loc,
1876 Msg: "attribute found where value was expected, if attempting to use "
1877 "constant value, construct a constant op with given attribute "
1878 "instead");
1879
1880 auto symbol =
1881 symbolInfoMap.getValueAndRangeUse(symbol: node.getArgName(index: argIndex));
1882 if (leaf.isNativeCodeCall()) {
1883 os << std::string(
1884 tgfmt(fmt: leaf.getNativeCodeTemplate(), ctx: &fmtCtx.withSelf(subst: symbol)));
1885 } else {
1886 os << symbol;
1887 }
1888 }
1889 os << ");\n";
1890 }
1891 }
1892
1893 if (numVariadic > 1 && !hasOperandSegmentSizes) {
1894 // Only set size if it can't be computed.
1895 const auto *sameVariadicSize =
1896 resultOp.getTrait(trait: "::mlir::OpTrait::SameVariadicOperandSize");
1897 if (!sameVariadicSize) {
1898 if (useProperties) {
1899 const char *setSizes = R"(
1900 tblgen_props.operandSegmentSizes = {{ {0} };
1901 )";
1902 os.printReindented(str: formatv(Fmt: setSizes, Vals: llvm::join(R&: sizes, Separator: ", ")).str());
1903 } else {
1904 const char *setSizes = R"(
1905 tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
1906 rewriter.getDenseI32ArrayAttr({{ {0} }));
1907 )";
1908 os.printReindented(str: formatv(Fmt: setSizes, Vals: llvm::join(R&: sizes, Separator: ", ")).str());
1909 }
1910 }
1911 }
1912}
1913
1914StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os,
1915 const RecordKeeper &records,
1916 RecordOperatorMap &mapper)
1917 : opMap(mapper), staticVerifierEmitter(os, records) {}
1918
1919void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
1920 // PatternEmitter will use the static matcher if there's one generated. To
1921 // ensure that all the dependent static matchers are generated before emitting
1922 // the matching logic of the DagNode, we use topological order to achieve it.
1923 for (auto &dagInfo : topologicalOrder) {
1924 DagNode node = dagInfo.first;
1925 if (!useStaticMatcher(node))
1926 continue;
1927
1928 std::string funcName =
1929 formatv(Fmt: "static_dag_matcher_{0}", Vals: staticMatcherCounter++);
1930 assert(!matcherNames.contains(node));
1931 PatternEmitter(dagInfo.second, &opMap, os, *this)
1932 .emitStaticMatcher(tree: node, funcName);
1933 matcherNames[node] = funcName;
1934 }
1935}
1936
1937void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) {
1938 staticVerifierEmitter.emitPatternConstraints(constraints: constraints.getArrayRef());
1939}
1940
1941void StaticMatcherHelper::addPattern(const Record *record) {
1942 Pattern pat(record, &opMap);
1943
1944 // While generating the function body of the DAG matcher, it may depends on
1945 // other DAG matchers. To ensure the dependent matchers are ready, we compute
1946 // the topological order for all the DAGs and emit the DAG matchers in this
1947 // order.
1948 llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) {
1949 ++refStats[node];
1950
1951 if (refStats[node] != 1)
1952 return;
1953
1954 for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i)
1955 if (DagNode sibling = node.getArgAsNestedDag(index: i))
1956 dfs(sibling);
1957 else {
1958 DagLeaf leaf = node.getArgAsLeaf(index: i);
1959 if (!leaf.isUnspecified())
1960 constraints.insert(X: leaf);
1961 }
1962
1963 topologicalOrder.push_back(Elt: std::make_pair(x&: node, y&: record));
1964 };
1965
1966 dfs(pat.getSourcePattern());
1967}
1968
1969StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) {
1970 if (leaf.isAttrMatcher()) {
1971 std::optional<StringRef> constraint =
1972 staticVerifierEmitter.getAttrConstraintFn(constraint: leaf.getAsConstraint());
1973 assert(constraint && "attribute constraint was not uniqued");
1974 return *constraint;
1975 }
1976 assert(leaf.isOperandMatcher());
1977 return staticVerifierEmitter.getTypeConstraintFn(constraint: leaf.getAsConstraint());
1978}
1979
1980static void emitRewriters(const RecordKeeper &records, raw_ostream &os) {
1981 emitSourceFileHeader(Desc: "Rewriters", OS&: os, Record: records);
1982
1983 auto patterns = records.getAllDerivedDefinitions(ClassName: "Pattern");
1984
1985 // We put the map here because it can be shared among multiple patterns.
1986 RecordOperatorMap recordOpMap;
1987
1988 // Exam all the patterns and generate static matcher for the duplicated
1989 // DagNode.
1990 StaticMatcherHelper staticMatcher(os, records, recordOpMap);
1991 for (const Record *p : patterns)
1992 staticMatcher.addPattern(record: p);
1993 staticMatcher.populateStaticConstraintFunctions(os);
1994 staticMatcher.populateStaticMatchers(os);
1995
1996 std::vector<std::string> rewriterNames;
1997 rewriterNames.reserve(n: patterns.size());
1998
1999 std::string baseRewriterName = "GeneratedConvert";
2000 int rewriterIndex = 0;
2001
2002 for (const Record *p : patterns) {
2003 std::string name;
2004 if (p->isAnonymous()) {
2005 // If no name is provided, ensure unique rewriter names simply by
2006 // appending unique suffix.
2007 name = baseRewriterName + llvm::utostr(X: rewriterIndex++);
2008 } else {
2009 name = std::string(p->getName());
2010 }
2011 LLVM_DEBUG(llvm::dbgs()
2012 << "=== start generating pattern '" << name << "' ===\n");
2013 PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(rewriteName: name);
2014 LLVM_DEBUG(llvm::dbgs()
2015 << "=== done generating pattern '" << name << "' ===\n");
2016 rewriterNames.push_back(x: std::move(name));
2017 }
2018
2019 // Emit function to add the generated matchers to the pattern list.
2020 os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated("
2021 "::mlir::RewritePatternSet &patterns) {\n";
2022 for (const auto &name : rewriterNames) {
2023 os << " patterns.add<" << name << ">(patterns.getContext());\n";
2024 }
2025 os << "}\n";
2026}
2027
2028static mlir::GenRegistration
2029 genRewriters("gen-rewriters", "Generate pattern rewriters",
2030 [](const RecordKeeper &records, raw_ostream &os) {
2031 emitRewriters(records, os);
2032 return false;
2033 });
2034

source code of mlir/tools/mlir-tblgen/RewriterGen.cpp