1 | //===- PredicateTree.cpp - Predicate tree merging -------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "PredicateTree.h" |
10 | #include "RootOrdering.h" |
11 | |
12 | #include "mlir/Dialect/PDL/IR/PDL.h" |
13 | #include "mlir/Dialect/PDL/IR/PDLTypes.h" |
14 | #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" |
15 | #include "mlir/IR/BuiltinOps.h" |
16 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
17 | #include "llvm/ADT/MapVector.h" |
18 | #include "llvm/ADT/SmallPtrSet.h" |
19 | #include "llvm/ADT/TypeSwitch.h" |
20 | #include "llvm/Support/Debug.h" |
21 | #include <queue> |
22 | |
23 | #define DEBUG_TYPE "pdl-predicate-tree" |
24 | |
25 | using namespace mlir; |
26 | using namespace mlir::pdl_to_pdl_interp; |
27 | |
28 | //===----------------------------------------------------------------------===// |
29 | // Predicate List Building |
30 | //===----------------------------------------------------------------------===// |
31 | |
32 | static void getTreePredicates(std::vector<PositionalPredicate> &predList, |
33 | Value val, PredicateBuilder &builder, |
34 | DenseMap<Value, Position *> &inputs, |
35 | Position *pos); |
36 | |
37 | /// Compares the depths of two positions. |
38 | static bool comparePosDepth(Position *lhs, Position *rhs) { |
39 | return lhs->getOperationDepth() < rhs->getOperationDepth(); |
40 | } |
41 | |
42 | /// Returns the number of non-range elements within `values`. |
43 | static unsigned getNumNonRangeValues(ValueRange values) { |
44 | return llvm::count_if(Range: values.getTypes(), |
45 | P: [](Type type) { return !isa<pdl::RangeType>(type); }); |
46 | } |
47 | |
48 | static void getTreePredicates(std::vector<PositionalPredicate> &predList, |
49 | Value val, PredicateBuilder &builder, |
50 | DenseMap<Value, Position *> &inputs, |
51 | AttributePosition *pos) { |
52 | assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type" ); |
53 | predList.emplace_back(args&: pos, args: builder.getIsNotNull()); |
54 | |
55 | if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) { |
56 | // If the attribute has a type or value, add a constraint. |
57 | if (Value type = attr.getValueType()) |
58 | getTreePredicates(predList, val: type, builder, inputs, pos: builder.getType(pos)); |
59 | else if (Attribute value = attr.getValueAttr()) |
60 | predList.emplace_back(args&: pos, args: builder.getAttributeConstraint(attr: value)); |
61 | } |
62 | } |
63 | |
64 | /// Collect all of the predicates for the given operand position. |
65 | static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList, |
66 | Value val, PredicateBuilder &builder, |
67 | DenseMap<Value, Position *> &inputs, |
68 | Position *pos) { |
69 | Type valueType = val.getType(); |
70 | bool isVariadic = isa<pdl::RangeType>(valueType); |
71 | |
72 | // If this is a typed operand, add a type constraint. |
73 | TypeSwitch<Operation *>(val.getDefiningOp()) |
74 | .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) { |
75 | // Prevent traversal into a null value if the operand has a proper |
76 | // index. |
77 | if (std::is_same<pdl::OperandOp, decltype(op)>::value || |
78 | cast<OperandGroupPosition>(pos)->getOperandGroupNumber()) |
79 | predList.emplace_back(pos, builder.getIsNotNull()); |
80 | |
81 | if (Value type = op.getValueType()) |
82 | getTreePredicates(predList, type, builder, inputs, |
83 | builder.getType(pos)); |
84 | }) |
85 | .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) { |
86 | std::optional<unsigned> index = op.getIndex(); |
87 | |
88 | // Prevent traversal into a null value if the result has a proper index. |
89 | if (index) |
90 | predList.emplace_back(pos, builder.getIsNotNull()); |
91 | |
92 | // Get the parent operation of this operand. |
93 | OperationPosition *parentPos = builder.getOperandDefiningOp(pos); |
94 | predList.emplace_back(parentPos, builder.getIsNotNull()); |
95 | |
96 | // Ensure that the operands match the corresponding results of the |
97 | // parent operation. |
98 | Position *resultPos = nullptr; |
99 | if (std::is_same<pdl::ResultOp, decltype(op)>::value) |
100 | resultPos = builder.getResult(parentPos, *index); |
101 | else |
102 | resultPos = builder.getResultGroup(parentPos, index, isVariadic); |
103 | predList.emplace_back(resultPos, builder.getEqualTo(pos)); |
104 | |
105 | // Collect the predicates of the parent operation. |
106 | getTreePredicates(predList, op.getParent(), builder, inputs, |
107 | (Position *)parentPos); |
108 | }); |
109 | } |
110 | |
111 | static void |
112 | getTreePredicates(std::vector<PositionalPredicate> &predList, Value val, |
113 | PredicateBuilder &builder, |
114 | DenseMap<Value, Position *> &inputs, OperationPosition *pos, |
115 | std::optional<unsigned> ignoreOperand = std::nullopt) { |
116 | assert(isa<pdl::OperationType>(val.getType()) && "expected operation" ); |
117 | pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp()); |
118 | OperationPosition *opPos = cast<OperationPosition>(Val: pos); |
119 | |
120 | // Ensure getDefiningOp returns a non-null operation. |
121 | if (!opPos->isRoot()) |
122 | predList.emplace_back(args&: pos, args: builder.getIsNotNull()); |
123 | |
124 | // Check that this is the correct root operation. |
125 | if (std::optional<StringRef> opName = op.getOpName()) |
126 | predList.emplace_back(args&: pos, args: builder.getOperationName(name: *opName)); |
127 | |
128 | // Check that the operation has the proper number of operands. If there are |
129 | // any variable length operands, we check a minimum instead of an exact count. |
130 | OperandRange operands = op.getOperandValues(); |
131 | unsigned minOperands = getNumNonRangeValues(values: operands); |
132 | if (minOperands != operands.size()) { |
133 | if (minOperands) |
134 | predList.emplace_back(args&: pos, args: builder.getOperandCountAtLeast(count: minOperands)); |
135 | } else { |
136 | predList.emplace_back(args&: pos, args: builder.getOperandCount(count: minOperands)); |
137 | } |
138 | |
139 | // Check that the operation has the proper number of results. If there are |
140 | // any variable length results, we check a minimum instead of an exact count. |
141 | OperandRange types = op.getTypeValues(); |
142 | unsigned minResults = getNumNonRangeValues(values: types); |
143 | if (minResults == types.size()) |
144 | predList.emplace_back(args&: pos, args: builder.getResultCount(count: types.size())); |
145 | else if (minResults) |
146 | predList.emplace_back(args&: pos, args: builder.getResultCountAtLeast(count: minResults)); |
147 | |
148 | // Recurse into any attributes, operands, or results. |
149 | for (auto [attrName, attr] : |
150 | llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) { |
151 | getTreePredicates( |
152 | predList, attr, builder, inputs, |
153 | builder.getAttribute(opPos, cast<StringAttr>(attrName).getValue())); |
154 | } |
155 | |
156 | // Process the operands and results of the operation. For all values up to |
157 | // the first variable length value, we use the concrete operand/result |
158 | // number. After that, we use the "group" given that we can't know the |
159 | // concrete indices until runtime. If there is only one variadic operand |
160 | // group, we treat it as all of the operands/results of the operation. |
161 | /// Operands. |
162 | if (operands.size() == 1 && isa<pdl::RangeType>(operands[0].getType())) { |
163 | // Ignore the operands if we are performing an upward traversal (in that |
164 | // case, they have already been visited). |
165 | if (opPos->isRoot() || opPos->isOperandDefiningOp()) |
166 | getTreePredicates(predList, val: operands.front(), builder, inputs, |
167 | pos: builder.getAllOperands(p: opPos)); |
168 | } else { |
169 | bool foundVariableLength = false; |
170 | for (const auto &operandIt : llvm::enumerate(operands)) { |
171 | bool isVariadic = isa<pdl::RangeType>(operandIt.value().getType()); |
172 | foundVariableLength |= isVariadic; |
173 | |
174 | // Ignore the specified operand, usually because this position was |
175 | // visited in an upward traversal via an iterative choice. |
176 | if (ignoreOperand && *ignoreOperand == operandIt.index()) |
177 | continue; |
178 | |
179 | Position *pos = |
180 | foundVariableLength |
181 | ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic) |
182 | : builder.getOperand(opPos, operandIt.index()); |
183 | getTreePredicates(predList, operandIt.value(), builder, inputs, pos); |
184 | } |
185 | } |
186 | /// Results. |
187 | if (types.size() == 1 && isa<pdl::RangeType>(types[0].getType())) { |
188 | getTreePredicates(predList, val: types.front(), builder, inputs, |
189 | pos: builder.getType(p: builder.getAllResults(p: opPos))); |
190 | return; |
191 | } |
192 | |
193 | bool foundVariableLength = false; |
194 | for (auto [idx, typeValue] : llvm::enumerate(types)) { |
195 | bool isVariadic = isa<pdl::RangeType>(typeValue.getType()); |
196 | foundVariableLength |= isVariadic; |
197 | |
198 | auto *resultPos = foundVariableLength |
199 | ? builder.getResultGroup(pos, idx, isVariadic) |
200 | : builder.getResult(pos, idx); |
201 | predList.emplace_back(resultPos, builder.getIsNotNull()); |
202 | getTreePredicates(predList, typeValue, builder, inputs, |
203 | builder.getType(resultPos)); |
204 | } |
205 | } |
206 | |
207 | static void getTreePredicates(std::vector<PositionalPredicate> &predList, |
208 | Value val, PredicateBuilder &builder, |
209 | DenseMap<Value, Position *> &inputs, |
210 | TypePosition *pos) { |
211 | // Check for a constraint on a constant type. |
212 | if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) { |
213 | if (Attribute type = typeOp.getConstantTypeAttr()) |
214 | predList.emplace_back(args&: pos, args: builder.getTypeConstraint(type)); |
215 | } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) { |
216 | if (Attribute typeAttr = typeOp.getConstantTypesAttr()) |
217 | predList.emplace_back(args&: pos, args: builder.getTypeConstraint(type: typeAttr)); |
218 | } |
219 | } |
220 | |
221 | /// Collect the tree predicates anchored at the given value. |
222 | static void getTreePredicates(std::vector<PositionalPredicate> &predList, |
223 | Value val, PredicateBuilder &builder, |
224 | DenseMap<Value, Position *> &inputs, |
225 | Position *pos) { |
226 | // Make sure this input value is accessible to the rewrite. |
227 | auto it = inputs.try_emplace(Key: val, Args&: pos); |
228 | if (!it.second) { |
229 | // If this is an input value that has been visited in the tree, add a |
230 | // constraint to ensure that both instances refer to the same value. |
231 | if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp, |
232 | pdl::TypeOp>(val.getDefiningOp())) { |
233 | auto minMaxPositions = |
234 | std::minmax(a: pos, b: it.first->second, comp: comparePosDepth); |
235 | predList.emplace_back(args: minMaxPositions.second, |
236 | args: builder.getEqualTo(pos: minMaxPositions.first)); |
237 | } |
238 | return; |
239 | } |
240 | |
241 | TypeSwitch<Position *>(pos) |
242 | .Case<AttributePosition, OperationPosition, TypePosition>(caseFn: [&](auto *pos) { |
243 | getTreePredicates(predList, val, builder, inputs, pos); |
244 | }) |
245 | .Case<OperandPosition, OperandGroupPosition>(caseFn: [&](auto *pos) { |
246 | getOperandTreePredicates(predList, val, builder, inputs, pos); |
247 | }) |
248 | .Default(defaultFn: [](auto *) { llvm_unreachable("unexpected position kind" ); }); |
249 | } |
250 | |
251 | static void getAttributePredicates(pdl::AttributeOp op, |
252 | std::vector<PositionalPredicate> &predList, |
253 | PredicateBuilder &builder, |
254 | DenseMap<Value, Position *> &inputs) { |
255 | Position *&attrPos = inputs[op]; |
256 | if (attrPos) |
257 | return; |
258 | Attribute value = op.getValueAttr(); |
259 | assert(value && "expected non-tree `pdl.attribute` to contain a value" ); |
260 | attrPos = builder.getAttributeLiteral(attr: value); |
261 | } |
262 | |
263 | static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, |
264 | std::vector<PositionalPredicate> &predList, |
265 | PredicateBuilder &builder, |
266 | DenseMap<Value, Position *> &inputs) { |
267 | OperandRange arguments = op.getArgs(); |
268 | |
269 | std::vector<Position *> allPositions; |
270 | allPositions.reserve(n: arguments.size()); |
271 | for (Value arg : arguments) |
272 | allPositions.push_back(inputs.lookup(arg)); |
273 | |
274 | // Push the constraint to the furthest position. |
275 | Position *pos = *llvm::max_element(Range&: allPositions, C: comparePosDepth); |
276 | ResultRange results = op.getResults(); |
277 | PredicateBuilder::Predicate pred = builder.getConstraint( |
278 | name: op.getName(), args: allPositions, resultTypes: SmallVector<Type>(results.getTypes()), |
279 | isNegated: op.getIsNegated()); |
280 | |
281 | // For each result register a position so it can be used later |
282 | for (auto [i, result] : llvm::enumerate(results)) { |
283 | ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first); |
284 | ConstraintPosition *pos = builder.getConstraintPosition(q, i); |
285 | auto [it, inserted] = inputs.try_emplace(result, pos); |
286 | // If this is an input value that has been visited in the tree, add a |
287 | // constraint to ensure that both instances refer to the same value. |
288 | if (!inserted) { |
289 | Position *first = pos; |
290 | Position *second = it->second; |
291 | if (comparePosDepth(second, first)) |
292 | std::tie(second, first) = std::make_pair(first, second); |
293 | |
294 | predList.emplace_back(second, builder.getEqualTo(first)); |
295 | } |
296 | } |
297 | predList.emplace_back(args&: pos, args&: pred); |
298 | } |
299 | |
300 | static void getResultPredicates(pdl::ResultOp op, |
301 | std::vector<PositionalPredicate> &predList, |
302 | PredicateBuilder &builder, |
303 | DenseMap<Value, Position *> &inputs) { |
304 | Position *&resultPos = inputs[op]; |
305 | if (resultPos) |
306 | return; |
307 | |
308 | // Ensure that the result isn't null. |
309 | auto *parentPos = cast<OperationPosition>(inputs.lookup(Val: op.getParent())); |
310 | resultPos = builder.getResult(p: parentPos, result: op.getIndex()); |
311 | predList.emplace_back(args&: resultPos, args: builder.getIsNotNull()); |
312 | } |
313 | |
314 | static void getResultPredicates(pdl::ResultsOp op, |
315 | std::vector<PositionalPredicate> &predList, |
316 | PredicateBuilder &builder, |
317 | DenseMap<Value, Position *> &inputs) { |
318 | Position *&resultPos = inputs[op]; |
319 | if (resultPos) |
320 | return; |
321 | |
322 | // Ensure that the result isn't null if the result has an index. |
323 | auto *parentPos = cast<OperationPosition>(inputs.lookup(Val: op.getParent())); |
324 | bool isVariadic = isa<pdl::RangeType>(op.getType()); |
325 | std::optional<unsigned> index = op.getIndex(); |
326 | resultPos = builder.getResultGroup(p: parentPos, group: index, isVariadic); |
327 | if (index) |
328 | predList.emplace_back(args&: resultPos, args: builder.getIsNotNull()); |
329 | } |
330 | |
331 | static void getTypePredicates(Value typeValue, |
332 | function_ref<Attribute()> typeAttrFn, |
333 | PredicateBuilder &builder, |
334 | DenseMap<Value, Position *> &inputs) { |
335 | Position *&typePos = inputs[typeValue]; |
336 | if (typePos) |
337 | return; |
338 | Attribute typeAttr = typeAttrFn(); |
339 | assert(typeAttr && |
340 | "expected non-tree `pdl.type`/`pdl.types` to contain a value" ); |
341 | typePos = builder.getTypeLiteral(attr: typeAttr); |
342 | } |
343 | |
344 | /// Collect all of the predicates that cannot be determined via walking the |
345 | /// tree. |
346 | static void getNonTreePredicates(pdl::PatternOp pattern, |
347 | std::vector<PositionalPredicate> &predList, |
348 | PredicateBuilder &builder, |
349 | DenseMap<Value, Position *> &inputs) { |
350 | for (Operation &op : pattern.getBodyRegion().getOps()) { |
351 | TypeSwitch<Operation *>(&op) |
352 | .Case([&](pdl::AttributeOp attrOp) { |
353 | getAttributePredicates(attrOp, predList, builder, inputs); |
354 | }) |
355 | .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) { |
356 | getConstraintPredicates(constraintOp, predList, builder, inputs); |
357 | }) |
358 | .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) { |
359 | getResultPredicates(resultOp, predList, builder, inputs); |
360 | }) |
361 | .Case([&](pdl::TypeOp typeOp) { |
362 | getTypePredicates( |
363 | typeOp, [&] { return typeOp.getConstantTypeAttr(); }, builder, |
364 | inputs); |
365 | }) |
366 | .Case([&](pdl::TypesOp typeOp) { |
367 | getTypePredicates( |
368 | typeOp, [&] { return typeOp.getConstantTypesAttr(); }, builder, |
369 | inputs); |
370 | }); |
371 | } |
372 | } |
373 | |
374 | namespace { |
375 | |
376 | /// An op accepting a value at an optional index. |
377 | struct OpIndex { |
378 | Value parent; |
379 | std::optional<unsigned> index; |
380 | }; |
381 | |
382 | /// The parent and operand index of each operation for each root, stored |
383 | /// as a nested map [root][operation]. |
384 | using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>; |
385 | |
386 | } // namespace |
387 | |
388 | /// Given a pattern, determines the set of roots present in this pattern. |
389 | /// These are the operations whose results are not consumed by other operations. |
390 | static SmallVector<Value> detectRoots(pdl::PatternOp pattern) { |
391 | // First, collect all the operations that are used as operands |
392 | // to other operations. These are not roots by default. |
393 | DenseSet<Value> used; |
394 | for (auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) { |
395 | for (Value operand : operationOp.getOperandValues()) |
396 | TypeSwitch<Operation *>(operand.getDefiningOp()) |
397 | .Case<pdl::ResultOp, pdl::ResultsOp>( |
398 | [&used](auto resultOp) { used.insert(resultOp.getParent()); }); |
399 | } |
400 | |
401 | // Remove the specified root from the use set, so that we can |
402 | // always select it as a root, even if it is used by other operations. |
403 | if (Value root = pattern.getRewriter().getRoot()) |
404 | used.erase(V: root); |
405 | |
406 | // Finally, collect all the unused operations. |
407 | SmallVector<Value> roots; |
408 | for (Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) |
409 | if (!used.contains(operationOp)) |
410 | roots.push_back(operationOp); |
411 | |
412 | return roots; |
413 | } |
414 | |
415 | /// Given a list of candidate roots, builds the cost graph for connecting them. |
416 | /// The graph is formed by traversing the DAG of operations starting from each |
417 | /// root and marking the depth of each connector value (operand). Then we join |
418 | /// the candidate roots based on the common connector values, taking the one |
419 | /// with the minimum depth. Along the way, we compute, for each candidate root, |
420 | /// a mapping from each operation (in the DAG underneath this root) to its |
421 | /// parent operation and the corresponding operand index. |
422 | static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph, |
423 | ParentMaps &parentMaps) { |
424 | |
425 | // The entry of a queue. The entry consists of the following items: |
426 | // * the value in the DAG underneath the root; |
427 | // * the parent of the value; |
428 | // * the operand index of the value in its parent; |
429 | // * the depth of the visited value. |
430 | struct Entry { |
431 | Entry(Value value, Value parent, std::optional<unsigned> index, |
432 | unsigned depth) |
433 | : value(value), parent(parent), index(index), depth(depth) {} |
434 | |
435 | Value value; |
436 | Value parent; |
437 | std::optional<unsigned> index; |
438 | unsigned depth; |
439 | }; |
440 | |
441 | // A root of a value and its depth (distance from root to the value). |
442 | struct RootDepth { |
443 | Value root; |
444 | unsigned depth = 0; |
445 | }; |
446 | |
447 | // Map from candidate connector values to their roots and depths. Using a |
448 | // small vector with 1 entry because most values belong to a single root. |
449 | llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths; |
450 | |
451 | // Perform a breadth-first traversal of the op DAG rooted at each root. |
452 | for (Value root : roots) { |
453 | // The queue of visited values. A value may be present multiple times in |
454 | // the queue, for multiple parents. We only accept the first occurrence, |
455 | // which is guaranteed to have the lowest depth. |
456 | std::queue<Entry> toVisit; |
457 | toVisit.emplace(args&: root, args: Value(), args: 0, args: 0); |
458 | |
459 | // The map from value to its parent for the current root. |
460 | DenseMap<Value, OpIndex> &parentMap = parentMaps[root]; |
461 | |
462 | while (!toVisit.empty()) { |
463 | Entry entry = toVisit.front(); |
464 | toVisit.pop(); |
465 | // Skip if already visited. |
466 | if (!parentMap.insert(KV: {entry.value, {.parent: entry.parent, .index: entry.index}}).second) |
467 | continue; |
468 | |
469 | // Mark the root and depth of the value. |
470 | connectorsRootsDepths[entry.value].push_back(Elt: {.root: root, .depth: entry.depth}); |
471 | |
472 | // Traverse the operands of an operation and result ops. |
473 | // We intentionally do not traverse attributes and types, because those |
474 | // are expensive to join on. |
475 | TypeSwitch<Operation *>(entry.value.getDefiningOp()) |
476 | .Case<pdl::OperationOp>([&](auto operationOp) { |
477 | OperandRange operands = operationOp.getOperandValues(); |
478 | // Special case when we pass all the operands in one range. |
479 | // For those, the index is empty. |
480 | if (operands.size() == 1 && |
481 | isa<pdl::RangeType>(operands[0].getType())) { |
482 | toVisit.emplace(operands[0], entry.value, std::nullopt, |
483 | entry.depth + 1); |
484 | return; |
485 | } |
486 | |
487 | // Default case: visit all the operands. |
488 | for (const auto &p : |
489 | llvm::enumerate(operationOp.getOperandValues())) |
490 | toVisit.emplace(p.value(), entry.value, p.index(), |
491 | entry.depth + 1); |
492 | }) |
493 | .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) { |
494 | toVisit.emplace(resultOp.getParent(), entry.value, |
495 | resultOp.getIndex(), entry.depth); |
496 | }); |
497 | } |
498 | } |
499 | |
500 | // Now build the cost graph. |
501 | // This is simply a minimum over all depths for the target root. |
502 | unsigned nextID = 0; |
503 | for (const auto &connectorRootsDepths : connectorsRootsDepths) { |
504 | Value value = connectorRootsDepths.first; |
505 | ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second; |
506 | // If there is only one root for this value, this will not trigger |
507 | // any edges in the cost graph (a perf optimization). |
508 | if (rootsDepths.size() == 1) |
509 | continue; |
510 | |
511 | for (const RootDepth &p : rootsDepths) { |
512 | for (const RootDepth &q : rootsDepths) { |
513 | if (&p == &q) |
514 | continue; |
515 | // Insert or retrieve the property of edge from p to q. |
516 | RootOrderingEntry &entry = graph[q.root][p.root]; |
517 | if (!entry.connector /* new edge */ || entry.cost.first > q.depth) { |
518 | if (!entry.connector) |
519 | entry.cost.second = nextID++; |
520 | entry.cost.first = q.depth; |
521 | entry.connector = value; |
522 | } |
523 | } |
524 | } |
525 | } |
526 | |
527 | assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) && |
528 | "the pattern contains a candidate root disconnected from the others" ); |
529 | } |
530 | |
531 | /// Returns true if the operand at the given index needs to be queried using an |
532 | /// operand group, i.e., if it is variadic itself or follows a variadic operand. |
533 | static bool useOperandGroup(pdl::OperationOp op, unsigned index) { |
534 | OperandRange operands = op.getOperandValues(); |
535 | assert(index < operands.size() && "operand index out of range" ); |
536 | for (unsigned i = 0; i <= index; ++i) |
537 | if (isa<pdl::RangeType>(operands[i].getType())) |
538 | return true; |
539 | return false; |
540 | } |
541 | |
542 | /// Visit a node during upward traversal. |
543 | static void visitUpward(std::vector<PositionalPredicate> &predList, |
544 | OpIndex opIndex, PredicateBuilder &builder, |
545 | DenseMap<Value, Position *> &valueToPosition, |
546 | Position *&pos, unsigned rootID) { |
547 | Value value = opIndex.parent; |
548 | TypeSwitch<Operation *>(value.getDefiningOp()) |
549 | .Case<pdl::OperationOp>([&](auto operationOp) { |
550 | LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" ); |
551 | |
552 | // Get users and iterate over them. |
553 | Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true); |
554 | Position *foreachPos = builder.getForEach(usersPos, rootID); |
555 | OperationPosition *opPos = builder.getPassthroughOp(foreachPos); |
556 | |
557 | // Compare the operand(s) of the user against the input value(s). |
558 | Position *operandPos; |
559 | if (!opIndex.index) { |
560 | // We are querying all the operands of the operation. |
561 | operandPos = builder.getAllOperands(opPos); |
562 | } else if (useOperandGroup(operationOp, *opIndex.index)) { |
563 | // We are querying an operand group. |
564 | Type type = operationOp.getOperandValues()[*opIndex.index].getType(); |
565 | bool variadic = isa<pdl::RangeType>(type); |
566 | operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic); |
567 | } else { |
568 | // We are querying an individual operand. |
569 | operandPos = builder.getOperand(opPos, *opIndex.index); |
570 | } |
571 | predList.emplace_back(operandPos, builder.getEqualTo(pos)); |
572 | |
573 | // Guard against duplicate upward visits. These are not possible, |
574 | // because if this value was already visited, it would have been |
575 | // cheaper to start the traversal at this value rather than at the |
576 | // `connector`, violating the optimality of our spanning tree. |
577 | bool inserted = valueToPosition.try_emplace(value, opPos).second; |
578 | (void)inserted; |
579 | assert(inserted && "duplicate upward visit" ); |
580 | |
581 | // Obtain the tree predicates at the current value. |
582 | getTreePredicates(predList, value, builder, valueToPosition, opPos, |
583 | opIndex.index); |
584 | |
585 | // Update the position |
586 | pos = opPos; |
587 | }) |
588 | .Case<pdl::ResultOp>([&](auto resultOp) { |
589 | // Traverse up an individual result. |
590 | auto *opPos = dyn_cast<OperationPosition>(pos); |
591 | assert(opPos && "operations and results must be interleaved" ); |
592 | pos = builder.getResult(opPos, *opIndex.index); |
593 | |
594 | // Insert the result position in case we have not visited it yet. |
595 | valueToPosition.try_emplace(value, pos); |
596 | }) |
597 | .Case<pdl::ResultsOp>([&](auto resultOp) { |
598 | // Traverse up a group of results. |
599 | auto *opPos = dyn_cast<OperationPosition>(pos); |
600 | assert(opPos && "operations and results must be interleaved" ); |
601 | bool isVariadic = isa<pdl::RangeType>(value.getType()); |
602 | if (opIndex.index) |
603 | pos = builder.getResultGroup(opPos, opIndex.index, isVariadic); |
604 | else |
605 | pos = builder.getAllResults(opPos); |
606 | |
607 | // Insert the result position in case we have not visited it yet. |
608 | valueToPosition.try_emplace(value, pos); |
609 | }); |
610 | } |
611 | |
612 | /// Given a pattern operation, build the set of matcher predicates necessary to |
613 | /// match this pattern. |
614 | static Value buildPredicateList(pdl::PatternOp pattern, |
615 | PredicateBuilder &builder, |
616 | std::vector<PositionalPredicate> &predList, |
617 | DenseMap<Value, Position *> &valueToPosition) { |
618 | SmallVector<Value> roots = detectRoots(pattern); |
619 | |
620 | // Build the root ordering graph and compute the parent maps. |
621 | RootOrderingGraph graph; |
622 | ParentMaps parentMaps; |
623 | buildCostGraph(roots, graph, parentMaps); |
624 | LLVM_DEBUG({ |
625 | llvm::dbgs() << "Graph:\n" ; |
626 | for (auto &target : graph) { |
627 | llvm::dbgs() << " * " << target.first.getLoc() << " " << target.first |
628 | << "\n" ; |
629 | for (auto &source : target.second) { |
630 | RootOrderingEntry &entry = source.second; |
631 | llvm::dbgs() << " <- " << source.first << ": " << entry.cost.first |
632 | << ":" << entry.cost.second << " via " |
633 | << entry.connector.getLoc() << "\n" ; |
634 | } |
635 | } |
636 | }); |
637 | |
638 | // Solve the optimal branching problem for each candidate root, or use the |
639 | // provided one. |
640 | Value bestRoot = pattern.getRewriter().getRoot(); |
641 | OptimalBranching::EdgeList bestEdges; |
642 | if (!bestRoot) { |
643 | unsigned bestCost = 0; |
644 | LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n" ); |
645 | for (Value root : roots) { |
646 | OptimalBranching solver(graph, root); |
647 | unsigned cost = solver.solve(); |
648 | LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n" ); |
649 | if (!bestRoot || bestCost > cost) { |
650 | bestCost = cost; |
651 | bestRoot = root; |
652 | bestEdges = solver.preOrderTraversal(nodes: roots); |
653 | } |
654 | } |
655 | } else { |
656 | OptimalBranching solver(graph, bestRoot); |
657 | solver.solve(); |
658 | bestEdges = solver.preOrderTraversal(nodes: roots); |
659 | } |
660 | |
661 | // Print the best solution. |
662 | LLVM_DEBUG({ |
663 | llvm::dbgs() << "Best tree:\n" ; |
664 | for (const std::pair<Value, Value> &edge : bestEdges) { |
665 | llvm::dbgs() << " * " << edge.first; |
666 | if (edge.second) |
667 | llvm::dbgs() << " <- " << edge.second; |
668 | llvm::dbgs() << "\n" ; |
669 | } |
670 | }); |
671 | |
672 | LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n" ); |
673 | LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n" ); |
674 | |
675 | // The best root is the starting point for the traversal. Get the tree |
676 | // predicates for the DAG rooted at bestRoot. |
677 | getTreePredicates(predList, val: bestRoot, builder, inputs&: valueToPosition, |
678 | pos: builder.getRoot()); |
679 | |
680 | // Traverse the selected optimal branching. For all edges in order, traverse |
681 | // up starting from the connector, until the candidate root is reached, and |
682 | // call getTreePredicates at every node along the way. |
683 | for (const auto &it : llvm::enumerate(First&: bestEdges)) { |
684 | Value target = it.value().first; |
685 | Value source = it.value().second; |
686 | |
687 | // Check if we already visited the target root. This happens in two cases: |
688 | // 1) the initial root (bestRoot); |
689 | // 2) a root that is dominated by (contained in the subtree rooted at) an |
690 | // already visited root. |
691 | if (valueToPosition.count(Val: target)) |
692 | continue; |
693 | |
694 | // Determine the connector. |
695 | Value connector = graph[target][source].connector; |
696 | assert(connector && "invalid edge" ); |
697 | LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n" ); |
698 | DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(Val: target); |
699 | Position *pos = valueToPosition.lookup(Val: connector); |
700 | assert(pos && "connector has not been traversed yet" ); |
701 | |
702 | // Traverse from the connector upwards towards the target root. |
703 | for (Value value = connector; value != target;) { |
704 | OpIndex opIndex = parentMap.lookup(Val: value); |
705 | assert(opIndex.parent && "missing parent" ); |
706 | visitUpward(predList, opIndex, builder, valueToPosition, pos, rootID: it.index()); |
707 | value = opIndex.parent; |
708 | } |
709 | } |
710 | |
711 | getNonTreePredicates(pattern, predList, builder, valueToPosition); |
712 | |
713 | return bestRoot; |
714 | } |
715 | |
716 | //===----------------------------------------------------------------------===// |
717 | // Pattern Predicate Tree Merging |
718 | //===----------------------------------------------------------------------===// |
719 | |
720 | namespace { |
721 | |
722 | /// This class represents a specific predicate applied to a position, and |
723 | /// provides hashing and ordering operators. This class allows for computing a |
724 | /// frequence sum and ordering predicates based on a cost model. |
725 | struct OrderedPredicate { |
726 | OrderedPredicate(const std::pair<Position *, Qualifier *> &ip) |
727 | : position(ip.first), question(ip.second) {} |
728 | OrderedPredicate(const PositionalPredicate &ip) |
729 | : position(ip.position), question(ip.question) {} |
730 | |
731 | /// The position this predicate is applied to. |
732 | Position *position; |
733 | |
734 | /// The question that is applied by this predicate onto the position. |
735 | Qualifier *question; |
736 | |
737 | /// The first and second order benefit sums. |
738 | /// The primary sum is the number of occurrences of this predicate among all |
739 | /// of the patterns. |
740 | unsigned primary = 0; |
741 | /// The secondary sum is a squared summation of the primary sum of all of the |
742 | /// predicates within each pattern that contains this predicate. This allows |
743 | /// for favoring predicates that are more commonly shared within a pattern, as |
744 | /// opposed to those shared across patterns. |
745 | unsigned secondary = 0; |
746 | |
747 | /// The tie breaking ID, used to preserve a deterministic (insertion) order |
748 | /// among all the predicates with the same priority, depth, and position / |
749 | /// predicate dependency. |
750 | unsigned id = 0; |
751 | |
752 | /// A map between a pattern operation and the answer to the predicate question |
753 | /// within that pattern. |
754 | DenseMap<Operation *, Qualifier *> patternToAnswer; |
755 | |
756 | /// Returns true if this predicate is ordered before `rhs`, based on the cost |
757 | /// model. |
758 | bool operator<(const OrderedPredicate &rhs) const { |
759 | // Sort by: |
760 | // * higher first and secondary order sums |
761 | // * lower depth |
762 | // * lower position dependency |
763 | // * lower predicate dependency |
764 | // * lower tie breaking ID |
765 | auto *rhsPos = rhs.position; |
766 | return std::make_tuple(args: primary, args: secondary, args: rhsPos->getOperationDepth(), |
767 | args: rhsPos->getKind(), args: rhs.question->getKind(), args: rhs.id) > |
768 | std::make_tuple(args: rhs.primary, args: rhs.secondary, |
769 | args: position->getOperationDepth(), args: position->getKind(), |
770 | args: question->getKind(), args: id); |
771 | } |
772 | }; |
773 | |
774 | /// A DenseMapInfo for OrderedPredicate based solely on the position and |
775 | /// question. |
776 | struct OrderedPredicateDenseInfo { |
777 | using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>; |
778 | |
779 | static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); } |
780 | static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); } |
781 | static bool isEqual(const OrderedPredicate &lhs, |
782 | const OrderedPredicate &rhs) { |
783 | return lhs.position == rhs.position && lhs.question == rhs.question; |
784 | } |
785 | static unsigned getHashValue(const OrderedPredicate &p) { |
786 | return llvm::hash_combine(args: p.position, args: p.question); |
787 | } |
788 | }; |
789 | |
790 | /// This class wraps a set of ordered predicates that are used within a specific |
791 | /// pattern operation. |
792 | struct OrderedPredicateList { |
793 | OrderedPredicateList(pdl::PatternOp pattern, Value root) |
794 | : pattern(pattern), root(root) {} |
795 | |
796 | pdl::PatternOp pattern; |
797 | Value root; |
798 | DenseSet<OrderedPredicate *> predicates; |
799 | }; |
800 | } // namespace |
801 | |
802 | /// Returns true if the given matcher refers to the same predicate as the given |
803 | /// ordered predicate. This means that the position and questions of the two |
804 | /// match. |
805 | static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) { |
806 | return node->getPosition() == predicate->position && |
807 | node->getQuestion() == predicate->question; |
808 | } |
809 | |
810 | /// Get or insert a child matcher for the given parent switch node, given a |
811 | /// predicate and parent pattern. |
812 | std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node, |
813 | OrderedPredicate *predicate, |
814 | pdl::PatternOp pattern) { |
815 | assert(isSamePredicate(node, predicate) && |
816 | "expected matcher to equal the given predicate" ); |
817 | |
818 | auto it = predicate->patternToAnswer.find(pattern); |
819 | assert(it != predicate->patternToAnswer.end() && |
820 | "expected pattern to exist in predicate" ); |
821 | return node->getChildren().insert({it->second, nullptr}).first->second; |
822 | } |
823 | |
824 | /// Build the matcher CFG by "pushing" patterns through by sorted predicate |
825 | /// order. A pattern will traverse as far as possible using common predicates |
826 | /// and then either diverge from the CFG or reach the end of a branch and start |
827 | /// creating new nodes. |
828 | static void propagatePattern(std::unique_ptr<MatcherNode> &node, |
829 | OrderedPredicateList &list, |
830 | std::vector<OrderedPredicate *>::iterator current, |
831 | std::vector<OrderedPredicate *>::iterator end) { |
832 | if (current == end) { |
833 | // We've hit the end of a pattern, so create a successful result node. |
834 | node = |
835 | std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node)); |
836 | |
837 | // If the pattern doesn't contain this predicate, ignore it. |
838 | } else if (!list.predicates.contains(V: *current)) { |
839 | propagatePattern(node, list, current: std::next(x: current), end); |
840 | |
841 | // If the current matcher node is invalid, create a new one for this |
842 | // position and continue propagation. |
843 | } else if (!node) { |
844 | // Create a new node at this position and continue |
845 | node = std::make_unique<SwitchNode>(args&: (*current)->position, |
846 | args&: (*current)->question); |
847 | propagatePattern( |
848 | getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern), |
849 | list, std::next(current), end); |
850 | |
851 | // If the matcher has already been created, and it is for this predicate we |
852 | // continue propagation to the child. |
853 | } else if (isSamePredicate(node: node.get(), predicate: *current)) { |
854 | propagatePattern( |
855 | getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern), |
856 | list, std::next(current), end); |
857 | |
858 | // If the matcher doesn't match the current predicate, insert a branch as |
859 | // the common set of matchers has diverged. |
860 | } else { |
861 | propagatePattern(node&: node->getFailureNode(), list, current, end); |
862 | } |
863 | } |
864 | |
865 | /// Fold any switch nodes nested under `node` to boolean nodes when possible. |
866 | /// `node` is updated in-place if it is a switch. |
867 | static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) { |
868 | if (!node) |
869 | return; |
870 | |
871 | if (SwitchNode *switchNode = dyn_cast<SwitchNode>(Val: &*node)) { |
872 | SwitchNode::ChildMapT &children = switchNode->getChildren(); |
873 | for (auto &it : children) |
874 | foldSwitchToBool(node&: it.second); |
875 | |
876 | // If the node only contains one child, collapse it into a boolean predicate |
877 | // node. |
878 | if (children.size() == 1) { |
879 | auto *childIt = children.begin(); |
880 | node = std::make_unique<BoolNode>( |
881 | args: node->getPosition(), args: node->getQuestion(), args&: childIt->first, |
882 | args: std::move(childIt->second), args: std::move(node->getFailureNode())); |
883 | } |
884 | } else if (BoolNode *boolNode = dyn_cast<BoolNode>(Val: &*node)) { |
885 | foldSwitchToBool(node&: boolNode->getSuccessNode()); |
886 | } |
887 | |
888 | foldSwitchToBool(node&: node->getFailureNode()); |
889 | } |
890 | |
891 | /// Insert an exit node at the end of the failure path of the `root`. |
892 | static void insertExitNode(std::unique_ptr<MatcherNode> *root) { |
893 | while (*root) |
894 | root = &(*root)->getFailureNode(); |
895 | *root = std::make_unique<ExitNode>(); |
896 | } |
897 | |
898 | /// Sorts the range begin/end with the partial order given by cmp. |
899 | template <typename Iterator, typename Compare> |
900 | static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) { |
901 | while (begin != end) { |
902 | // Cannot compute sortBeforeOthers in the predicate of stable_partition |
903 | // because stable_partition will not keep the [begin, end) range intact |
904 | // while it runs. |
905 | llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers; |
906 | for (auto i = begin; i != end; ++i) { |
907 | if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); })) |
908 | sortBeforeOthers.insert(*i); |
909 | } |
910 | |
911 | auto const next = std::stable_partition(begin, end, [&](auto const &a) { |
912 | return sortBeforeOthers.contains(a); |
913 | }); |
914 | assert(next != begin && "not a partial ordering" ); |
915 | begin = next; |
916 | } |
917 | } |
918 | |
919 | /// Returns true if 'b' depends on a result of 'a'. |
920 | static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) { |
921 | auto *cqa = dyn_cast<ConstraintQuestion>(Val: a->question); |
922 | if (!cqa) |
923 | return false; |
924 | |
925 | auto positionDependsOnA = [&](Position *p) { |
926 | auto *cp = dyn_cast<ConstraintPosition>(Val: p); |
927 | return cp && cp->getQuestion() == cqa; |
928 | }; |
929 | |
930 | if (auto *cqb = dyn_cast<ConstraintQuestion>(Val: b->question)) { |
931 | // Does any argument of b use a? |
932 | return llvm::any_of(Range: cqb->getArgs(), P: positionDependsOnA); |
933 | } |
934 | if (auto *equalTo = dyn_cast<EqualToQuestion>(Val: b->question)) { |
935 | return positionDependsOnA(b->position) || |
936 | positionDependsOnA(equalTo->getValue()); |
937 | } |
938 | return positionDependsOnA(b->position); |
939 | } |
940 | |
941 | /// Given a module containing PDL pattern operations, generate a matcher tree |
942 | /// using the patterns within the given module and return the root matcher node. |
943 | std::unique_ptr<MatcherNode> |
944 | MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder, |
945 | DenseMap<Value, Position *> &valueToPosition) { |
946 | // The set of predicates contained within the pattern operations of the |
947 | // module. |
948 | struct PatternPredicates { |
949 | PatternPredicates(pdl::PatternOp pattern, Value root, |
950 | std::vector<PositionalPredicate> predicates) |
951 | : pattern(pattern), root(root), predicates(std::move(predicates)) {} |
952 | |
953 | /// A pattern. |
954 | pdl::PatternOp pattern; |
955 | |
956 | /// A root of the pattern chosen among the candidate roots in pdl.rewrite. |
957 | Value root; |
958 | |
959 | /// The extracted predicates for this pattern and root. |
960 | std::vector<PositionalPredicate> predicates; |
961 | }; |
962 | |
963 | SmallVector<PatternPredicates, 16> patternsAndPredicates; |
964 | for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) { |
965 | std::vector<PositionalPredicate> predicateList; |
966 | Value root = |
967 | buildPredicateList(pattern, builder, predicateList, valueToPosition); |
968 | patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList)); |
969 | } |
970 | |
971 | // Associate a pattern result with each unique predicate. |
972 | DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued; |
973 | for (auto &patternAndPredList : patternsAndPredicates) { |
974 | for (auto &predicate : patternAndPredList.predicates) { |
975 | auto it = uniqued.insert(V: predicate); |
976 | it.first->patternToAnswer.try_emplace(patternAndPredList.pattern, |
977 | predicate.answer); |
978 | // Mark the insertion order (0-based indexing). |
979 | if (it.second) |
980 | it.first->id = uniqued.size() - 1; |
981 | } |
982 | } |
983 | |
984 | // Associate each pattern to a set of its ordered predicates for later lookup. |
985 | std::vector<OrderedPredicateList> lists; |
986 | lists.reserve(n: patternsAndPredicates.size()); |
987 | for (auto &patternAndPredList : patternsAndPredicates) { |
988 | OrderedPredicateList list(patternAndPredList.pattern, |
989 | patternAndPredList.root); |
990 | for (auto &predicate : patternAndPredList.predicates) { |
991 | OrderedPredicate *orderedPredicate = &*uniqued.find(V: predicate); |
992 | list.predicates.insert(V: orderedPredicate); |
993 | |
994 | // Increment the primary sum for each reference to a particular predicate. |
995 | ++orderedPredicate->primary; |
996 | } |
997 | lists.push_back(x: std::move(list)); |
998 | } |
999 | |
1000 | // For a particular pattern, get the total primary sum and add it to the |
1001 | // secondary sum of each predicate. Square the primary sums to emphasize |
1002 | // shared predicates within rather than across patterns. |
1003 | for (auto &list : lists) { |
1004 | unsigned total = 0; |
1005 | for (auto *predicate : list.predicates) |
1006 | total += predicate->primary * predicate->primary; |
1007 | for (auto *predicate : list.predicates) |
1008 | predicate->secondary += total; |
1009 | } |
1010 | |
1011 | // Sort the set of predicates now that the cost primary and secondary sums |
1012 | // have been computed. |
1013 | std::vector<OrderedPredicate *> ordered; |
1014 | ordered.reserve(n: uniqued.size()); |
1015 | for (auto &ip : uniqued) |
1016 | ordered.push_back(x: &ip); |
1017 | llvm::sort(C&: ordered, Comp: [](OrderedPredicate *lhs, OrderedPredicate *rhs) { |
1018 | return *lhs < *rhs; |
1019 | }); |
1020 | |
1021 | // Mostly keep the now established order, but also ensure that |
1022 | // ConstraintQuestions come after the results they use. |
1023 | stableTopologicalSort(begin: ordered.begin(), end: ordered.end(), cmp: dependsOn); |
1024 | |
1025 | // Build the matchers for each of the pattern predicate lists. |
1026 | std::unique_ptr<MatcherNode> root; |
1027 | for (OrderedPredicateList &list : lists) |
1028 | propagatePattern(node&: root, list, current: ordered.begin(), end: ordered.end()); |
1029 | |
1030 | // Collapse the graph and insert the exit node. |
1031 | foldSwitchToBool(node&: root); |
1032 | insertExitNode(root: &root); |
1033 | return root; |
1034 | } |
1035 | |
1036 | //===----------------------------------------------------------------------===// |
1037 | // MatcherNode |
1038 | //===----------------------------------------------------------------------===// |
1039 | |
1040 | MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q, |
1041 | std::unique_ptr<MatcherNode> failureNode) |
1042 | : position(p), question(q), failureNode(std::move(failureNode)), |
1043 | matcherTypeID(matcherTypeID) {} |
1044 | |
1045 | //===----------------------------------------------------------------------===// |
1046 | // BoolNode |
1047 | //===----------------------------------------------------------------------===// |
1048 | |
1049 | BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer, |
1050 | std::unique_ptr<MatcherNode> successNode, |
1051 | std::unique_ptr<MatcherNode> failureNode) |
1052 | : MatcherNode(TypeID::get<BoolNode>(), position, question, |
1053 | std::move(failureNode)), |
1054 | answer(answer), successNode(std::move(successNode)) {} |
1055 | |
1056 | //===----------------------------------------------------------------------===// |
1057 | // SuccessNode |
1058 | //===----------------------------------------------------------------------===// |
1059 | |
1060 | SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root, |
1061 | std::unique_ptr<MatcherNode> failureNode) |
1062 | : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr, |
1063 | /*question=*/nullptr, std::move(failureNode)), |
1064 | pattern(pattern), root(root) {} |
1065 | |
1066 | //===----------------------------------------------------------------------===// |
1067 | // SwitchNode |
1068 | //===----------------------------------------------------------------------===// |
1069 | |
1070 | SwitchNode::SwitchNode(Position *position, Qualifier *question) |
1071 | : MatcherNode(TypeID::get<SwitchNode>(), position, question) {} |
1072 | |