1//===- PredicateTree.cpp - Predicate tree merging -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PredicateTree.h"
10#include "RootOrdering.h"
11
12#include "mlir/Dialect/PDL/IR/PDL.h"
13#include "mlir/Dialect/PDL/IR/PDLTypes.h"
14#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15#include "mlir/IR/BuiltinOps.h"
16#include "mlir/Interfaces/InferTypeOpInterface.h"
17#include "llvm/ADT/MapVector.h"
18#include "llvm/ADT/SmallPtrSet.h"
19#include "llvm/ADT/TypeSwitch.h"
20#include "llvm/Support/Debug.h"
21#include <queue>
22
23#define DEBUG_TYPE "pdl-predicate-tree"
24
25using namespace mlir;
26using namespace mlir::pdl_to_pdl_interp;
27
28//===----------------------------------------------------------------------===//
29// Predicate List Building
30//===----------------------------------------------------------------------===//
31
32static void getTreePredicates(std::vector<PositionalPredicate> &predList,
33 Value val, PredicateBuilder &builder,
34 DenseMap<Value, Position *> &inputs,
35 Position *pos);
36
37/// Compares the depths of two positions.
38static bool comparePosDepth(Position *lhs, Position *rhs) {
39 return lhs->getOperationDepth() < rhs->getOperationDepth();
40}
41
42/// Returns the number of non-range elements within `values`.
43static unsigned getNumNonRangeValues(ValueRange values) {
44 return llvm::count_if(Range: values.getTypes(),
45 P: [](Type type) { return !isa<pdl::RangeType>(type); });
46}
47
48static void getTreePredicates(std::vector<PositionalPredicate> &predList,
49 Value val, PredicateBuilder &builder,
50 DenseMap<Value, Position *> &inputs,
51 AttributePosition *pos) {
52 assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
53 predList.emplace_back(args&: pos, args: builder.getIsNotNull());
54
55 if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) {
56 // If the attribute has a type or value, add a constraint.
57 if (Value type = attr.getValueType())
58 getTreePredicates(predList, val: type, builder, inputs, pos: builder.getType(pos));
59 else if (Attribute value = attr.getValueAttr())
60 predList.emplace_back(args&: pos, args: builder.getAttributeConstraint(attr: value));
61 }
62}
63
64/// Collect all of the predicates for the given operand position.
65static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
66 Value val, PredicateBuilder &builder,
67 DenseMap<Value, Position *> &inputs,
68 Position *pos) {
69 Type valueType = val.getType();
70 bool isVariadic = isa<pdl::RangeType>(valueType);
71
72 // If this is a typed operand, add a type constraint.
73 TypeSwitch<Operation *>(val.getDefiningOp())
74 .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
75 // Prevent traversal into a null value if the operand has a proper
76 // index.
77 if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
78 cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
79 predList.emplace_back(pos, builder.getIsNotNull());
80
81 if (Value type = op.getValueType())
82 getTreePredicates(predList, type, builder, inputs,
83 builder.getType(pos));
84 })
85 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
86 std::optional<unsigned> index = op.getIndex();
87
88 // Prevent traversal into a null value if the result has a proper index.
89 if (index)
90 predList.emplace_back(pos, builder.getIsNotNull());
91
92 // Get the parent operation of this operand.
93 OperationPosition *parentPos = builder.getOperandDefiningOp(pos);
94 predList.emplace_back(parentPos, builder.getIsNotNull());
95
96 // Ensure that the operands match the corresponding results of the
97 // parent operation.
98 Position *resultPos = nullptr;
99 if (std::is_same<pdl::ResultOp, decltype(op)>::value)
100 resultPos = builder.getResult(parentPos, *index);
101 else
102 resultPos = builder.getResultGroup(parentPos, index, isVariadic);
103 predList.emplace_back(resultPos, builder.getEqualTo(pos));
104
105 // Collect the predicates of the parent operation.
106 getTreePredicates(predList, op.getParent(), builder, inputs,
107 (Position *)parentPos);
108 });
109}
110
111static void
112getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
113 PredicateBuilder &builder,
114 DenseMap<Value, Position *> &inputs, OperationPosition *pos,
115 std::optional<unsigned> ignoreOperand = std::nullopt) {
116 assert(isa<pdl::OperationType>(val.getType()) && "expected operation");
117 pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
118 OperationPosition *opPos = cast<OperationPosition>(Val: pos);
119
120 // Ensure getDefiningOp returns a non-null operation.
121 if (!opPos->isRoot())
122 predList.emplace_back(args&: pos, args: builder.getIsNotNull());
123
124 // Check that this is the correct root operation.
125 if (std::optional<StringRef> opName = op.getOpName())
126 predList.emplace_back(args&: pos, args: builder.getOperationName(name: *opName));
127
128 // Check that the operation has the proper number of operands. If there are
129 // any variable length operands, we check a minimum instead of an exact count.
130 OperandRange operands = op.getOperandValues();
131 unsigned minOperands = getNumNonRangeValues(values: operands);
132 if (minOperands != operands.size()) {
133 if (minOperands)
134 predList.emplace_back(args&: pos, args: builder.getOperandCountAtLeast(count: minOperands));
135 } else {
136 predList.emplace_back(args&: pos, args: builder.getOperandCount(count: minOperands));
137 }
138
139 // Check that the operation has the proper number of results. If there are
140 // any variable length results, we check a minimum instead of an exact count.
141 OperandRange types = op.getTypeValues();
142 unsigned minResults = getNumNonRangeValues(values: types);
143 if (minResults == types.size())
144 predList.emplace_back(args&: pos, args: builder.getResultCount(count: types.size()));
145 else if (minResults)
146 predList.emplace_back(args&: pos, args: builder.getResultCountAtLeast(count: minResults));
147
148 // Recurse into any attributes, operands, or results.
149 for (auto [attrName, attr] :
150 llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) {
151 getTreePredicates(
152 predList, attr, builder, inputs,
153 builder.getAttribute(opPos, cast<StringAttr>(attrName).getValue()));
154 }
155
156 // Process the operands and results of the operation. For all values up to
157 // the first variable length value, we use the concrete operand/result
158 // number. After that, we use the "group" given that we can't know the
159 // concrete indices until runtime. If there is only one variadic operand
160 // group, we treat it as all of the operands/results of the operation.
161 /// Operands.
162 if (operands.size() == 1 && isa<pdl::RangeType>(operands[0].getType())) {
163 // Ignore the operands if we are performing an upward traversal (in that
164 // case, they have already been visited).
165 if (opPos->isRoot() || opPos->isOperandDefiningOp())
166 getTreePredicates(predList, val: operands.front(), builder, inputs,
167 pos: builder.getAllOperands(p: opPos));
168 } else {
169 bool foundVariableLength = false;
170 for (const auto &operandIt : llvm::enumerate(operands)) {
171 bool isVariadic = isa<pdl::RangeType>(operandIt.value().getType());
172 foundVariableLength |= isVariadic;
173
174 // Ignore the specified operand, usually because this position was
175 // visited in an upward traversal via an iterative choice.
176 if (ignoreOperand && *ignoreOperand == operandIt.index())
177 continue;
178
179 Position *pos =
180 foundVariableLength
181 ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
182 : builder.getOperand(opPos, operandIt.index());
183 getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
184 }
185 }
186 /// Results.
187 if (types.size() == 1 && isa<pdl::RangeType>(types[0].getType())) {
188 getTreePredicates(predList, val: types.front(), builder, inputs,
189 pos: builder.getType(p: builder.getAllResults(p: opPos)));
190 return;
191 }
192
193 bool foundVariableLength = false;
194 for (auto [idx, typeValue] : llvm::enumerate(types)) {
195 bool isVariadic = isa<pdl::RangeType>(typeValue.getType());
196 foundVariableLength |= isVariadic;
197
198 auto *resultPos = foundVariableLength
199 ? builder.getResultGroup(pos, idx, isVariadic)
200 : builder.getResult(pos, idx);
201 predList.emplace_back(resultPos, builder.getIsNotNull());
202 getTreePredicates(predList, typeValue, builder, inputs,
203 builder.getType(resultPos));
204 }
205}
206
207static void getTreePredicates(std::vector<PositionalPredicate> &predList,
208 Value val, PredicateBuilder &builder,
209 DenseMap<Value, Position *> &inputs,
210 TypePosition *pos) {
211 // Check for a constraint on a constant type.
212 if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
213 if (Attribute type = typeOp.getConstantTypeAttr())
214 predList.emplace_back(args&: pos, args: builder.getTypeConstraint(type));
215 } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
216 if (Attribute typeAttr = typeOp.getConstantTypesAttr())
217 predList.emplace_back(args&: pos, args: builder.getTypeConstraint(type: typeAttr));
218 }
219}
220
221/// Collect the tree predicates anchored at the given value.
222static void getTreePredicates(std::vector<PositionalPredicate> &predList,
223 Value val, PredicateBuilder &builder,
224 DenseMap<Value, Position *> &inputs,
225 Position *pos) {
226 // Make sure this input value is accessible to the rewrite.
227 auto it = inputs.try_emplace(Key: val, Args&: pos);
228 if (!it.second) {
229 // If this is an input value that has been visited in the tree, add a
230 // constraint to ensure that both instances refer to the same value.
231 if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
232 pdl::TypeOp>(val.getDefiningOp())) {
233 auto minMaxPositions =
234 std::minmax(a: pos, b: it.first->second, comp: comparePosDepth);
235 predList.emplace_back(args: minMaxPositions.second,
236 args: builder.getEqualTo(pos: minMaxPositions.first));
237 }
238 return;
239 }
240
241 TypeSwitch<Position *>(pos)
242 .Case<AttributePosition, OperationPosition, TypePosition>(caseFn: [&](auto *pos) {
243 getTreePredicates(predList, val, builder, inputs, pos);
244 })
245 .Case<OperandPosition, OperandGroupPosition>(caseFn: [&](auto *pos) {
246 getOperandTreePredicates(predList, val, builder, inputs, pos);
247 })
248 .Default(defaultFn: [](auto *) { llvm_unreachable("unexpected position kind"); });
249}
250
251static void getAttributePredicates(pdl::AttributeOp op,
252 std::vector<PositionalPredicate> &predList,
253 PredicateBuilder &builder,
254 DenseMap<Value, Position *> &inputs) {
255 Position *&attrPos = inputs[op];
256 if (attrPos)
257 return;
258 Attribute value = op.getValueAttr();
259 assert(value && "expected non-tree `pdl.attribute` to contain a value");
260 attrPos = builder.getAttributeLiteral(attr: value);
261}
262
263static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
264 std::vector<PositionalPredicate> &predList,
265 PredicateBuilder &builder,
266 DenseMap<Value, Position *> &inputs) {
267 OperandRange arguments = op.getArgs();
268
269 std::vector<Position *> allPositions;
270 allPositions.reserve(n: arguments.size());
271 for (Value arg : arguments)
272 allPositions.push_back(inputs.lookup(arg));
273
274 // Push the constraint to the furthest position.
275 Position *pos = *llvm::max_element(Range&: allPositions, C: comparePosDepth);
276 ResultRange results = op.getResults();
277 PredicateBuilder::Predicate pred = builder.getConstraint(
278 name: op.getName(), args: allPositions, resultTypes: SmallVector<Type>(results.getTypes()),
279 isNegated: op.getIsNegated());
280
281 // For each result register a position so it can be used later
282 for (auto [i, result] : llvm::enumerate(results)) {
283 ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
284 ConstraintPosition *pos = builder.getConstraintPosition(q, i);
285 auto [it, inserted] = inputs.try_emplace(result, pos);
286 // If this is an input value that has been visited in the tree, add a
287 // constraint to ensure that both instances refer to the same value.
288 if (!inserted) {
289 Position *first = pos;
290 Position *second = it->second;
291 if (comparePosDepth(second, first))
292 std::tie(second, first) = std::make_pair(first, second);
293
294 predList.emplace_back(second, builder.getEqualTo(first));
295 }
296 }
297 predList.emplace_back(args&: pos, args&: pred);
298}
299
300static void getResultPredicates(pdl::ResultOp op,
301 std::vector<PositionalPredicate> &predList,
302 PredicateBuilder &builder,
303 DenseMap<Value, Position *> &inputs) {
304 Position *&resultPos = inputs[op];
305 if (resultPos)
306 return;
307
308 // Ensure that the result isn't null.
309 auto *parentPos = cast<OperationPosition>(inputs.lookup(Val: op.getParent()));
310 resultPos = builder.getResult(p: parentPos, result: op.getIndex());
311 predList.emplace_back(args&: resultPos, args: builder.getIsNotNull());
312}
313
314static void getResultPredicates(pdl::ResultsOp op,
315 std::vector<PositionalPredicate> &predList,
316 PredicateBuilder &builder,
317 DenseMap<Value, Position *> &inputs) {
318 Position *&resultPos = inputs[op];
319 if (resultPos)
320 return;
321
322 // Ensure that the result isn't null if the result has an index.
323 auto *parentPos = cast<OperationPosition>(inputs.lookup(Val: op.getParent()));
324 bool isVariadic = isa<pdl::RangeType>(op.getType());
325 std::optional<unsigned> index = op.getIndex();
326 resultPos = builder.getResultGroup(p: parentPos, group: index, isVariadic);
327 if (index)
328 predList.emplace_back(args&: resultPos, args: builder.getIsNotNull());
329}
330
331static void getTypePredicates(Value typeValue,
332 function_ref<Attribute()> typeAttrFn,
333 PredicateBuilder &builder,
334 DenseMap<Value, Position *> &inputs) {
335 Position *&typePos = inputs[typeValue];
336 if (typePos)
337 return;
338 Attribute typeAttr = typeAttrFn();
339 assert(typeAttr &&
340 "expected non-tree `pdl.type`/`pdl.types` to contain a value");
341 typePos = builder.getTypeLiteral(attr: typeAttr);
342}
343
344/// Collect all of the predicates that cannot be determined via walking the
345/// tree.
346static void getNonTreePredicates(pdl::PatternOp pattern,
347 std::vector<PositionalPredicate> &predList,
348 PredicateBuilder &builder,
349 DenseMap<Value, Position *> &inputs) {
350 for (Operation &op : pattern.getBodyRegion().getOps()) {
351 TypeSwitch<Operation *>(&op)
352 .Case([&](pdl::AttributeOp attrOp) {
353 getAttributePredicates(attrOp, predList, builder, inputs);
354 })
355 .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
356 getConstraintPredicates(constraintOp, predList, builder, inputs);
357 })
358 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
359 getResultPredicates(resultOp, predList, builder, inputs);
360 })
361 .Case([&](pdl::TypeOp typeOp) {
362 getTypePredicates(
363 typeOp, [&] { return typeOp.getConstantTypeAttr(); }, builder,
364 inputs);
365 })
366 .Case([&](pdl::TypesOp typeOp) {
367 getTypePredicates(
368 typeOp, [&] { return typeOp.getConstantTypesAttr(); }, builder,
369 inputs);
370 });
371 }
372}
373
374namespace {
375
376/// An op accepting a value at an optional index.
377struct OpIndex {
378 Value parent;
379 std::optional<unsigned> index;
380};
381
382/// The parent and operand index of each operation for each root, stored
383/// as a nested map [root][operation].
384using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>;
385
386} // namespace
387
388/// Given a pattern, determines the set of roots present in this pattern.
389/// These are the operations whose results are not consumed by other operations.
390static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
391 // First, collect all the operations that are used as operands
392 // to other operations. These are not roots by default.
393 DenseSet<Value> used;
394 for (auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) {
395 for (Value operand : operationOp.getOperandValues())
396 TypeSwitch<Operation *>(operand.getDefiningOp())
397 .Case<pdl::ResultOp, pdl::ResultsOp>(
398 [&used](auto resultOp) { used.insert(resultOp.getParent()); });
399 }
400
401 // Remove the specified root from the use set, so that we can
402 // always select it as a root, even if it is used by other operations.
403 if (Value root = pattern.getRewriter().getRoot())
404 used.erase(V: root);
405
406 // Finally, collect all the unused operations.
407 SmallVector<Value> roots;
408 for (Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>())
409 if (!used.contains(operationOp))
410 roots.push_back(operationOp);
411
412 return roots;
413}
414
415/// Given a list of candidate roots, builds the cost graph for connecting them.
416/// The graph is formed by traversing the DAG of operations starting from each
417/// root and marking the depth of each connector value (operand). Then we join
418/// the candidate roots based on the common connector values, taking the one
419/// with the minimum depth. Along the way, we compute, for each candidate root,
420/// a mapping from each operation (in the DAG underneath this root) to its
421/// parent operation and the corresponding operand index.
422static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
423 ParentMaps &parentMaps) {
424
425 // The entry of a queue. The entry consists of the following items:
426 // * the value in the DAG underneath the root;
427 // * the parent of the value;
428 // * the operand index of the value in its parent;
429 // * the depth of the visited value.
430 struct Entry {
431 Entry(Value value, Value parent, std::optional<unsigned> index,
432 unsigned depth)
433 : value(value), parent(parent), index(index), depth(depth) {}
434
435 Value value;
436 Value parent;
437 std::optional<unsigned> index;
438 unsigned depth;
439 };
440
441 // A root of a value and its depth (distance from root to the value).
442 struct RootDepth {
443 Value root;
444 unsigned depth = 0;
445 };
446
447 // Map from candidate connector values to their roots and depths. Using a
448 // small vector with 1 entry because most values belong to a single root.
449 llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
450
451 // Perform a breadth-first traversal of the op DAG rooted at each root.
452 for (Value root : roots) {
453 // The queue of visited values. A value may be present multiple times in
454 // the queue, for multiple parents. We only accept the first occurrence,
455 // which is guaranteed to have the lowest depth.
456 std::queue<Entry> toVisit;
457 toVisit.emplace(args&: root, args: Value(), args: 0, args: 0);
458
459 // The map from value to its parent for the current root.
460 DenseMap<Value, OpIndex> &parentMap = parentMaps[root];
461
462 while (!toVisit.empty()) {
463 Entry entry = toVisit.front();
464 toVisit.pop();
465 // Skip if already visited.
466 if (!parentMap.insert(KV: {entry.value, {.parent: entry.parent, .index: entry.index}}).second)
467 continue;
468
469 // Mark the root and depth of the value.
470 connectorsRootsDepths[entry.value].push_back(Elt: {.root: root, .depth: entry.depth});
471
472 // Traverse the operands of an operation and result ops.
473 // We intentionally do not traverse attributes and types, because those
474 // are expensive to join on.
475 TypeSwitch<Operation *>(entry.value.getDefiningOp())
476 .Case<pdl::OperationOp>([&](auto operationOp) {
477 OperandRange operands = operationOp.getOperandValues();
478 // Special case when we pass all the operands in one range.
479 // For those, the index is empty.
480 if (operands.size() == 1 &&
481 isa<pdl::RangeType>(operands[0].getType())) {
482 toVisit.emplace(operands[0], entry.value, std::nullopt,
483 entry.depth + 1);
484 return;
485 }
486
487 // Default case: visit all the operands.
488 for (const auto &p :
489 llvm::enumerate(operationOp.getOperandValues()))
490 toVisit.emplace(p.value(), entry.value, p.index(),
491 entry.depth + 1);
492 })
493 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
494 toVisit.emplace(resultOp.getParent(), entry.value,
495 resultOp.getIndex(), entry.depth);
496 });
497 }
498 }
499
500 // Now build the cost graph.
501 // This is simply a minimum over all depths for the target root.
502 unsigned nextID = 0;
503 for (const auto &connectorRootsDepths : connectorsRootsDepths) {
504 Value value = connectorRootsDepths.first;
505 ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second;
506 // If there is only one root for this value, this will not trigger
507 // any edges in the cost graph (a perf optimization).
508 if (rootsDepths.size() == 1)
509 continue;
510
511 for (const RootDepth &p : rootsDepths) {
512 for (const RootDepth &q : rootsDepths) {
513 if (&p == &q)
514 continue;
515 // Insert or retrieve the property of edge from p to q.
516 RootOrderingEntry &entry = graph[q.root][p.root];
517 if (!entry.connector /* new edge */ || entry.cost.first > q.depth) {
518 if (!entry.connector)
519 entry.cost.second = nextID++;
520 entry.cost.first = q.depth;
521 entry.connector = value;
522 }
523 }
524 }
525 }
526
527 assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
528 "the pattern contains a candidate root disconnected from the others");
529}
530
531/// Returns true if the operand at the given index needs to be queried using an
532/// operand group, i.e., if it is variadic itself or follows a variadic operand.
533static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
534 OperandRange operands = op.getOperandValues();
535 assert(index < operands.size() && "operand index out of range");
536 for (unsigned i = 0; i <= index; ++i)
537 if (isa<pdl::RangeType>(operands[i].getType()))
538 return true;
539 return false;
540}
541
542/// Visit a node during upward traversal.
543static void visitUpward(std::vector<PositionalPredicate> &predList,
544 OpIndex opIndex, PredicateBuilder &builder,
545 DenseMap<Value, Position *> &valueToPosition,
546 Position *&pos, unsigned rootID) {
547 Value value = opIndex.parent;
548 TypeSwitch<Operation *>(value.getDefiningOp())
549 .Case<pdl::OperationOp>([&](auto operationOp) {
550 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
551
552 // Get users and iterate over them.
553 Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true);
554 Position *foreachPos = builder.getForEach(usersPos, rootID);
555 OperationPosition *opPos = builder.getPassthroughOp(foreachPos);
556
557 // Compare the operand(s) of the user against the input value(s).
558 Position *operandPos;
559 if (!opIndex.index) {
560 // We are querying all the operands of the operation.
561 operandPos = builder.getAllOperands(opPos);
562 } else if (useOperandGroup(operationOp, *opIndex.index)) {
563 // We are querying an operand group.
564 Type type = operationOp.getOperandValues()[*opIndex.index].getType();
565 bool variadic = isa<pdl::RangeType>(type);
566 operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
567 } else {
568 // We are querying an individual operand.
569 operandPos = builder.getOperand(opPos, *opIndex.index);
570 }
571 predList.emplace_back(operandPos, builder.getEqualTo(pos));
572
573 // Guard against duplicate upward visits. These are not possible,
574 // because if this value was already visited, it would have been
575 // cheaper to start the traversal at this value rather than at the
576 // `connector`, violating the optimality of our spanning tree.
577 bool inserted = valueToPosition.try_emplace(value, opPos).second;
578 (void)inserted;
579 assert(inserted && "duplicate upward visit");
580
581 // Obtain the tree predicates at the current value.
582 getTreePredicates(predList, value, builder, valueToPosition, opPos,
583 opIndex.index);
584
585 // Update the position
586 pos = opPos;
587 })
588 .Case<pdl::ResultOp>([&](auto resultOp) {
589 // Traverse up an individual result.
590 auto *opPos = dyn_cast<OperationPosition>(pos);
591 assert(opPos && "operations and results must be interleaved");
592 pos = builder.getResult(opPos, *opIndex.index);
593
594 // Insert the result position in case we have not visited it yet.
595 valueToPosition.try_emplace(value, pos);
596 })
597 .Case<pdl::ResultsOp>([&](auto resultOp) {
598 // Traverse up a group of results.
599 auto *opPos = dyn_cast<OperationPosition>(pos);
600 assert(opPos && "operations and results must be interleaved");
601 bool isVariadic = isa<pdl::RangeType>(value.getType());
602 if (opIndex.index)
603 pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
604 else
605 pos = builder.getAllResults(opPos);
606
607 // Insert the result position in case we have not visited it yet.
608 valueToPosition.try_emplace(value, pos);
609 });
610}
611
612/// Given a pattern operation, build the set of matcher predicates necessary to
613/// match this pattern.
614static Value buildPredicateList(pdl::PatternOp pattern,
615 PredicateBuilder &builder,
616 std::vector<PositionalPredicate> &predList,
617 DenseMap<Value, Position *> &valueToPosition) {
618 SmallVector<Value> roots = detectRoots(pattern);
619
620 // Build the root ordering graph and compute the parent maps.
621 RootOrderingGraph graph;
622 ParentMaps parentMaps;
623 buildCostGraph(roots, graph, parentMaps);
624 LLVM_DEBUG({
625 llvm::dbgs() << "Graph:\n";
626 for (auto &target : graph) {
627 llvm::dbgs() << " * " << target.first.getLoc() << " " << target.first
628 << "\n";
629 for (auto &source : target.second) {
630 RootOrderingEntry &entry = source.second;
631 llvm::dbgs() << " <- " << source.first << ": " << entry.cost.first
632 << ":" << entry.cost.second << " via "
633 << entry.connector.getLoc() << "\n";
634 }
635 }
636 });
637
638 // Solve the optimal branching problem for each candidate root, or use the
639 // provided one.
640 Value bestRoot = pattern.getRewriter().getRoot();
641 OptimalBranching::EdgeList bestEdges;
642 if (!bestRoot) {
643 unsigned bestCost = 0;
644 LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");
645 for (Value root : roots) {
646 OptimalBranching solver(graph, root);
647 unsigned cost = solver.solve();
648 LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n");
649 if (!bestRoot || bestCost > cost) {
650 bestCost = cost;
651 bestRoot = root;
652 bestEdges = solver.preOrderTraversal(nodes: roots);
653 }
654 }
655 } else {
656 OptimalBranching solver(graph, bestRoot);
657 solver.solve();
658 bestEdges = solver.preOrderTraversal(nodes: roots);
659 }
660
661 // Print the best solution.
662 LLVM_DEBUG({
663 llvm::dbgs() << "Best tree:\n";
664 for (const std::pair<Value, Value> &edge : bestEdges) {
665 llvm::dbgs() << " * " << edge.first;
666 if (edge.second)
667 llvm::dbgs() << " <- " << edge.second;
668 llvm::dbgs() << "\n";
669 }
670 });
671
672 LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
673 LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n");
674
675 // The best root is the starting point for the traversal. Get the tree
676 // predicates for the DAG rooted at bestRoot.
677 getTreePredicates(predList, val: bestRoot, builder, inputs&: valueToPosition,
678 pos: builder.getRoot());
679
680 // Traverse the selected optimal branching. For all edges in order, traverse
681 // up starting from the connector, until the candidate root is reached, and
682 // call getTreePredicates at every node along the way.
683 for (const auto &it : llvm::enumerate(First&: bestEdges)) {
684 Value target = it.value().first;
685 Value source = it.value().second;
686
687 // Check if we already visited the target root. This happens in two cases:
688 // 1) the initial root (bestRoot);
689 // 2) a root that is dominated by (contained in the subtree rooted at) an
690 // already visited root.
691 if (valueToPosition.count(Val: target))
692 continue;
693
694 // Determine the connector.
695 Value connector = graph[target][source].connector;
696 assert(connector && "invalid edge");
697 LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n");
698 DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(Val: target);
699 Position *pos = valueToPosition.lookup(Val: connector);
700 assert(pos && "connector has not been traversed yet");
701
702 // Traverse from the connector upwards towards the target root.
703 for (Value value = connector; value != target;) {
704 OpIndex opIndex = parentMap.lookup(Val: value);
705 assert(opIndex.parent && "missing parent");
706 visitUpward(predList, opIndex, builder, valueToPosition, pos, rootID: it.index());
707 value = opIndex.parent;
708 }
709 }
710
711 getNonTreePredicates(pattern, predList, builder, valueToPosition);
712
713 return bestRoot;
714}
715
716//===----------------------------------------------------------------------===//
717// Pattern Predicate Tree Merging
718//===----------------------------------------------------------------------===//
719
720namespace {
721
722/// This class represents a specific predicate applied to a position, and
723/// provides hashing and ordering operators. This class allows for computing a
724/// frequence sum and ordering predicates based on a cost model.
725struct OrderedPredicate {
726 OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
727 : position(ip.first), question(ip.second) {}
728 OrderedPredicate(const PositionalPredicate &ip)
729 : position(ip.position), question(ip.question) {}
730
731 /// The position this predicate is applied to.
732 Position *position;
733
734 /// The question that is applied by this predicate onto the position.
735 Qualifier *question;
736
737 /// The first and second order benefit sums.
738 /// The primary sum is the number of occurrences of this predicate among all
739 /// of the patterns.
740 unsigned primary = 0;
741 /// The secondary sum is a squared summation of the primary sum of all of the
742 /// predicates within each pattern that contains this predicate. This allows
743 /// for favoring predicates that are more commonly shared within a pattern, as
744 /// opposed to those shared across patterns.
745 unsigned secondary = 0;
746
747 /// The tie breaking ID, used to preserve a deterministic (insertion) order
748 /// among all the predicates with the same priority, depth, and position /
749 /// predicate dependency.
750 unsigned id = 0;
751
752 /// A map between a pattern operation and the answer to the predicate question
753 /// within that pattern.
754 DenseMap<Operation *, Qualifier *> patternToAnswer;
755
756 /// Returns true if this predicate is ordered before `rhs`, based on the cost
757 /// model.
758 bool operator<(const OrderedPredicate &rhs) const {
759 // Sort by:
760 // * higher first and secondary order sums
761 // * lower depth
762 // * lower position dependency
763 // * lower predicate dependency
764 // * lower tie breaking ID
765 auto *rhsPos = rhs.position;
766 return std::make_tuple(args: primary, args: secondary, args: rhsPos->getOperationDepth(),
767 args: rhsPos->getKind(), args: rhs.question->getKind(), args: rhs.id) >
768 std::make_tuple(args: rhs.primary, args: rhs.secondary,
769 args: position->getOperationDepth(), args: position->getKind(),
770 args: question->getKind(), args: id);
771 }
772};
773
774/// A DenseMapInfo for OrderedPredicate based solely on the position and
775/// question.
776struct OrderedPredicateDenseInfo {
777 using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>;
778
779 static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
780 static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
781 static bool isEqual(const OrderedPredicate &lhs,
782 const OrderedPredicate &rhs) {
783 return lhs.position == rhs.position && lhs.question == rhs.question;
784 }
785 static unsigned getHashValue(const OrderedPredicate &p) {
786 return llvm::hash_combine(args: p.position, args: p.question);
787 }
788};
789
790/// This class wraps a set of ordered predicates that are used within a specific
791/// pattern operation.
792struct OrderedPredicateList {
793 OrderedPredicateList(pdl::PatternOp pattern, Value root)
794 : pattern(pattern), root(root) {}
795
796 pdl::PatternOp pattern;
797 Value root;
798 DenseSet<OrderedPredicate *> predicates;
799};
800} // namespace
801
802/// Returns true if the given matcher refers to the same predicate as the given
803/// ordered predicate. This means that the position and questions of the two
804/// match.
805static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
806 return node->getPosition() == predicate->position &&
807 node->getQuestion() == predicate->question;
808}
809
810/// Get or insert a child matcher for the given parent switch node, given a
811/// predicate and parent pattern.
812std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
813 OrderedPredicate *predicate,
814 pdl::PatternOp pattern) {
815 assert(isSamePredicate(node, predicate) &&
816 "expected matcher to equal the given predicate");
817
818 auto it = predicate->patternToAnswer.find(pattern);
819 assert(it != predicate->patternToAnswer.end() &&
820 "expected pattern to exist in predicate");
821 return node->getChildren().insert({it->second, nullptr}).first->second;
822}
823
824/// Build the matcher CFG by "pushing" patterns through by sorted predicate
825/// order. A pattern will traverse as far as possible using common predicates
826/// and then either diverge from the CFG or reach the end of a branch and start
827/// creating new nodes.
828static void propagatePattern(std::unique_ptr<MatcherNode> &node,
829 OrderedPredicateList &list,
830 std::vector<OrderedPredicate *>::iterator current,
831 std::vector<OrderedPredicate *>::iterator end) {
832 if (current == end) {
833 // We've hit the end of a pattern, so create a successful result node.
834 node =
835 std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node));
836
837 // If the pattern doesn't contain this predicate, ignore it.
838 } else if (!list.predicates.contains(V: *current)) {
839 propagatePattern(node, list, current: std::next(x: current), end);
840
841 // If the current matcher node is invalid, create a new one for this
842 // position and continue propagation.
843 } else if (!node) {
844 // Create a new node at this position and continue
845 node = std::make_unique<SwitchNode>(args&: (*current)->position,
846 args&: (*current)->question);
847 propagatePattern(
848 getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
849 list, std::next(current), end);
850
851 // If the matcher has already been created, and it is for this predicate we
852 // continue propagation to the child.
853 } else if (isSamePredicate(node: node.get(), predicate: *current)) {
854 propagatePattern(
855 getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
856 list, std::next(current), end);
857
858 // If the matcher doesn't match the current predicate, insert a branch as
859 // the common set of matchers has diverged.
860 } else {
861 propagatePattern(node&: node->getFailureNode(), list, current, end);
862 }
863}
864
865/// Fold any switch nodes nested under `node` to boolean nodes when possible.
866/// `node` is updated in-place if it is a switch.
867static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
868 if (!node)
869 return;
870
871 if (SwitchNode *switchNode = dyn_cast<SwitchNode>(Val: &*node)) {
872 SwitchNode::ChildMapT &children = switchNode->getChildren();
873 for (auto &it : children)
874 foldSwitchToBool(node&: it.second);
875
876 // If the node only contains one child, collapse it into a boolean predicate
877 // node.
878 if (children.size() == 1) {
879 auto *childIt = children.begin();
880 node = std::make_unique<BoolNode>(
881 args: node->getPosition(), args: node->getQuestion(), args&: childIt->first,
882 args: std::move(childIt->second), args: std::move(node->getFailureNode()));
883 }
884 } else if (BoolNode *boolNode = dyn_cast<BoolNode>(Val: &*node)) {
885 foldSwitchToBool(node&: boolNode->getSuccessNode());
886 }
887
888 foldSwitchToBool(node&: node->getFailureNode());
889}
890
891/// Insert an exit node at the end of the failure path of the `root`.
892static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
893 while (*root)
894 root = &(*root)->getFailureNode();
895 *root = std::make_unique<ExitNode>();
896}
897
898/// Sorts the range begin/end with the partial order given by cmp.
899template <typename Iterator, typename Compare>
900static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
901 while (begin != end) {
902 // Cannot compute sortBeforeOthers in the predicate of stable_partition
903 // because stable_partition will not keep the [begin, end) range intact
904 // while it runs.
905 llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers;
906 for (auto i = begin; i != end; ++i) {
907 if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
908 sortBeforeOthers.insert(*i);
909 }
910
911 auto const next = std::stable_partition(begin, end, [&](auto const &a) {
912 return sortBeforeOthers.contains(a);
913 });
914 assert(next != begin && "not a partial ordering");
915 begin = next;
916 }
917}
918
919/// Returns true if 'b' depends on a result of 'a'.
920static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
921 auto *cqa = dyn_cast<ConstraintQuestion>(Val: a->question);
922 if (!cqa)
923 return false;
924
925 auto positionDependsOnA = [&](Position *p) {
926 auto *cp = dyn_cast<ConstraintPosition>(Val: p);
927 return cp && cp->getQuestion() == cqa;
928 };
929
930 if (auto *cqb = dyn_cast<ConstraintQuestion>(Val: b->question)) {
931 // Does any argument of b use a?
932 return llvm::any_of(Range: cqb->getArgs(), P: positionDependsOnA);
933 }
934 if (auto *equalTo = dyn_cast<EqualToQuestion>(Val: b->question)) {
935 return positionDependsOnA(b->position) ||
936 positionDependsOnA(equalTo->getValue());
937 }
938 return positionDependsOnA(b->position);
939}
940
941/// Given a module containing PDL pattern operations, generate a matcher tree
942/// using the patterns within the given module and return the root matcher node.
943std::unique_ptr<MatcherNode>
944MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
945 DenseMap<Value, Position *> &valueToPosition) {
946 // The set of predicates contained within the pattern operations of the
947 // module.
948 struct PatternPredicates {
949 PatternPredicates(pdl::PatternOp pattern, Value root,
950 std::vector<PositionalPredicate> predicates)
951 : pattern(pattern), root(root), predicates(std::move(predicates)) {}
952
953 /// A pattern.
954 pdl::PatternOp pattern;
955
956 /// A root of the pattern chosen among the candidate roots in pdl.rewrite.
957 Value root;
958
959 /// The extracted predicates for this pattern and root.
960 std::vector<PositionalPredicate> predicates;
961 };
962
963 SmallVector<PatternPredicates, 16> patternsAndPredicates;
964 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
965 std::vector<PositionalPredicate> predicateList;
966 Value root =
967 buildPredicateList(pattern, builder, predicateList, valueToPosition);
968 patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
969 }
970
971 // Associate a pattern result with each unique predicate.
972 DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued;
973 for (auto &patternAndPredList : patternsAndPredicates) {
974 for (auto &predicate : patternAndPredList.predicates) {
975 auto it = uniqued.insert(V: predicate);
976 it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
977 predicate.answer);
978 // Mark the insertion order (0-based indexing).
979 if (it.second)
980 it.first->id = uniqued.size() - 1;
981 }
982 }
983
984 // Associate each pattern to a set of its ordered predicates for later lookup.
985 std::vector<OrderedPredicateList> lists;
986 lists.reserve(n: patternsAndPredicates.size());
987 for (auto &patternAndPredList : patternsAndPredicates) {
988 OrderedPredicateList list(patternAndPredList.pattern,
989 patternAndPredList.root);
990 for (auto &predicate : patternAndPredList.predicates) {
991 OrderedPredicate *orderedPredicate = &*uniqued.find(V: predicate);
992 list.predicates.insert(V: orderedPredicate);
993
994 // Increment the primary sum for each reference to a particular predicate.
995 ++orderedPredicate->primary;
996 }
997 lists.push_back(x: std::move(list));
998 }
999
1000 // For a particular pattern, get the total primary sum and add it to the
1001 // secondary sum of each predicate. Square the primary sums to emphasize
1002 // shared predicates within rather than across patterns.
1003 for (auto &list : lists) {
1004 unsigned total = 0;
1005 for (auto *predicate : list.predicates)
1006 total += predicate->primary * predicate->primary;
1007 for (auto *predicate : list.predicates)
1008 predicate->secondary += total;
1009 }
1010
1011 // Sort the set of predicates now that the cost primary and secondary sums
1012 // have been computed.
1013 std::vector<OrderedPredicate *> ordered;
1014 ordered.reserve(n: uniqued.size());
1015 for (auto &ip : uniqued)
1016 ordered.push_back(x: &ip);
1017 llvm::sort(C&: ordered, Comp: [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
1018 return *lhs < *rhs;
1019 });
1020
1021 // Mostly keep the now established order, but also ensure that
1022 // ConstraintQuestions come after the results they use.
1023 stableTopologicalSort(begin: ordered.begin(), end: ordered.end(), cmp: dependsOn);
1024
1025 // Build the matchers for each of the pattern predicate lists.
1026 std::unique_ptr<MatcherNode> root;
1027 for (OrderedPredicateList &list : lists)
1028 propagatePattern(node&: root, list, current: ordered.begin(), end: ordered.end());
1029
1030 // Collapse the graph and insert the exit node.
1031 foldSwitchToBool(node&: root);
1032 insertExitNode(root: &root);
1033 return root;
1034}
1035
1036//===----------------------------------------------------------------------===//
1037// MatcherNode
1038//===----------------------------------------------------------------------===//
1039
1040MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
1041 std::unique_ptr<MatcherNode> failureNode)
1042 : position(p), question(q), failureNode(std::move(failureNode)),
1043 matcherTypeID(matcherTypeID) {}
1044
1045//===----------------------------------------------------------------------===//
1046// BoolNode
1047//===----------------------------------------------------------------------===//
1048
1049BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
1050 std::unique_ptr<MatcherNode> successNode,
1051 std::unique_ptr<MatcherNode> failureNode)
1052 : MatcherNode(TypeID::get<BoolNode>(), position, question,
1053 std::move(failureNode)),
1054 answer(answer), successNode(std::move(successNode)) {}
1055
1056//===----------------------------------------------------------------------===//
1057// SuccessNode
1058//===----------------------------------------------------------------------===//
1059
1060SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root,
1061 std::unique_ptr<MatcherNode> failureNode)
1062 : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
1063 /*question=*/nullptr, std::move(failureNode)),
1064 pattern(pattern), root(root) {}
1065
1066//===----------------------------------------------------------------------===//
1067// SwitchNode
1068//===----------------------------------------------------------------------===//
1069
1070SwitchNode::SwitchNode(Position *position, Qualifier *question)
1071 : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
1072

source code of mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp