1//===- PredicateTree.cpp - Predicate tree merging -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PredicateTree.h"
10#include "RootOrdering.h"
11
12#include "mlir/Dialect/PDL/IR/PDLTypes.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "llvm/ADT/MapVector.h"
15#include "llvm/ADT/SmallPtrSet.h"
16#include "llvm/ADT/TypeSwitch.h"
17#include "llvm/Support/Debug.h"
18#include <queue>
19
20#define DEBUG_TYPE "pdl-predicate-tree"
21
22using namespace mlir;
23using namespace mlir::pdl_to_pdl_interp;
24
25//===----------------------------------------------------------------------===//
26// Predicate List Building
27//===----------------------------------------------------------------------===//
28
29static void getTreePredicates(std::vector<PositionalPredicate> &predList,
30 Value val, PredicateBuilder &builder,
31 DenseMap<Value, Position *> &inputs,
32 Position *pos);
33
34/// Compares the depths of two positions.
35static bool comparePosDepth(Position *lhs, Position *rhs) {
36 return lhs->getOperationDepth() < rhs->getOperationDepth();
37}
38
39/// Returns the number of non-range elements within `values`.
40static unsigned getNumNonRangeValues(ValueRange values) {
41 return llvm::count_if(Range: values.getTypes(),
42 P: [](Type type) { return !isa<pdl::RangeType>(Val: type); });
43}
44
45static void getTreePredicates(std::vector<PositionalPredicate> &predList,
46 Value val, PredicateBuilder &builder,
47 DenseMap<Value, Position *> &inputs,
48 AttributePosition *pos) {
49 assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
50 predList.emplace_back(args&: pos, args: builder.getIsNotNull());
51
52 if (auto attr = dyn_cast<pdl::AttributeOp>(Val: val.getDefiningOp())) {
53 // If the attribute has a type or value, add a constraint.
54 if (Value type = attr.getValueType())
55 getTreePredicates(predList, val: type, builder, inputs, pos: builder.getType(p: pos));
56 else if (Attribute value = attr.getValueAttr())
57 predList.emplace_back(args&: pos, args: builder.getAttributeConstraint(attr: value));
58 }
59}
60
61/// Collect all of the predicates for the given operand position.
62static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
63 Value val, PredicateBuilder &builder,
64 DenseMap<Value, Position *> &inputs,
65 Position *pos) {
66 Type valueType = val.getType();
67 bool isVariadic = isa<pdl::RangeType>(Val: valueType);
68
69 // If this is a typed operand, add a type constraint.
70 TypeSwitch<Operation *>(val.getDefiningOp())
71 .Case<pdl::OperandOp, pdl::OperandsOp>(caseFn: [&](auto op) {
72 // Prevent traversal into a null value if the operand has a proper
73 // index.
74 if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
75 cast<OperandGroupPosition>(Val: pos)->getOperandGroupNumber())
76 predList.emplace_back(args&: pos, args: builder.getIsNotNull());
77
78 if (Value type = op.getValueType())
79 getTreePredicates(predList, val: type, builder, inputs,
80 pos: builder.getType(p: pos));
81 })
82 .Case<pdl::ResultOp, pdl::ResultsOp>(caseFn: [&](auto op) {
83 std::optional<unsigned> index = op.getIndex();
84
85 // Prevent traversal into a null value if the result has a proper index.
86 if (index)
87 predList.emplace_back(args&: pos, args: builder.getIsNotNull());
88
89 // Get the parent operation of this operand.
90 OperationPosition *parentPos = builder.getOperandDefiningOp(p: pos);
91 predList.emplace_back(args&: parentPos, args: builder.getIsNotNull());
92
93 // Ensure that the operands match the corresponding results of the
94 // parent operation.
95 Position *resultPos = nullptr;
96 if (std::is_same<pdl::ResultOp, decltype(op)>::value)
97 resultPos = builder.getResult(p: parentPos, result: *index);
98 else
99 resultPos = builder.getResultGroup(p: parentPos, group: index, isVariadic);
100 predList.emplace_back(args&: resultPos, args: builder.getEqualTo(pos));
101
102 // Collect the predicates of the parent operation.
103 getTreePredicates(predList, op.getParent(), builder, inputs,
104 (Position *)parentPos);
105 });
106}
107
108static void
109getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
110 PredicateBuilder &builder,
111 DenseMap<Value, Position *> &inputs, OperationPosition *pos,
112 std::optional<unsigned> ignoreOperand = std::nullopt) {
113 assert(isa<pdl::OperationType>(val.getType()) && "expected operation");
114 pdl::OperationOp op = cast<pdl::OperationOp>(Val: val.getDefiningOp());
115 OperationPosition *opPos = cast<OperationPosition>(Val: pos);
116
117 // Ensure getDefiningOp returns a non-null operation.
118 if (!opPos->isRoot())
119 predList.emplace_back(args&: pos, args: builder.getIsNotNull());
120
121 // Check that this is the correct root operation.
122 if (std::optional<StringRef> opName = op.getOpName())
123 predList.emplace_back(args&: pos, args: builder.getOperationName(name: *opName));
124
125 // Check that the operation has the proper number of operands. If there are
126 // any variable length operands, we check a minimum instead of an exact count.
127 OperandRange operands = op.getOperandValues();
128 unsigned minOperands = getNumNonRangeValues(values: operands);
129 if (minOperands != operands.size()) {
130 if (minOperands)
131 predList.emplace_back(args&: pos, args: builder.getOperandCountAtLeast(count: minOperands));
132 } else {
133 predList.emplace_back(args&: pos, args: builder.getOperandCount(count: minOperands));
134 }
135
136 // Check that the operation has the proper number of results. If there are
137 // any variable length results, we check a minimum instead of an exact count.
138 OperandRange types = op.getTypeValues();
139 unsigned minResults = getNumNonRangeValues(values: types);
140 if (minResults == types.size())
141 predList.emplace_back(args&: pos, args: builder.getResultCount(count: types.size()));
142 else if (minResults)
143 predList.emplace_back(args&: pos, args: builder.getResultCountAtLeast(count: minResults));
144
145 // Recurse into any attributes, operands, or results.
146 for (auto [attrName, attr] :
147 llvm::zip(t: op.getAttributeValueNames(), u: op.getAttributeValues())) {
148 getTreePredicates(
149 predList, val: attr, builder, inputs,
150 pos: builder.getAttribute(p: opPos, name: cast<StringAttr>(Val: attrName).getValue()));
151 }
152
153 // Process the operands and results of the operation. For all values up to
154 // the first variable length value, we use the concrete operand/result
155 // number. After that, we use the "group" given that we can't know the
156 // concrete indices until runtime. If there is only one variadic operand
157 // group, we treat it as all of the operands/results of the operation.
158 /// Operands.
159 if (operands.size() == 1 && isa<pdl::RangeType>(Val: operands[0].getType())) {
160 // Ignore the operands if we are performing an upward traversal (in that
161 // case, they have already been visited).
162 if (opPos->isRoot() || opPos->isOperandDefiningOp())
163 getTreePredicates(predList, val: operands.front(), builder, inputs,
164 pos: builder.getAllOperands(p: opPos));
165 } else {
166 bool foundVariableLength = false;
167 for (const auto &operandIt : llvm::enumerate(First&: operands)) {
168 bool isVariadic = isa<pdl::RangeType>(Val: operandIt.value().getType());
169 foundVariableLength |= isVariadic;
170
171 // Ignore the specified operand, usually because this position was
172 // visited in an upward traversal via an iterative choice.
173 if (ignoreOperand == operandIt.index())
174 continue;
175
176 Position *pos =
177 foundVariableLength
178 ? builder.getOperandGroup(p: opPos, group: operandIt.index(), isVariadic)
179 : builder.getOperand(p: opPos, operand: operandIt.index());
180 getTreePredicates(predList, val: operandIt.value(), builder, inputs, pos);
181 }
182 }
183 /// Results.
184 if (types.size() == 1 && isa<pdl::RangeType>(Val: types[0].getType())) {
185 getTreePredicates(predList, val: types.front(), builder, inputs,
186 pos: builder.getType(p: builder.getAllResults(p: opPos)));
187 return;
188 }
189
190 bool foundVariableLength = false;
191 for (auto [idx, typeValue] : llvm::enumerate(First&: types)) {
192 bool isVariadic = isa<pdl::RangeType>(Val: typeValue.getType());
193 foundVariableLength |= isVariadic;
194
195 auto *resultPos = foundVariableLength
196 ? builder.getResultGroup(p: pos, group: idx, isVariadic)
197 : builder.getResult(p: pos, result: idx);
198 predList.emplace_back(args&: resultPos, args: builder.getIsNotNull());
199 getTreePredicates(predList, val: typeValue, builder, inputs,
200 pos: builder.getType(p: resultPos));
201 }
202}
203
204static void getTreePredicates(std::vector<PositionalPredicate> &predList,
205 Value val, PredicateBuilder &builder,
206 DenseMap<Value, Position *> &inputs,
207 TypePosition *pos) {
208 // Check for a constraint on a constant type.
209 if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
210 if (Attribute type = typeOp.getConstantTypeAttr())
211 predList.emplace_back(args&: pos, args: builder.getTypeConstraint(type));
212 } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
213 if (Attribute typeAttr = typeOp.getConstantTypesAttr())
214 predList.emplace_back(args&: pos, args: builder.getTypeConstraint(type: typeAttr));
215 }
216}
217
218/// Collect the tree predicates anchored at the given value.
219static void getTreePredicates(std::vector<PositionalPredicate> &predList,
220 Value val, PredicateBuilder &builder,
221 DenseMap<Value, Position *> &inputs,
222 Position *pos) {
223 // Make sure this input value is accessible to the rewrite.
224 auto it = inputs.try_emplace(Key: val, Args&: pos);
225 if (!it.second) {
226 // If this is an input value that has been visited in the tree, add a
227 // constraint to ensure that both instances refer to the same value.
228 if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
229 pdl::TypeOp>(Val: val.getDefiningOp())) {
230 auto minMaxPositions =
231 std::minmax(a: pos, b: it.first->second, comp: comparePosDepth);
232 predList.emplace_back(args: minMaxPositions.second,
233 args: builder.getEqualTo(pos: minMaxPositions.first));
234 }
235 return;
236 }
237
238 TypeSwitch<Position *>(pos)
239 .Case<AttributePosition, OperationPosition, TypePosition>(caseFn: [&](auto *pos) {
240 getTreePredicates(predList, val, builder, inputs, pos);
241 })
242 .Case<OperandPosition, OperandGroupPosition>(caseFn: [&](auto *pos) {
243 getOperandTreePredicates(predList, val, builder, inputs, pos);
244 })
245 .Default(defaultFn: [](auto *) { llvm_unreachable("unexpected position kind"); });
246}
247
248static void getAttributePredicates(pdl::AttributeOp op,
249 std::vector<PositionalPredicate> &predList,
250 PredicateBuilder &builder,
251 DenseMap<Value, Position *> &inputs) {
252 Position *&attrPos = inputs[op];
253 if (attrPos)
254 return;
255 Attribute value = op.getValueAttr();
256 assert(value && "expected non-tree `pdl.attribute` to contain a value");
257 attrPos = builder.getAttributeLiteral(attr: value);
258}
259
260static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
261 std::vector<PositionalPredicate> &predList,
262 PredicateBuilder &builder,
263 DenseMap<Value, Position *> &inputs) {
264 OperandRange arguments = op.getArgs();
265
266 std::vector<Position *> allPositions;
267 allPositions.reserve(n: arguments.size());
268 for (Value arg : arguments)
269 allPositions.push_back(x: inputs.lookup(Val: arg));
270
271 // Push the constraint to the furthest position.
272 Position *pos = *llvm::max_element(Range&: allPositions, C: comparePosDepth);
273 ResultRange results = op.getResults();
274 PredicateBuilder::Predicate pred = builder.getConstraint(
275 name: op.getName(), args: allPositions, resultTypes: SmallVector<Type>(results.getTypes()),
276 isNegated: op.getIsNegated());
277
278 // For each result register a position so it can be used later
279 for (auto [i, result] : llvm::enumerate(First&: results)) {
280 ConstraintQuestion *q = cast<ConstraintQuestion>(Val: pred.first);
281 ConstraintPosition *pos = builder.getConstraintPosition(q, index: i);
282 auto [it, inserted] = inputs.try_emplace(Key: result, Args&: pos);
283 // If this is an input value that has been visited in the tree, add a
284 // constraint to ensure that both instances refer to the same value.
285 if (!inserted) {
286 Position *first = pos;
287 Position *second = it->second;
288 if (comparePosDepth(lhs: second, rhs: first))
289 std::tie(args&: second, args&: first) = std::make_pair(x&: first, y&: second);
290
291 predList.emplace_back(args&: second, args: builder.getEqualTo(pos: first));
292 }
293 }
294 predList.emplace_back(args&: pos, args&: pred);
295}
296
297static void getResultPredicates(pdl::ResultOp op,
298 std::vector<PositionalPredicate> &predList,
299 PredicateBuilder &builder,
300 DenseMap<Value, Position *> &inputs) {
301 Position *&resultPos = inputs[op];
302 if (resultPos)
303 return;
304
305 // Ensure that the result isn't null.
306 auto *parentPos = cast<OperationPosition>(Val: inputs.lookup(Val: op.getParent()));
307 resultPos = builder.getResult(p: parentPos, result: op.getIndex());
308 predList.emplace_back(args&: resultPos, args: builder.getIsNotNull());
309}
310
311static void getResultPredicates(pdl::ResultsOp op,
312 std::vector<PositionalPredicate> &predList,
313 PredicateBuilder &builder,
314 DenseMap<Value, Position *> &inputs) {
315 Position *&resultPos = inputs[op];
316 if (resultPos)
317 return;
318
319 // Ensure that the result isn't null if the result has an index.
320 auto *parentPos = cast<OperationPosition>(Val: inputs.lookup(Val: op.getParent()));
321 bool isVariadic = isa<pdl::RangeType>(Val: op.getType());
322 std::optional<unsigned> index = op.getIndex();
323 resultPos = builder.getResultGroup(p: parentPos, group: index, isVariadic);
324 if (index)
325 predList.emplace_back(args&: resultPos, args: builder.getIsNotNull());
326}
327
328static void getTypePredicates(Value typeValue,
329 function_ref<Attribute()> typeAttrFn,
330 PredicateBuilder &builder,
331 DenseMap<Value, Position *> &inputs) {
332 Position *&typePos = inputs[typeValue];
333 if (typePos)
334 return;
335 Attribute typeAttr = typeAttrFn();
336 assert(typeAttr &&
337 "expected non-tree `pdl.type`/`pdl.types` to contain a value");
338 typePos = builder.getTypeLiteral(attr: typeAttr);
339}
340
341/// Collect all of the predicates that cannot be determined via walking the
342/// tree.
343static void getNonTreePredicates(pdl::PatternOp pattern,
344 std::vector<PositionalPredicate> &predList,
345 PredicateBuilder &builder,
346 DenseMap<Value, Position *> &inputs) {
347 for (Operation &op : pattern.getBodyRegion().getOps()) {
348 TypeSwitch<Operation *>(&op)
349 .Case(caseFn: [&](pdl::AttributeOp attrOp) {
350 getAttributePredicates(op: attrOp, predList, builder, inputs);
351 })
352 .Case<pdl::ApplyNativeConstraintOp>(caseFn: [&](auto constraintOp) {
353 getConstraintPredicates(constraintOp, predList, builder, inputs);
354 })
355 .Case<pdl::ResultOp, pdl::ResultsOp>(caseFn: [&](auto resultOp) {
356 getResultPredicates(resultOp, predList, builder, inputs);
357 })
358 .Case(caseFn: [&](pdl::TypeOp typeOp) {
359 getTypePredicates(
360 typeValue: typeOp, typeAttrFn: [&] { return typeOp.getConstantTypeAttr(); }, builder,
361 inputs);
362 })
363 .Case(caseFn: [&](pdl::TypesOp typeOp) {
364 getTypePredicates(
365 typeValue: typeOp, typeAttrFn: [&] { return typeOp.getConstantTypesAttr(); }, builder,
366 inputs);
367 });
368 }
369}
370
371namespace {
372
373/// An op accepting a value at an optional index.
374struct OpIndex {
375 Value parent;
376 std::optional<unsigned> index;
377};
378
379/// The parent and operand index of each operation for each root, stored
380/// as a nested map [root][operation].
381using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>;
382
383} // namespace
384
385/// Given a pattern, determines the set of roots present in this pattern.
386/// These are the operations whose results are not consumed by other operations.
387static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
388 // First, collect all the operations that are used as operands
389 // to other operations. These are not roots by default.
390 DenseSet<Value> used;
391 for (auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) {
392 for (Value operand : operationOp.getOperandValues())
393 TypeSwitch<Operation *>(operand.getDefiningOp())
394 .Case<pdl::ResultOp, pdl::ResultsOp>(
395 caseFn: [&used](auto resultOp) { used.insert(resultOp.getParent()); });
396 }
397
398 // Remove the specified root from the use set, so that we can
399 // always select it as a root, even if it is used by other operations.
400 if (Value root = pattern.getRewriter().getRoot())
401 used.erase(V: root);
402
403 // Finally, collect all the unused operations.
404 SmallVector<Value> roots;
405 for (Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>())
406 if (!used.contains(V: operationOp))
407 roots.push_back(Elt: operationOp);
408
409 return roots;
410}
411
412/// Given a list of candidate roots, builds the cost graph for connecting them.
413/// The graph is formed by traversing the DAG of operations starting from each
414/// root and marking the depth of each connector value (operand). Then we join
415/// the candidate roots based on the common connector values, taking the one
416/// with the minimum depth. Along the way, we compute, for each candidate root,
417/// a mapping from each operation (in the DAG underneath this root) to its
418/// parent operation and the corresponding operand index.
419static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
420 ParentMaps &parentMaps) {
421
422 // The entry of a queue. The entry consists of the following items:
423 // * the value in the DAG underneath the root;
424 // * the parent of the value;
425 // * the operand index of the value in its parent;
426 // * the depth of the visited value.
427 struct Entry {
428 Entry(Value value, Value parent, std::optional<unsigned> index,
429 unsigned depth)
430 : value(value), parent(parent), index(index), depth(depth) {}
431
432 Value value;
433 Value parent;
434 std::optional<unsigned> index;
435 unsigned depth;
436 };
437
438 // A root of a value and its depth (distance from root to the value).
439 struct RootDepth {
440 Value root;
441 unsigned depth = 0;
442 };
443
444 // Map from candidate connector values to their roots and depths. Using a
445 // small vector with 1 entry because most values belong to a single root.
446 llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
447
448 // Perform a breadth-first traversal of the op DAG rooted at each root.
449 for (Value root : roots) {
450 // The queue of visited values. A value may be present multiple times in
451 // the queue, for multiple parents. We only accept the first occurrence,
452 // which is guaranteed to have the lowest depth.
453 std::queue<Entry> toVisit;
454 toVisit.emplace(args&: root, args: Value(), args: 0, args: 0);
455
456 // The map from value to its parent for the current root.
457 DenseMap<Value, OpIndex> &parentMap = parentMaps[root];
458
459 while (!toVisit.empty()) {
460 Entry entry = toVisit.front();
461 toVisit.pop();
462 // Skip if already visited.
463 if (!parentMap.insert(KV: {entry.value, {.parent: entry.parent, .index: entry.index}}).second)
464 continue;
465
466 // Mark the root and depth of the value.
467 connectorsRootsDepths[entry.value].push_back(Elt: {.root: root, .depth: entry.depth});
468
469 // Traverse the operands of an operation and result ops.
470 // We intentionally do not traverse attributes and types, because those
471 // are expensive to join on.
472 TypeSwitch<Operation *>(entry.value.getDefiningOp())
473 .Case<pdl::OperationOp>(caseFn: [&](auto operationOp) {
474 OperandRange operands = operationOp.getOperandValues();
475 // Special case when we pass all the operands in one range.
476 // For those, the index is empty.
477 if (operands.size() == 1 &&
478 isa<pdl::RangeType>(Val: operands[0].getType())) {
479 toVisit.emplace(args: operands[0], args&: entry.value, args: std::nullopt,
480 args: entry.depth + 1);
481 return;
482 }
483
484 // Default case: visit all the operands.
485 for (const auto &p :
486 llvm::enumerate(operationOp.getOperandValues()))
487 toVisit.emplace(p.value(), entry.value, p.index(),
488 entry.depth + 1);
489 })
490 .Case<pdl::ResultOp, pdl::ResultsOp>(caseFn: [&](auto resultOp) {
491 toVisit.emplace(resultOp.getParent(), entry.value,
492 resultOp.getIndex(), entry.depth);
493 });
494 }
495 }
496
497 // Now build the cost graph.
498 // This is simply a minimum over all depths for the target root.
499 unsigned nextID = 0;
500 for (const auto &connectorRootsDepths : connectorsRootsDepths) {
501 Value value = connectorRootsDepths.first;
502 ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second;
503 // If there is only one root for this value, this will not trigger
504 // any edges in the cost graph (a perf optimization).
505 if (rootsDepths.size() == 1)
506 continue;
507
508 for (const RootDepth &p : rootsDepths) {
509 for (const RootDepth &q : rootsDepths) {
510 if (&p == &q)
511 continue;
512 // Insert or retrieve the property of edge from p to q.
513 RootOrderingEntry &entry = graph[q.root][p.root];
514 if (!entry.connector /* new edge */ || entry.cost.first > q.depth) {
515 if (!entry.connector)
516 entry.cost.second = nextID++;
517 entry.cost.first = q.depth;
518 entry.connector = value;
519 }
520 }
521 }
522 }
523
524 assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
525 "the pattern contains a candidate root disconnected from the others");
526}
527
528/// Returns true if the operand at the given index needs to be queried using an
529/// operand group, i.e., if it is variadic itself or follows a variadic operand.
530static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
531 OperandRange operands = op.getOperandValues();
532 assert(index < operands.size() && "operand index out of range");
533 for (unsigned i = 0; i <= index; ++i)
534 if (isa<pdl::RangeType>(Val: operands[i].getType()))
535 return true;
536 return false;
537}
538
539/// Visit a node during upward traversal.
540static void visitUpward(std::vector<PositionalPredicate> &predList,
541 OpIndex opIndex, PredicateBuilder &builder,
542 DenseMap<Value, Position *> &valueToPosition,
543 Position *&pos, unsigned rootID) {
544 Value value = opIndex.parent;
545 TypeSwitch<Operation *>(value.getDefiningOp())
546 .Case<pdl::OperationOp>(caseFn: [&](auto operationOp) {
547 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
548
549 // Get users and iterate over them.
550 Position *usersPos = builder.getUsers(p: pos, /*useRepresentative=*/true);
551 Position *foreachPos = builder.getForEach(p: usersPos, id: rootID);
552 OperationPosition *opPos = builder.getPassthroughOp(p: foreachPos);
553
554 // Compare the operand(s) of the user against the input value(s).
555 Position *operandPos;
556 if (!opIndex.index) {
557 // We are querying all the operands of the operation.
558 operandPos = builder.getAllOperands(p: opPos);
559 } else if (useOperandGroup(operationOp, *opIndex.index)) {
560 // We are querying an operand group.
561 Type type = operationOp.getOperandValues()[*opIndex.index].getType();
562 bool variadic = isa<pdl::RangeType>(Val: type);
563 operandPos = builder.getOperandGroup(p: opPos, group: opIndex.index, isVariadic: variadic);
564 } else {
565 // We are querying an individual operand.
566 operandPos = builder.getOperand(p: opPos, operand: *opIndex.index);
567 }
568 predList.emplace_back(args&: operandPos, args: builder.getEqualTo(pos));
569
570 // Guard against duplicate upward visits. These are not possible,
571 // because if this value was already visited, it would have been
572 // cheaper to start the traversal at this value rather than at the
573 // `connector`, violating the optimality of our spanning tree.
574 bool inserted = valueToPosition.try_emplace(Key: value, Args&: opPos).second;
575 (void)inserted;
576 assert(inserted && "duplicate upward visit");
577
578 // Obtain the tree predicates at the current value.
579 getTreePredicates(predList, val: value, builder, inputs&: valueToPosition, pos: opPos,
580 ignoreOperand: opIndex.index);
581
582 // Update the position
583 pos = opPos;
584 })
585 .Case<pdl::ResultOp>(caseFn: [&](auto resultOp) {
586 // Traverse up an individual result.
587 auto *opPos = dyn_cast<OperationPosition>(Val: pos);
588 assert(opPos && "operations and results must be interleaved");
589 pos = builder.getResult(p: opPos, result: *opIndex.index);
590
591 // Insert the result position in case we have not visited it yet.
592 valueToPosition.try_emplace(Key: value, Args&: pos);
593 })
594 .Case<pdl::ResultsOp>(caseFn: [&](auto resultOp) {
595 // Traverse up a group of results.
596 auto *opPos = dyn_cast<OperationPosition>(Val: pos);
597 assert(opPos && "operations and results must be interleaved");
598 bool isVariadic = isa<pdl::RangeType>(Val: value.getType());
599 if (opIndex.index)
600 pos = builder.getResultGroup(p: opPos, group: opIndex.index, isVariadic);
601 else
602 pos = builder.getAllResults(p: opPos);
603
604 // Insert the result position in case we have not visited it yet.
605 valueToPosition.try_emplace(Key: value, Args&: pos);
606 });
607}
608
609/// Given a pattern operation, build the set of matcher predicates necessary to
610/// match this pattern.
611static Value buildPredicateList(pdl::PatternOp pattern,
612 PredicateBuilder &builder,
613 std::vector<PositionalPredicate> &predList,
614 DenseMap<Value, Position *> &valueToPosition) {
615 SmallVector<Value> roots = detectRoots(pattern);
616
617 // Build the root ordering graph and compute the parent maps.
618 RootOrderingGraph graph;
619 ParentMaps parentMaps;
620 buildCostGraph(roots, graph, parentMaps);
621 LLVM_DEBUG({
622 llvm::dbgs() << "Graph:\n";
623 for (auto &target : graph) {
624 llvm::dbgs() << " * " << target.first.getLoc() << " " << target.first
625 << "\n";
626 for (auto &source : target.second) {
627 RootOrderingEntry &entry = source.second;
628 llvm::dbgs() << " <- " << source.first << ": " << entry.cost.first
629 << ":" << entry.cost.second << " via "
630 << entry.connector.getLoc() << "\n";
631 }
632 }
633 });
634
635 // Solve the optimal branching problem for each candidate root, or use the
636 // provided one.
637 Value bestRoot = pattern.getRewriter().getRoot();
638 OptimalBranching::EdgeList bestEdges;
639 if (!bestRoot) {
640 unsigned bestCost = 0;
641 LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");
642 for (Value root : roots) {
643 OptimalBranching solver(graph, root);
644 unsigned cost = solver.solve();
645 LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n");
646 if (!bestRoot || bestCost > cost) {
647 bestCost = cost;
648 bestRoot = root;
649 bestEdges = solver.preOrderTraversal(nodes: roots);
650 }
651 }
652 } else {
653 OptimalBranching solver(graph, bestRoot);
654 solver.solve();
655 bestEdges = solver.preOrderTraversal(nodes: roots);
656 }
657
658 // Print the best solution.
659 LLVM_DEBUG({
660 llvm::dbgs() << "Best tree:\n";
661 for (const std::pair<Value, Value> &edge : bestEdges) {
662 llvm::dbgs() << " * " << edge.first;
663 if (edge.second)
664 llvm::dbgs() << " <- " << edge.second;
665 llvm::dbgs() << "\n";
666 }
667 });
668
669 LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
670 LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n");
671
672 // The best root is the starting point for the traversal. Get the tree
673 // predicates for the DAG rooted at bestRoot.
674 getTreePredicates(predList, val: bestRoot, builder, inputs&: valueToPosition,
675 pos: builder.getRoot());
676
677 // Traverse the selected optimal branching. For all edges in order, traverse
678 // up starting from the connector, until the candidate root is reached, and
679 // call getTreePredicates at every node along the way.
680 for (const auto &it : llvm::enumerate(First&: bestEdges)) {
681 Value target = it.value().first;
682 Value source = it.value().second;
683
684 // Check if we already visited the target root. This happens in two cases:
685 // 1) the initial root (bestRoot);
686 // 2) a root that is dominated by (contained in the subtree rooted at) an
687 // already visited root.
688 if (valueToPosition.count(Val: target))
689 continue;
690
691 // Determine the connector.
692 Value connector = graph[target][source].connector;
693 assert(connector && "invalid edge");
694 LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n");
695 DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(Val: target);
696 Position *pos = valueToPosition.lookup(Val: connector);
697 assert(pos && "connector has not been traversed yet");
698
699 // Traverse from the connector upwards towards the target root.
700 for (Value value = connector; value != target;) {
701 OpIndex opIndex = parentMap.lookup(Val: value);
702 assert(opIndex.parent && "missing parent");
703 visitUpward(predList, opIndex, builder, valueToPosition, pos, rootID: it.index());
704 value = opIndex.parent;
705 }
706 }
707
708 getNonTreePredicates(pattern, predList, builder, inputs&: valueToPosition);
709
710 return bestRoot;
711}
712
713//===----------------------------------------------------------------------===//
714// Pattern Predicate Tree Merging
715//===----------------------------------------------------------------------===//
716
717namespace {
718
719/// This class represents a specific predicate applied to a position, and
720/// provides hashing and ordering operators. This class allows for computing a
721/// frequence sum and ordering predicates based on a cost model.
722struct OrderedPredicate {
723 OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
724 : position(ip.first), question(ip.second) {}
725 OrderedPredicate(const PositionalPredicate &ip)
726 : position(ip.position), question(ip.question) {}
727
728 /// The position this predicate is applied to.
729 Position *position;
730
731 /// The question that is applied by this predicate onto the position.
732 Qualifier *question;
733
734 /// The first and second order benefit sums.
735 /// The primary sum is the number of occurrences of this predicate among all
736 /// of the patterns.
737 unsigned primary = 0;
738 /// The secondary sum is a squared summation of the primary sum of all of the
739 /// predicates within each pattern that contains this predicate. This allows
740 /// for favoring predicates that are more commonly shared within a pattern, as
741 /// opposed to those shared across patterns.
742 unsigned secondary = 0;
743
744 /// The tie breaking ID, used to preserve a deterministic (insertion) order
745 /// among all the predicates with the same priority, depth, and position /
746 /// predicate dependency.
747 unsigned id = 0;
748
749 /// A map between a pattern operation and the answer to the predicate question
750 /// within that pattern.
751 DenseMap<Operation *, Qualifier *> patternToAnswer;
752
753 /// Returns true if this predicate is ordered before `rhs`, based on the cost
754 /// model.
755 bool operator<(const OrderedPredicate &rhs) const {
756 // Sort by:
757 // * higher first and secondary order sums
758 // * lower depth
759 // * lower position dependency
760 // * lower predicate dependency
761 // * lower tie breaking ID
762 auto *rhsPos = rhs.position;
763 return std::make_tuple(args: primary, args: secondary, args: rhsPos->getOperationDepth(),
764 args: rhsPos->getKind(), args: rhs.question->getKind(), args: rhs.id) >
765 std::make_tuple(args: rhs.primary, args: rhs.secondary,
766 args: position->getOperationDepth(), args: position->getKind(),
767 args: question->getKind(), args: id);
768 }
769};
770
771/// A DenseMapInfo for OrderedPredicate based solely on the position and
772/// question.
773struct OrderedPredicateDenseInfo {
774 using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>;
775
776 static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
777 static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
778 static bool isEqual(const OrderedPredicate &lhs,
779 const OrderedPredicate &rhs) {
780 return lhs.position == rhs.position && lhs.question == rhs.question;
781 }
782 static unsigned getHashValue(const OrderedPredicate &p) {
783 return llvm::hash_combine(args: p.position, args: p.question);
784 }
785};
786
787/// This class wraps a set of ordered predicates that are used within a specific
788/// pattern operation.
789struct OrderedPredicateList {
790 OrderedPredicateList(pdl::PatternOp pattern, Value root)
791 : pattern(pattern), root(root) {}
792
793 pdl::PatternOp pattern;
794 Value root;
795 DenseSet<OrderedPredicate *> predicates;
796};
797} // namespace
798
799/// Returns true if the given matcher refers to the same predicate as the given
800/// ordered predicate. This means that the position and questions of the two
801/// match.
802static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
803 return node->getPosition() == predicate->position &&
804 node->getQuestion() == predicate->question;
805}
806
807/// Get or insert a child matcher for the given parent switch node, given a
808/// predicate and parent pattern.
809std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
810 OrderedPredicate *predicate,
811 pdl::PatternOp pattern) {
812 assert(isSamePredicate(node, predicate) &&
813 "expected matcher to equal the given predicate");
814
815 auto it = predicate->patternToAnswer.find(Val: pattern);
816 assert(it != predicate->patternToAnswer.end() &&
817 "expected pattern to exist in predicate");
818 return node->getChildren()[it->second];
819}
820
821/// Build the matcher CFG by "pushing" patterns through by sorted predicate
822/// order. A pattern will traverse as far as possible using common predicates
823/// and then either diverge from the CFG or reach the end of a branch and start
824/// creating new nodes.
825static void propagatePattern(std::unique_ptr<MatcherNode> &node,
826 OrderedPredicateList &list,
827 std::vector<OrderedPredicate *>::iterator current,
828 std::vector<OrderedPredicate *>::iterator end) {
829 if (current == end) {
830 // We've hit the end of a pattern, so create a successful result node.
831 node =
832 std::make_unique<SuccessNode>(args&: list.pattern, args&: list.root, args: std::move(node));
833
834 // If the pattern doesn't contain this predicate, ignore it.
835 } else if (!list.predicates.contains(V: *current)) {
836 propagatePattern(node, list, current: std::next(x: current), end);
837
838 // If the current matcher node is invalid, create a new one for this
839 // position and continue propagation.
840 } else if (!node) {
841 // Create a new node at this position and continue
842 node = std::make_unique<SwitchNode>(args&: (*current)->position,
843 args&: (*current)->question);
844 propagatePattern(
845 node&: getOrCreateChild(node: cast<SwitchNode>(Val: &*node), predicate: *current, pattern: list.pattern),
846 list, current: std::next(x: current), end);
847
848 // If the matcher has already been created, and it is for this predicate we
849 // continue propagation to the child.
850 } else if (isSamePredicate(node: node.get(), predicate: *current)) {
851 propagatePattern(
852 node&: getOrCreateChild(node: cast<SwitchNode>(Val: &*node), predicate: *current, pattern: list.pattern),
853 list, current: std::next(x: current), end);
854
855 // If the matcher doesn't match the current predicate, insert a branch as
856 // the common set of matchers has diverged.
857 } else {
858 propagatePattern(node&: node->getFailureNode(), list, current, end);
859 }
860}
861
862/// Fold any switch nodes nested under `node` to boolean nodes when possible.
863/// `node` is updated in-place if it is a switch.
864static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
865 if (!node)
866 return;
867
868 if (SwitchNode *switchNode = dyn_cast<SwitchNode>(Val: &*node)) {
869 SwitchNode::ChildMapT &children = switchNode->getChildren();
870 for (auto &it : children)
871 foldSwitchToBool(node&: it.second);
872
873 // If the node only contains one child, collapse it into a boolean predicate
874 // node.
875 if (children.size() == 1) {
876 auto *childIt = children.begin();
877 node = std::make_unique<BoolNode>(
878 args: node->getPosition(), args: node->getQuestion(), args&: childIt->first,
879 args: std::move(childIt->second), args: std::move(node->getFailureNode()));
880 }
881 } else if (BoolNode *boolNode = dyn_cast<BoolNode>(Val: &*node)) {
882 foldSwitchToBool(node&: boolNode->getSuccessNode());
883 }
884
885 foldSwitchToBool(node&: node->getFailureNode());
886}
887
888/// Insert an exit node at the end of the failure path of the `root`.
889static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
890 while (*root)
891 root = &(*root)->getFailureNode();
892 *root = std::make_unique<ExitNode>();
893}
894
895/// Sorts the range begin/end with the partial order given by cmp.
896template <typename Iterator, typename Compare>
897static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
898 while (begin != end) {
899 // Cannot compute sortBeforeOthers in the predicate of stable_partition
900 // because stable_partition will not keep the [begin, end) range intact
901 // while it runs.
902 llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers;
903 for (auto i = begin; i != end; ++i) {
904 if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
905 sortBeforeOthers.insert(*i);
906 }
907
908 auto const next = std::stable_partition(begin, end, [&](auto const &a) {
909 return sortBeforeOthers.contains(a);
910 });
911 assert(next != begin && "not a partial ordering");
912 begin = next;
913 }
914}
915
916/// Returns true if 'b' depends on a result of 'a'.
917static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
918 auto *cqa = dyn_cast<ConstraintQuestion>(Val: a->question);
919 if (!cqa)
920 return false;
921
922 auto positionDependsOnA = [&](Position *p) {
923 auto *cp = dyn_cast<ConstraintPosition>(Val: p);
924 return cp && cp->getQuestion() == cqa;
925 };
926
927 if (auto *cqb = dyn_cast<ConstraintQuestion>(Val: b->question)) {
928 // Does any argument of b use a?
929 return llvm::any_of(Range: cqb->getArgs(), P: positionDependsOnA);
930 }
931 if (auto *equalTo = dyn_cast<EqualToQuestion>(Val: b->question)) {
932 return positionDependsOnA(b->position) ||
933 positionDependsOnA(equalTo->getValue());
934 }
935 return positionDependsOnA(b->position);
936}
937
938/// Given a module containing PDL pattern operations, generate a matcher tree
939/// using the patterns within the given module and return the root matcher node.
940std::unique_ptr<MatcherNode>
941MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
942 DenseMap<Value, Position *> &valueToPosition) {
943 // The set of predicates contained within the pattern operations of the
944 // module.
945 struct PatternPredicates {
946 PatternPredicates(pdl::PatternOp pattern, Value root,
947 std::vector<PositionalPredicate> predicates)
948 : pattern(pattern), root(root), predicates(std::move(predicates)) {}
949
950 /// A pattern.
951 pdl::PatternOp pattern;
952
953 /// A root of the pattern chosen among the candidate roots in pdl.rewrite.
954 Value root;
955
956 /// The extracted predicates for this pattern and root.
957 std::vector<PositionalPredicate> predicates;
958 };
959
960 SmallVector<PatternPredicates, 16> patternsAndPredicates;
961 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
962 std::vector<PositionalPredicate> predicateList;
963 Value root =
964 buildPredicateList(pattern, builder, predList&: predicateList, valueToPosition);
965 patternsAndPredicates.emplace_back(Args&: pattern, Args&: root, Args: std::move(predicateList));
966 }
967
968 // Associate a pattern result with each unique predicate.
969 DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued;
970 for (auto &patternAndPredList : patternsAndPredicates) {
971 for (auto &predicate : patternAndPredList.predicates) {
972 auto it = uniqued.insert(V: predicate);
973 it.first->patternToAnswer.try_emplace(Key: patternAndPredList.pattern,
974 Args&: predicate.answer);
975 // Mark the insertion order (0-based indexing).
976 if (it.second)
977 it.first->id = uniqued.size() - 1;
978 }
979 }
980
981 // Associate each pattern to a set of its ordered predicates for later lookup.
982 std::vector<OrderedPredicateList> lists;
983 lists.reserve(n: patternsAndPredicates.size());
984 for (auto &patternAndPredList : patternsAndPredicates) {
985 OrderedPredicateList list(patternAndPredList.pattern,
986 patternAndPredList.root);
987 for (auto &predicate : patternAndPredList.predicates) {
988 OrderedPredicate *orderedPredicate = &*uniqued.find(V: predicate);
989 list.predicates.insert(V: orderedPredicate);
990
991 // Increment the primary sum for each reference to a particular predicate.
992 ++orderedPredicate->primary;
993 }
994 lists.push_back(x: std::move(list));
995 }
996
997 // For a particular pattern, get the total primary sum and add it to the
998 // secondary sum of each predicate. Square the primary sums to emphasize
999 // shared predicates within rather than across patterns.
1000 for (auto &list : lists) {
1001 unsigned total = 0;
1002 for (auto *predicate : list.predicates)
1003 total += predicate->primary * predicate->primary;
1004 for (auto *predicate : list.predicates)
1005 predicate->secondary += total;
1006 }
1007
1008 // Sort the set of predicates now that the cost primary and secondary sums
1009 // have been computed.
1010 std::vector<OrderedPredicate *> ordered;
1011 ordered.reserve(n: uniqued.size());
1012 for (auto &ip : uniqued)
1013 ordered.push_back(x: &ip);
1014 llvm::sort(C&: ordered, Comp: [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
1015 return *lhs < *rhs;
1016 });
1017
1018 // Mostly keep the now established order, but also ensure that
1019 // ConstraintQuestions come after the results they use.
1020 stableTopologicalSort(begin: ordered.begin(), end: ordered.end(), cmp: dependsOn);
1021
1022 // Build the matchers for each of the pattern predicate lists.
1023 std::unique_ptr<MatcherNode> root;
1024 for (OrderedPredicateList &list : lists)
1025 propagatePattern(node&: root, list, current: ordered.begin(), end: ordered.end());
1026
1027 // Collapse the graph and insert the exit node.
1028 foldSwitchToBool(node&: root);
1029 insertExitNode(root: &root);
1030 return root;
1031}
1032
1033//===----------------------------------------------------------------------===//
1034// MatcherNode
1035//===----------------------------------------------------------------------===//
1036
1037MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
1038 std::unique_ptr<MatcherNode> failureNode)
1039 : position(p), question(q), failureNode(std::move(failureNode)),
1040 matcherTypeID(matcherTypeID) {}
1041
1042//===----------------------------------------------------------------------===//
1043// BoolNode
1044//===----------------------------------------------------------------------===//
1045
1046BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
1047 std::unique_ptr<MatcherNode> successNode,
1048 std::unique_ptr<MatcherNode> failureNode)
1049 : MatcherNode(TypeID::get<BoolNode>(), position, question,
1050 std::move(failureNode)),
1051 answer(answer), successNode(std::move(successNode)) {}
1052
1053//===----------------------------------------------------------------------===//
1054// SuccessNode
1055//===----------------------------------------------------------------------===//
1056
1057SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root,
1058 std::unique_ptr<MatcherNode> failureNode)
1059 : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
1060 /*question=*/nullptr, std::move(failureNode)),
1061 pattern(pattern), root(root) {}
1062
1063//===----------------------------------------------------------------------===//
1064// SwitchNode
1065//===----------------------------------------------------------------------===//
1066
1067SwitchNode::SwitchNode(Position *position, Qualifier *question)
1068 : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
1069

source code of mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp