| 1 | //===- PredicateTree.cpp - Predicate tree merging -------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "PredicateTree.h" |
| 10 | #include "RootOrdering.h" |
| 11 | |
| 12 | #include "mlir/Dialect/PDL/IR/PDLTypes.h" |
| 13 | #include "mlir/IR/BuiltinOps.h" |
| 14 | #include "llvm/ADT/MapVector.h" |
| 15 | #include "llvm/ADT/SmallPtrSet.h" |
| 16 | #include "llvm/ADT/TypeSwitch.h" |
| 17 | #include "llvm/Support/Debug.h" |
| 18 | #include <queue> |
| 19 | |
| 20 | #define DEBUG_TYPE "pdl-predicate-tree" |
| 21 | |
| 22 | using namespace mlir; |
| 23 | using namespace mlir::pdl_to_pdl_interp; |
| 24 | |
| 25 | //===----------------------------------------------------------------------===// |
| 26 | // Predicate List Building |
| 27 | //===----------------------------------------------------------------------===// |
| 28 | |
| 29 | static void getTreePredicates(std::vector<PositionalPredicate> &predList, |
| 30 | Value val, PredicateBuilder &builder, |
| 31 | DenseMap<Value, Position *> &inputs, |
| 32 | Position *pos); |
| 33 | |
| 34 | /// Compares the depths of two positions. |
| 35 | static bool comparePosDepth(Position *lhs, Position *rhs) { |
| 36 | return lhs->getOperationDepth() < rhs->getOperationDepth(); |
| 37 | } |
| 38 | |
| 39 | /// Returns the number of non-range elements within `values`. |
| 40 | static unsigned getNumNonRangeValues(ValueRange values) { |
| 41 | return llvm::count_if(Range: values.getTypes(), |
| 42 | P: [](Type type) { return !isa<pdl::RangeType>(Val: type); }); |
| 43 | } |
| 44 | |
| 45 | static void getTreePredicates(std::vector<PositionalPredicate> &predList, |
| 46 | Value val, PredicateBuilder &builder, |
| 47 | DenseMap<Value, Position *> &inputs, |
| 48 | AttributePosition *pos) { |
| 49 | assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type" ); |
| 50 | predList.emplace_back(args&: pos, args: builder.getIsNotNull()); |
| 51 | |
| 52 | if (auto attr = dyn_cast<pdl::AttributeOp>(Val: val.getDefiningOp())) { |
| 53 | // If the attribute has a type or value, add a constraint. |
| 54 | if (Value type = attr.getValueType()) |
| 55 | getTreePredicates(predList, val: type, builder, inputs, pos: builder.getType(p: pos)); |
| 56 | else if (Attribute value = attr.getValueAttr()) |
| 57 | predList.emplace_back(args&: pos, args: builder.getAttributeConstraint(attr: value)); |
| 58 | } |
| 59 | } |
| 60 | |
| 61 | /// Collect all of the predicates for the given operand position. |
| 62 | static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList, |
| 63 | Value val, PredicateBuilder &builder, |
| 64 | DenseMap<Value, Position *> &inputs, |
| 65 | Position *pos) { |
| 66 | Type valueType = val.getType(); |
| 67 | bool isVariadic = isa<pdl::RangeType>(Val: valueType); |
| 68 | |
| 69 | // If this is a typed operand, add a type constraint. |
| 70 | TypeSwitch<Operation *>(val.getDefiningOp()) |
| 71 | .Case<pdl::OperandOp, pdl::OperandsOp>(caseFn: [&](auto op) { |
| 72 | // Prevent traversal into a null value if the operand has a proper |
| 73 | // index. |
| 74 | if (std::is_same<pdl::OperandOp, decltype(op)>::value || |
| 75 | cast<OperandGroupPosition>(Val: pos)->getOperandGroupNumber()) |
| 76 | predList.emplace_back(args&: pos, args: builder.getIsNotNull()); |
| 77 | |
| 78 | if (Value type = op.getValueType()) |
| 79 | getTreePredicates(predList, val: type, builder, inputs, |
| 80 | pos: builder.getType(p: pos)); |
| 81 | }) |
| 82 | .Case<pdl::ResultOp, pdl::ResultsOp>(caseFn: [&](auto op) { |
| 83 | std::optional<unsigned> index = op.getIndex(); |
| 84 | |
| 85 | // Prevent traversal into a null value if the result has a proper index. |
| 86 | if (index) |
| 87 | predList.emplace_back(args&: pos, args: builder.getIsNotNull()); |
| 88 | |
| 89 | // Get the parent operation of this operand. |
| 90 | OperationPosition *parentPos = builder.getOperandDefiningOp(p: pos); |
| 91 | predList.emplace_back(args&: parentPos, args: builder.getIsNotNull()); |
| 92 | |
| 93 | // Ensure that the operands match the corresponding results of the |
| 94 | // parent operation. |
| 95 | Position *resultPos = nullptr; |
| 96 | if (std::is_same<pdl::ResultOp, decltype(op)>::value) |
| 97 | resultPos = builder.getResult(p: parentPos, result: *index); |
| 98 | else |
| 99 | resultPos = builder.getResultGroup(p: parentPos, group: index, isVariadic); |
| 100 | predList.emplace_back(args&: resultPos, args: builder.getEqualTo(pos)); |
| 101 | |
| 102 | // Collect the predicates of the parent operation. |
| 103 | getTreePredicates(predList, op.getParent(), builder, inputs, |
| 104 | (Position *)parentPos); |
| 105 | }); |
| 106 | } |
| 107 | |
| 108 | static void |
| 109 | getTreePredicates(std::vector<PositionalPredicate> &predList, Value val, |
| 110 | PredicateBuilder &builder, |
| 111 | DenseMap<Value, Position *> &inputs, OperationPosition *pos, |
| 112 | std::optional<unsigned> ignoreOperand = std::nullopt) { |
| 113 | assert(isa<pdl::OperationType>(val.getType()) && "expected operation" ); |
| 114 | pdl::OperationOp op = cast<pdl::OperationOp>(Val: val.getDefiningOp()); |
| 115 | OperationPosition *opPos = cast<OperationPosition>(Val: pos); |
| 116 | |
| 117 | // Ensure getDefiningOp returns a non-null operation. |
| 118 | if (!opPos->isRoot()) |
| 119 | predList.emplace_back(args&: pos, args: builder.getIsNotNull()); |
| 120 | |
| 121 | // Check that this is the correct root operation. |
| 122 | if (std::optional<StringRef> opName = op.getOpName()) |
| 123 | predList.emplace_back(args&: pos, args: builder.getOperationName(name: *opName)); |
| 124 | |
| 125 | // Check that the operation has the proper number of operands. If there are |
| 126 | // any variable length operands, we check a minimum instead of an exact count. |
| 127 | OperandRange operands = op.getOperandValues(); |
| 128 | unsigned minOperands = getNumNonRangeValues(values: operands); |
| 129 | if (minOperands != operands.size()) { |
| 130 | if (minOperands) |
| 131 | predList.emplace_back(args&: pos, args: builder.getOperandCountAtLeast(count: minOperands)); |
| 132 | } else { |
| 133 | predList.emplace_back(args&: pos, args: builder.getOperandCount(count: minOperands)); |
| 134 | } |
| 135 | |
| 136 | // Check that the operation has the proper number of results. If there are |
| 137 | // any variable length results, we check a minimum instead of an exact count. |
| 138 | OperandRange types = op.getTypeValues(); |
| 139 | unsigned minResults = getNumNonRangeValues(values: types); |
| 140 | if (minResults == types.size()) |
| 141 | predList.emplace_back(args&: pos, args: builder.getResultCount(count: types.size())); |
| 142 | else if (minResults) |
| 143 | predList.emplace_back(args&: pos, args: builder.getResultCountAtLeast(count: minResults)); |
| 144 | |
| 145 | // Recurse into any attributes, operands, or results. |
| 146 | for (auto [attrName, attr] : |
| 147 | llvm::zip(t: op.getAttributeValueNames(), u: op.getAttributeValues())) { |
| 148 | getTreePredicates( |
| 149 | predList, val: attr, builder, inputs, |
| 150 | pos: builder.getAttribute(p: opPos, name: cast<StringAttr>(Val: attrName).getValue())); |
| 151 | } |
| 152 | |
| 153 | // Process the operands and results of the operation. For all values up to |
| 154 | // the first variable length value, we use the concrete operand/result |
| 155 | // number. After that, we use the "group" given that we can't know the |
| 156 | // concrete indices until runtime. If there is only one variadic operand |
| 157 | // group, we treat it as all of the operands/results of the operation. |
| 158 | /// Operands. |
| 159 | if (operands.size() == 1 && isa<pdl::RangeType>(Val: operands[0].getType())) { |
| 160 | // Ignore the operands if we are performing an upward traversal (in that |
| 161 | // case, they have already been visited). |
| 162 | if (opPos->isRoot() || opPos->isOperandDefiningOp()) |
| 163 | getTreePredicates(predList, val: operands.front(), builder, inputs, |
| 164 | pos: builder.getAllOperands(p: opPos)); |
| 165 | } else { |
| 166 | bool foundVariableLength = false; |
| 167 | for (const auto &operandIt : llvm::enumerate(First&: operands)) { |
| 168 | bool isVariadic = isa<pdl::RangeType>(Val: operandIt.value().getType()); |
| 169 | foundVariableLength |= isVariadic; |
| 170 | |
| 171 | // Ignore the specified operand, usually because this position was |
| 172 | // visited in an upward traversal via an iterative choice. |
| 173 | if (ignoreOperand == operandIt.index()) |
| 174 | continue; |
| 175 | |
| 176 | Position *pos = |
| 177 | foundVariableLength |
| 178 | ? builder.getOperandGroup(p: opPos, group: operandIt.index(), isVariadic) |
| 179 | : builder.getOperand(p: opPos, operand: operandIt.index()); |
| 180 | getTreePredicates(predList, val: operandIt.value(), builder, inputs, pos); |
| 181 | } |
| 182 | } |
| 183 | /// Results. |
| 184 | if (types.size() == 1 && isa<pdl::RangeType>(Val: types[0].getType())) { |
| 185 | getTreePredicates(predList, val: types.front(), builder, inputs, |
| 186 | pos: builder.getType(p: builder.getAllResults(p: opPos))); |
| 187 | return; |
| 188 | } |
| 189 | |
| 190 | bool foundVariableLength = false; |
| 191 | for (auto [idx, typeValue] : llvm::enumerate(First&: types)) { |
| 192 | bool isVariadic = isa<pdl::RangeType>(Val: typeValue.getType()); |
| 193 | foundVariableLength |= isVariadic; |
| 194 | |
| 195 | auto *resultPos = foundVariableLength |
| 196 | ? builder.getResultGroup(p: pos, group: idx, isVariadic) |
| 197 | : builder.getResult(p: pos, result: idx); |
| 198 | predList.emplace_back(args&: resultPos, args: builder.getIsNotNull()); |
| 199 | getTreePredicates(predList, val: typeValue, builder, inputs, |
| 200 | pos: builder.getType(p: resultPos)); |
| 201 | } |
| 202 | } |
| 203 | |
| 204 | static void getTreePredicates(std::vector<PositionalPredicate> &predList, |
| 205 | Value val, PredicateBuilder &builder, |
| 206 | DenseMap<Value, Position *> &inputs, |
| 207 | TypePosition *pos) { |
| 208 | // Check for a constraint on a constant type. |
| 209 | if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) { |
| 210 | if (Attribute type = typeOp.getConstantTypeAttr()) |
| 211 | predList.emplace_back(args&: pos, args: builder.getTypeConstraint(type)); |
| 212 | } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) { |
| 213 | if (Attribute typeAttr = typeOp.getConstantTypesAttr()) |
| 214 | predList.emplace_back(args&: pos, args: builder.getTypeConstraint(type: typeAttr)); |
| 215 | } |
| 216 | } |
| 217 | |
| 218 | /// Collect the tree predicates anchored at the given value. |
| 219 | static void getTreePredicates(std::vector<PositionalPredicate> &predList, |
| 220 | Value val, PredicateBuilder &builder, |
| 221 | DenseMap<Value, Position *> &inputs, |
| 222 | Position *pos) { |
| 223 | // Make sure this input value is accessible to the rewrite. |
| 224 | auto it = inputs.try_emplace(Key: val, Args&: pos); |
| 225 | if (!it.second) { |
| 226 | // If this is an input value that has been visited in the tree, add a |
| 227 | // constraint to ensure that both instances refer to the same value. |
| 228 | if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp, |
| 229 | pdl::TypeOp>(Val: val.getDefiningOp())) { |
| 230 | auto minMaxPositions = |
| 231 | std::minmax(a: pos, b: it.first->second, comp: comparePosDepth); |
| 232 | predList.emplace_back(args: minMaxPositions.second, |
| 233 | args: builder.getEqualTo(pos: minMaxPositions.first)); |
| 234 | } |
| 235 | return; |
| 236 | } |
| 237 | |
| 238 | TypeSwitch<Position *>(pos) |
| 239 | .Case<AttributePosition, OperationPosition, TypePosition>(caseFn: [&](auto *pos) { |
| 240 | getTreePredicates(predList, val, builder, inputs, pos); |
| 241 | }) |
| 242 | .Case<OperandPosition, OperandGroupPosition>(caseFn: [&](auto *pos) { |
| 243 | getOperandTreePredicates(predList, val, builder, inputs, pos); |
| 244 | }) |
| 245 | .Default(defaultFn: [](auto *) { llvm_unreachable("unexpected position kind" ); }); |
| 246 | } |
| 247 | |
| 248 | static void getAttributePredicates(pdl::AttributeOp op, |
| 249 | std::vector<PositionalPredicate> &predList, |
| 250 | PredicateBuilder &builder, |
| 251 | DenseMap<Value, Position *> &inputs) { |
| 252 | Position *&attrPos = inputs[op]; |
| 253 | if (attrPos) |
| 254 | return; |
| 255 | Attribute value = op.getValueAttr(); |
| 256 | assert(value && "expected non-tree `pdl.attribute` to contain a value" ); |
| 257 | attrPos = builder.getAttributeLiteral(attr: value); |
| 258 | } |
| 259 | |
| 260 | static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, |
| 261 | std::vector<PositionalPredicate> &predList, |
| 262 | PredicateBuilder &builder, |
| 263 | DenseMap<Value, Position *> &inputs) { |
| 264 | OperandRange arguments = op.getArgs(); |
| 265 | |
| 266 | std::vector<Position *> allPositions; |
| 267 | allPositions.reserve(n: arguments.size()); |
| 268 | for (Value arg : arguments) |
| 269 | allPositions.push_back(x: inputs.lookup(Val: arg)); |
| 270 | |
| 271 | // Push the constraint to the furthest position. |
| 272 | Position *pos = *llvm::max_element(Range&: allPositions, C: comparePosDepth); |
| 273 | ResultRange results = op.getResults(); |
| 274 | PredicateBuilder::Predicate pred = builder.getConstraint( |
| 275 | name: op.getName(), args: allPositions, resultTypes: SmallVector<Type>(results.getTypes()), |
| 276 | isNegated: op.getIsNegated()); |
| 277 | |
| 278 | // For each result register a position so it can be used later |
| 279 | for (auto [i, result] : llvm::enumerate(First&: results)) { |
| 280 | ConstraintQuestion *q = cast<ConstraintQuestion>(Val: pred.first); |
| 281 | ConstraintPosition *pos = builder.getConstraintPosition(q, index: i); |
| 282 | auto [it, inserted] = inputs.try_emplace(Key: result, Args&: pos); |
| 283 | // If this is an input value that has been visited in the tree, add a |
| 284 | // constraint to ensure that both instances refer to the same value. |
| 285 | if (!inserted) { |
| 286 | Position *first = pos; |
| 287 | Position *second = it->second; |
| 288 | if (comparePosDepth(lhs: second, rhs: first)) |
| 289 | std::tie(args&: second, args&: first) = std::make_pair(x&: first, y&: second); |
| 290 | |
| 291 | predList.emplace_back(args&: second, args: builder.getEqualTo(pos: first)); |
| 292 | } |
| 293 | } |
| 294 | predList.emplace_back(args&: pos, args&: pred); |
| 295 | } |
| 296 | |
| 297 | static void getResultPredicates(pdl::ResultOp op, |
| 298 | std::vector<PositionalPredicate> &predList, |
| 299 | PredicateBuilder &builder, |
| 300 | DenseMap<Value, Position *> &inputs) { |
| 301 | Position *&resultPos = inputs[op]; |
| 302 | if (resultPos) |
| 303 | return; |
| 304 | |
| 305 | // Ensure that the result isn't null. |
| 306 | auto *parentPos = cast<OperationPosition>(Val: inputs.lookup(Val: op.getParent())); |
| 307 | resultPos = builder.getResult(p: parentPos, result: op.getIndex()); |
| 308 | predList.emplace_back(args&: resultPos, args: builder.getIsNotNull()); |
| 309 | } |
| 310 | |
| 311 | static void getResultPredicates(pdl::ResultsOp op, |
| 312 | std::vector<PositionalPredicate> &predList, |
| 313 | PredicateBuilder &builder, |
| 314 | DenseMap<Value, Position *> &inputs) { |
| 315 | Position *&resultPos = inputs[op]; |
| 316 | if (resultPos) |
| 317 | return; |
| 318 | |
| 319 | // Ensure that the result isn't null if the result has an index. |
| 320 | auto *parentPos = cast<OperationPosition>(Val: inputs.lookup(Val: op.getParent())); |
| 321 | bool isVariadic = isa<pdl::RangeType>(Val: op.getType()); |
| 322 | std::optional<unsigned> index = op.getIndex(); |
| 323 | resultPos = builder.getResultGroup(p: parentPos, group: index, isVariadic); |
| 324 | if (index) |
| 325 | predList.emplace_back(args&: resultPos, args: builder.getIsNotNull()); |
| 326 | } |
| 327 | |
| 328 | static void getTypePredicates(Value typeValue, |
| 329 | function_ref<Attribute()> typeAttrFn, |
| 330 | PredicateBuilder &builder, |
| 331 | DenseMap<Value, Position *> &inputs) { |
| 332 | Position *&typePos = inputs[typeValue]; |
| 333 | if (typePos) |
| 334 | return; |
| 335 | Attribute typeAttr = typeAttrFn(); |
| 336 | assert(typeAttr && |
| 337 | "expected non-tree `pdl.type`/`pdl.types` to contain a value" ); |
| 338 | typePos = builder.getTypeLiteral(attr: typeAttr); |
| 339 | } |
| 340 | |
| 341 | /// Collect all of the predicates that cannot be determined via walking the |
| 342 | /// tree. |
| 343 | static void getNonTreePredicates(pdl::PatternOp pattern, |
| 344 | std::vector<PositionalPredicate> &predList, |
| 345 | PredicateBuilder &builder, |
| 346 | DenseMap<Value, Position *> &inputs) { |
| 347 | for (Operation &op : pattern.getBodyRegion().getOps()) { |
| 348 | TypeSwitch<Operation *>(&op) |
| 349 | .Case(caseFn: [&](pdl::AttributeOp attrOp) { |
| 350 | getAttributePredicates(op: attrOp, predList, builder, inputs); |
| 351 | }) |
| 352 | .Case<pdl::ApplyNativeConstraintOp>(caseFn: [&](auto constraintOp) { |
| 353 | getConstraintPredicates(constraintOp, predList, builder, inputs); |
| 354 | }) |
| 355 | .Case<pdl::ResultOp, pdl::ResultsOp>(caseFn: [&](auto resultOp) { |
| 356 | getResultPredicates(resultOp, predList, builder, inputs); |
| 357 | }) |
| 358 | .Case(caseFn: [&](pdl::TypeOp typeOp) { |
| 359 | getTypePredicates( |
| 360 | typeValue: typeOp, typeAttrFn: [&] { return typeOp.getConstantTypeAttr(); }, builder, |
| 361 | inputs); |
| 362 | }) |
| 363 | .Case(caseFn: [&](pdl::TypesOp typeOp) { |
| 364 | getTypePredicates( |
| 365 | typeValue: typeOp, typeAttrFn: [&] { return typeOp.getConstantTypesAttr(); }, builder, |
| 366 | inputs); |
| 367 | }); |
| 368 | } |
| 369 | } |
| 370 | |
| 371 | namespace { |
| 372 | |
| 373 | /// An op accepting a value at an optional index. |
| 374 | struct OpIndex { |
| 375 | Value parent; |
| 376 | std::optional<unsigned> index; |
| 377 | }; |
| 378 | |
| 379 | /// The parent and operand index of each operation for each root, stored |
| 380 | /// as a nested map [root][operation]. |
| 381 | using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>; |
| 382 | |
| 383 | } // namespace |
| 384 | |
| 385 | /// Given a pattern, determines the set of roots present in this pattern. |
| 386 | /// These are the operations whose results are not consumed by other operations. |
| 387 | static SmallVector<Value> detectRoots(pdl::PatternOp pattern) { |
| 388 | // First, collect all the operations that are used as operands |
| 389 | // to other operations. These are not roots by default. |
| 390 | DenseSet<Value> used; |
| 391 | for (auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) { |
| 392 | for (Value operand : operationOp.getOperandValues()) |
| 393 | TypeSwitch<Operation *>(operand.getDefiningOp()) |
| 394 | .Case<pdl::ResultOp, pdl::ResultsOp>( |
| 395 | caseFn: [&used](auto resultOp) { used.insert(resultOp.getParent()); }); |
| 396 | } |
| 397 | |
| 398 | // Remove the specified root from the use set, so that we can |
| 399 | // always select it as a root, even if it is used by other operations. |
| 400 | if (Value root = pattern.getRewriter().getRoot()) |
| 401 | used.erase(V: root); |
| 402 | |
| 403 | // Finally, collect all the unused operations. |
| 404 | SmallVector<Value> roots; |
| 405 | for (Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) |
| 406 | if (!used.contains(V: operationOp)) |
| 407 | roots.push_back(Elt: operationOp); |
| 408 | |
| 409 | return roots; |
| 410 | } |
| 411 | |
| 412 | /// Given a list of candidate roots, builds the cost graph for connecting them. |
| 413 | /// The graph is formed by traversing the DAG of operations starting from each |
| 414 | /// root and marking the depth of each connector value (operand). Then we join |
| 415 | /// the candidate roots based on the common connector values, taking the one |
| 416 | /// with the minimum depth. Along the way, we compute, for each candidate root, |
| 417 | /// a mapping from each operation (in the DAG underneath this root) to its |
| 418 | /// parent operation and the corresponding operand index. |
| 419 | static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph, |
| 420 | ParentMaps &parentMaps) { |
| 421 | |
| 422 | // The entry of a queue. The entry consists of the following items: |
| 423 | // * the value in the DAG underneath the root; |
| 424 | // * the parent of the value; |
| 425 | // * the operand index of the value in its parent; |
| 426 | // * the depth of the visited value. |
| 427 | struct Entry { |
| 428 | Entry(Value value, Value parent, std::optional<unsigned> index, |
| 429 | unsigned depth) |
| 430 | : value(value), parent(parent), index(index), depth(depth) {} |
| 431 | |
| 432 | Value value; |
| 433 | Value parent; |
| 434 | std::optional<unsigned> index; |
| 435 | unsigned depth; |
| 436 | }; |
| 437 | |
| 438 | // A root of a value and its depth (distance from root to the value). |
| 439 | struct RootDepth { |
| 440 | Value root; |
| 441 | unsigned depth = 0; |
| 442 | }; |
| 443 | |
| 444 | // Map from candidate connector values to their roots and depths. Using a |
| 445 | // small vector with 1 entry because most values belong to a single root. |
| 446 | llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths; |
| 447 | |
| 448 | // Perform a breadth-first traversal of the op DAG rooted at each root. |
| 449 | for (Value root : roots) { |
| 450 | // The queue of visited values. A value may be present multiple times in |
| 451 | // the queue, for multiple parents. We only accept the first occurrence, |
| 452 | // which is guaranteed to have the lowest depth. |
| 453 | std::queue<Entry> toVisit; |
| 454 | toVisit.emplace(args&: root, args: Value(), args: 0, args: 0); |
| 455 | |
| 456 | // The map from value to its parent for the current root. |
| 457 | DenseMap<Value, OpIndex> &parentMap = parentMaps[root]; |
| 458 | |
| 459 | while (!toVisit.empty()) { |
| 460 | Entry entry = toVisit.front(); |
| 461 | toVisit.pop(); |
| 462 | // Skip if already visited. |
| 463 | if (!parentMap.insert(KV: {entry.value, {.parent: entry.parent, .index: entry.index}}).second) |
| 464 | continue; |
| 465 | |
| 466 | // Mark the root and depth of the value. |
| 467 | connectorsRootsDepths[entry.value].push_back(Elt: {.root: root, .depth: entry.depth}); |
| 468 | |
| 469 | // Traverse the operands of an operation and result ops. |
| 470 | // We intentionally do not traverse attributes and types, because those |
| 471 | // are expensive to join on. |
| 472 | TypeSwitch<Operation *>(entry.value.getDefiningOp()) |
| 473 | .Case<pdl::OperationOp>(caseFn: [&](auto operationOp) { |
| 474 | OperandRange operands = operationOp.getOperandValues(); |
| 475 | // Special case when we pass all the operands in one range. |
| 476 | // For those, the index is empty. |
| 477 | if (operands.size() == 1 && |
| 478 | isa<pdl::RangeType>(Val: operands[0].getType())) { |
| 479 | toVisit.emplace(args: operands[0], args&: entry.value, args: std::nullopt, |
| 480 | args: entry.depth + 1); |
| 481 | return; |
| 482 | } |
| 483 | |
| 484 | // Default case: visit all the operands. |
| 485 | for (const auto &p : |
| 486 | llvm::enumerate(operationOp.getOperandValues())) |
| 487 | toVisit.emplace(p.value(), entry.value, p.index(), |
| 488 | entry.depth + 1); |
| 489 | }) |
| 490 | .Case<pdl::ResultOp, pdl::ResultsOp>(caseFn: [&](auto resultOp) { |
| 491 | toVisit.emplace(resultOp.getParent(), entry.value, |
| 492 | resultOp.getIndex(), entry.depth); |
| 493 | }); |
| 494 | } |
| 495 | } |
| 496 | |
| 497 | // Now build the cost graph. |
| 498 | // This is simply a minimum over all depths for the target root. |
| 499 | unsigned nextID = 0; |
| 500 | for (const auto &connectorRootsDepths : connectorsRootsDepths) { |
| 501 | Value value = connectorRootsDepths.first; |
| 502 | ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second; |
| 503 | // If there is only one root for this value, this will not trigger |
| 504 | // any edges in the cost graph (a perf optimization). |
| 505 | if (rootsDepths.size() == 1) |
| 506 | continue; |
| 507 | |
| 508 | for (const RootDepth &p : rootsDepths) { |
| 509 | for (const RootDepth &q : rootsDepths) { |
| 510 | if (&p == &q) |
| 511 | continue; |
| 512 | // Insert or retrieve the property of edge from p to q. |
| 513 | RootOrderingEntry &entry = graph[q.root][p.root]; |
| 514 | if (!entry.connector /* new edge */ || entry.cost.first > q.depth) { |
| 515 | if (!entry.connector) |
| 516 | entry.cost.second = nextID++; |
| 517 | entry.cost.first = q.depth; |
| 518 | entry.connector = value; |
| 519 | } |
| 520 | } |
| 521 | } |
| 522 | } |
| 523 | |
| 524 | assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) && |
| 525 | "the pattern contains a candidate root disconnected from the others" ); |
| 526 | } |
| 527 | |
| 528 | /// Returns true if the operand at the given index needs to be queried using an |
| 529 | /// operand group, i.e., if it is variadic itself or follows a variadic operand. |
| 530 | static bool useOperandGroup(pdl::OperationOp op, unsigned index) { |
| 531 | OperandRange operands = op.getOperandValues(); |
| 532 | assert(index < operands.size() && "operand index out of range" ); |
| 533 | for (unsigned i = 0; i <= index; ++i) |
| 534 | if (isa<pdl::RangeType>(Val: operands[i].getType())) |
| 535 | return true; |
| 536 | return false; |
| 537 | } |
| 538 | |
| 539 | /// Visit a node during upward traversal. |
| 540 | static void visitUpward(std::vector<PositionalPredicate> &predList, |
| 541 | OpIndex opIndex, PredicateBuilder &builder, |
| 542 | DenseMap<Value, Position *> &valueToPosition, |
| 543 | Position *&pos, unsigned rootID) { |
| 544 | Value value = opIndex.parent; |
| 545 | TypeSwitch<Operation *>(value.getDefiningOp()) |
| 546 | .Case<pdl::OperationOp>(caseFn: [&](auto operationOp) { |
| 547 | LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" ); |
| 548 | |
| 549 | // Get users and iterate over them. |
| 550 | Position *usersPos = builder.getUsers(p: pos, /*useRepresentative=*/true); |
| 551 | Position *foreachPos = builder.getForEach(p: usersPos, id: rootID); |
| 552 | OperationPosition *opPos = builder.getPassthroughOp(p: foreachPos); |
| 553 | |
| 554 | // Compare the operand(s) of the user against the input value(s). |
| 555 | Position *operandPos; |
| 556 | if (!opIndex.index) { |
| 557 | // We are querying all the operands of the operation. |
| 558 | operandPos = builder.getAllOperands(p: opPos); |
| 559 | } else if (useOperandGroup(operationOp, *opIndex.index)) { |
| 560 | // We are querying an operand group. |
| 561 | Type type = operationOp.getOperandValues()[*opIndex.index].getType(); |
| 562 | bool variadic = isa<pdl::RangeType>(Val: type); |
| 563 | operandPos = builder.getOperandGroup(p: opPos, group: opIndex.index, isVariadic: variadic); |
| 564 | } else { |
| 565 | // We are querying an individual operand. |
| 566 | operandPos = builder.getOperand(p: opPos, operand: *opIndex.index); |
| 567 | } |
| 568 | predList.emplace_back(args&: operandPos, args: builder.getEqualTo(pos)); |
| 569 | |
| 570 | // Guard against duplicate upward visits. These are not possible, |
| 571 | // because if this value was already visited, it would have been |
| 572 | // cheaper to start the traversal at this value rather than at the |
| 573 | // `connector`, violating the optimality of our spanning tree. |
| 574 | bool inserted = valueToPosition.try_emplace(Key: value, Args&: opPos).second; |
| 575 | (void)inserted; |
| 576 | assert(inserted && "duplicate upward visit" ); |
| 577 | |
| 578 | // Obtain the tree predicates at the current value. |
| 579 | getTreePredicates(predList, val: value, builder, inputs&: valueToPosition, pos: opPos, |
| 580 | ignoreOperand: opIndex.index); |
| 581 | |
| 582 | // Update the position |
| 583 | pos = opPos; |
| 584 | }) |
| 585 | .Case<pdl::ResultOp>(caseFn: [&](auto resultOp) { |
| 586 | // Traverse up an individual result. |
| 587 | auto *opPos = dyn_cast<OperationPosition>(Val: pos); |
| 588 | assert(opPos && "operations and results must be interleaved" ); |
| 589 | pos = builder.getResult(p: opPos, result: *opIndex.index); |
| 590 | |
| 591 | // Insert the result position in case we have not visited it yet. |
| 592 | valueToPosition.try_emplace(Key: value, Args&: pos); |
| 593 | }) |
| 594 | .Case<pdl::ResultsOp>(caseFn: [&](auto resultOp) { |
| 595 | // Traverse up a group of results. |
| 596 | auto *opPos = dyn_cast<OperationPosition>(Val: pos); |
| 597 | assert(opPos && "operations and results must be interleaved" ); |
| 598 | bool isVariadic = isa<pdl::RangeType>(Val: value.getType()); |
| 599 | if (opIndex.index) |
| 600 | pos = builder.getResultGroup(p: opPos, group: opIndex.index, isVariadic); |
| 601 | else |
| 602 | pos = builder.getAllResults(p: opPos); |
| 603 | |
| 604 | // Insert the result position in case we have not visited it yet. |
| 605 | valueToPosition.try_emplace(Key: value, Args&: pos); |
| 606 | }); |
| 607 | } |
| 608 | |
| 609 | /// Given a pattern operation, build the set of matcher predicates necessary to |
| 610 | /// match this pattern. |
| 611 | static Value buildPredicateList(pdl::PatternOp pattern, |
| 612 | PredicateBuilder &builder, |
| 613 | std::vector<PositionalPredicate> &predList, |
| 614 | DenseMap<Value, Position *> &valueToPosition) { |
| 615 | SmallVector<Value> roots = detectRoots(pattern); |
| 616 | |
| 617 | // Build the root ordering graph and compute the parent maps. |
| 618 | RootOrderingGraph graph; |
| 619 | ParentMaps parentMaps; |
| 620 | buildCostGraph(roots, graph, parentMaps); |
| 621 | LLVM_DEBUG({ |
| 622 | llvm::dbgs() << "Graph:\n" ; |
| 623 | for (auto &target : graph) { |
| 624 | llvm::dbgs() << " * " << target.first.getLoc() << " " << target.first |
| 625 | << "\n" ; |
| 626 | for (auto &source : target.second) { |
| 627 | RootOrderingEntry &entry = source.second; |
| 628 | llvm::dbgs() << " <- " << source.first << ": " << entry.cost.first |
| 629 | << ":" << entry.cost.second << " via " |
| 630 | << entry.connector.getLoc() << "\n" ; |
| 631 | } |
| 632 | } |
| 633 | }); |
| 634 | |
| 635 | // Solve the optimal branching problem for each candidate root, or use the |
| 636 | // provided one. |
| 637 | Value bestRoot = pattern.getRewriter().getRoot(); |
| 638 | OptimalBranching::EdgeList bestEdges; |
| 639 | if (!bestRoot) { |
| 640 | unsigned bestCost = 0; |
| 641 | LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n" ); |
| 642 | for (Value root : roots) { |
| 643 | OptimalBranching solver(graph, root); |
| 644 | unsigned cost = solver.solve(); |
| 645 | LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n" ); |
| 646 | if (!bestRoot || bestCost > cost) { |
| 647 | bestCost = cost; |
| 648 | bestRoot = root; |
| 649 | bestEdges = solver.preOrderTraversal(nodes: roots); |
| 650 | } |
| 651 | } |
| 652 | } else { |
| 653 | OptimalBranching solver(graph, bestRoot); |
| 654 | solver.solve(); |
| 655 | bestEdges = solver.preOrderTraversal(nodes: roots); |
| 656 | } |
| 657 | |
| 658 | // Print the best solution. |
| 659 | LLVM_DEBUG({ |
| 660 | llvm::dbgs() << "Best tree:\n" ; |
| 661 | for (const std::pair<Value, Value> &edge : bestEdges) { |
| 662 | llvm::dbgs() << " * " << edge.first; |
| 663 | if (edge.second) |
| 664 | llvm::dbgs() << " <- " << edge.second; |
| 665 | llvm::dbgs() << "\n" ; |
| 666 | } |
| 667 | }); |
| 668 | |
| 669 | LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n" ); |
| 670 | LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n" ); |
| 671 | |
| 672 | // The best root is the starting point for the traversal. Get the tree |
| 673 | // predicates for the DAG rooted at bestRoot. |
| 674 | getTreePredicates(predList, val: bestRoot, builder, inputs&: valueToPosition, |
| 675 | pos: builder.getRoot()); |
| 676 | |
| 677 | // Traverse the selected optimal branching. For all edges in order, traverse |
| 678 | // up starting from the connector, until the candidate root is reached, and |
| 679 | // call getTreePredicates at every node along the way. |
| 680 | for (const auto &it : llvm::enumerate(First&: bestEdges)) { |
| 681 | Value target = it.value().first; |
| 682 | Value source = it.value().second; |
| 683 | |
| 684 | // Check if we already visited the target root. This happens in two cases: |
| 685 | // 1) the initial root (bestRoot); |
| 686 | // 2) a root that is dominated by (contained in the subtree rooted at) an |
| 687 | // already visited root. |
| 688 | if (valueToPosition.count(Val: target)) |
| 689 | continue; |
| 690 | |
| 691 | // Determine the connector. |
| 692 | Value connector = graph[target][source].connector; |
| 693 | assert(connector && "invalid edge" ); |
| 694 | LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n" ); |
| 695 | DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(Val: target); |
| 696 | Position *pos = valueToPosition.lookup(Val: connector); |
| 697 | assert(pos && "connector has not been traversed yet" ); |
| 698 | |
| 699 | // Traverse from the connector upwards towards the target root. |
| 700 | for (Value value = connector; value != target;) { |
| 701 | OpIndex opIndex = parentMap.lookup(Val: value); |
| 702 | assert(opIndex.parent && "missing parent" ); |
| 703 | visitUpward(predList, opIndex, builder, valueToPosition, pos, rootID: it.index()); |
| 704 | value = opIndex.parent; |
| 705 | } |
| 706 | } |
| 707 | |
| 708 | getNonTreePredicates(pattern, predList, builder, inputs&: valueToPosition); |
| 709 | |
| 710 | return bestRoot; |
| 711 | } |
| 712 | |
| 713 | //===----------------------------------------------------------------------===// |
| 714 | // Pattern Predicate Tree Merging |
| 715 | //===----------------------------------------------------------------------===// |
| 716 | |
| 717 | namespace { |
| 718 | |
| 719 | /// This class represents a specific predicate applied to a position, and |
| 720 | /// provides hashing and ordering operators. This class allows for computing a |
| 721 | /// frequence sum and ordering predicates based on a cost model. |
| 722 | struct OrderedPredicate { |
| 723 | OrderedPredicate(const std::pair<Position *, Qualifier *> &ip) |
| 724 | : position(ip.first), question(ip.second) {} |
| 725 | OrderedPredicate(const PositionalPredicate &ip) |
| 726 | : position(ip.position), question(ip.question) {} |
| 727 | |
| 728 | /// The position this predicate is applied to. |
| 729 | Position *position; |
| 730 | |
| 731 | /// The question that is applied by this predicate onto the position. |
| 732 | Qualifier *question; |
| 733 | |
| 734 | /// The first and second order benefit sums. |
| 735 | /// The primary sum is the number of occurrences of this predicate among all |
| 736 | /// of the patterns. |
| 737 | unsigned primary = 0; |
| 738 | /// The secondary sum is a squared summation of the primary sum of all of the |
| 739 | /// predicates within each pattern that contains this predicate. This allows |
| 740 | /// for favoring predicates that are more commonly shared within a pattern, as |
| 741 | /// opposed to those shared across patterns. |
| 742 | unsigned secondary = 0; |
| 743 | |
| 744 | /// The tie breaking ID, used to preserve a deterministic (insertion) order |
| 745 | /// among all the predicates with the same priority, depth, and position / |
| 746 | /// predicate dependency. |
| 747 | unsigned id = 0; |
| 748 | |
| 749 | /// A map between a pattern operation and the answer to the predicate question |
| 750 | /// within that pattern. |
| 751 | DenseMap<Operation *, Qualifier *> patternToAnswer; |
| 752 | |
| 753 | /// Returns true if this predicate is ordered before `rhs`, based on the cost |
| 754 | /// model. |
| 755 | bool operator<(const OrderedPredicate &rhs) const { |
| 756 | // Sort by: |
| 757 | // * higher first and secondary order sums |
| 758 | // * lower depth |
| 759 | // * lower position dependency |
| 760 | // * lower predicate dependency |
| 761 | // * lower tie breaking ID |
| 762 | auto *rhsPos = rhs.position; |
| 763 | return std::make_tuple(args: primary, args: secondary, args: rhsPos->getOperationDepth(), |
| 764 | args: rhsPos->getKind(), args: rhs.question->getKind(), args: rhs.id) > |
| 765 | std::make_tuple(args: rhs.primary, args: rhs.secondary, |
| 766 | args: position->getOperationDepth(), args: position->getKind(), |
| 767 | args: question->getKind(), args: id); |
| 768 | } |
| 769 | }; |
| 770 | |
| 771 | /// A DenseMapInfo for OrderedPredicate based solely on the position and |
| 772 | /// question. |
| 773 | struct OrderedPredicateDenseInfo { |
| 774 | using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>; |
| 775 | |
| 776 | static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); } |
| 777 | static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); } |
| 778 | static bool isEqual(const OrderedPredicate &lhs, |
| 779 | const OrderedPredicate &rhs) { |
| 780 | return lhs.position == rhs.position && lhs.question == rhs.question; |
| 781 | } |
| 782 | static unsigned getHashValue(const OrderedPredicate &p) { |
| 783 | return llvm::hash_combine(args: p.position, args: p.question); |
| 784 | } |
| 785 | }; |
| 786 | |
| 787 | /// This class wraps a set of ordered predicates that are used within a specific |
| 788 | /// pattern operation. |
| 789 | struct OrderedPredicateList { |
| 790 | OrderedPredicateList(pdl::PatternOp pattern, Value root) |
| 791 | : pattern(pattern), root(root) {} |
| 792 | |
| 793 | pdl::PatternOp pattern; |
| 794 | Value root; |
| 795 | DenseSet<OrderedPredicate *> predicates; |
| 796 | }; |
| 797 | } // namespace |
| 798 | |
| 799 | /// Returns true if the given matcher refers to the same predicate as the given |
| 800 | /// ordered predicate. This means that the position and questions of the two |
| 801 | /// match. |
| 802 | static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) { |
| 803 | return node->getPosition() == predicate->position && |
| 804 | node->getQuestion() == predicate->question; |
| 805 | } |
| 806 | |
| 807 | /// Get or insert a child matcher for the given parent switch node, given a |
| 808 | /// predicate and parent pattern. |
| 809 | std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node, |
| 810 | OrderedPredicate *predicate, |
| 811 | pdl::PatternOp pattern) { |
| 812 | assert(isSamePredicate(node, predicate) && |
| 813 | "expected matcher to equal the given predicate" ); |
| 814 | |
| 815 | auto it = predicate->patternToAnswer.find(Val: pattern); |
| 816 | assert(it != predicate->patternToAnswer.end() && |
| 817 | "expected pattern to exist in predicate" ); |
| 818 | return node->getChildren()[it->second]; |
| 819 | } |
| 820 | |
| 821 | /// Build the matcher CFG by "pushing" patterns through by sorted predicate |
| 822 | /// order. A pattern will traverse as far as possible using common predicates |
| 823 | /// and then either diverge from the CFG or reach the end of a branch and start |
| 824 | /// creating new nodes. |
| 825 | static void propagatePattern(std::unique_ptr<MatcherNode> &node, |
| 826 | OrderedPredicateList &list, |
| 827 | std::vector<OrderedPredicate *>::iterator current, |
| 828 | std::vector<OrderedPredicate *>::iterator end) { |
| 829 | if (current == end) { |
| 830 | // We've hit the end of a pattern, so create a successful result node. |
| 831 | node = |
| 832 | std::make_unique<SuccessNode>(args&: list.pattern, args&: list.root, args: std::move(node)); |
| 833 | |
| 834 | // If the pattern doesn't contain this predicate, ignore it. |
| 835 | } else if (!list.predicates.contains(V: *current)) { |
| 836 | propagatePattern(node, list, current: std::next(x: current), end); |
| 837 | |
| 838 | // If the current matcher node is invalid, create a new one for this |
| 839 | // position and continue propagation. |
| 840 | } else if (!node) { |
| 841 | // Create a new node at this position and continue |
| 842 | node = std::make_unique<SwitchNode>(args&: (*current)->position, |
| 843 | args&: (*current)->question); |
| 844 | propagatePattern( |
| 845 | node&: getOrCreateChild(node: cast<SwitchNode>(Val: &*node), predicate: *current, pattern: list.pattern), |
| 846 | list, current: std::next(x: current), end); |
| 847 | |
| 848 | // If the matcher has already been created, and it is for this predicate we |
| 849 | // continue propagation to the child. |
| 850 | } else if (isSamePredicate(node: node.get(), predicate: *current)) { |
| 851 | propagatePattern( |
| 852 | node&: getOrCreateChild(node: cast<SwitchNode>(Val: &*node), predicate: *current, pattern: list.pattern), |
| 853 | list, current: std::next(x: current), end); |
| 854 | |
| 855 | // If the matcher doesn't match the current predicate, insert a branch as |
| 856 | // the common set of matchers has diverged. |
| 857 | } else { |
| 858 | propagatePattern(node&: node->getFailureNode(), list, current, end); |
| 859 | } |
| 860 | } |
| 861 | |
| 862 | /// Fold any switch nodes nested under `node` to boolean nodes when possible. |
| 863 | /// `node` is updated in-place if it is a switch. |
| 864 | static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) { |
| 865 | if (!node) |
| 866 | return; |
| 867 | |
| 868 | if (SwitchNode *switchNode = dyn_cast<SwitchNode>(Val: &*node)) { |
| 869 | SwitchNode::ChildMapT &children = switchNode->getChildren(); |
| 870 | for (auto &it : children) |
| 871 | foldSwitchToBool(node&: it.second); |
| 872 | |
| 873 | // If the node only contains one child, collapse it into a boolean predicate |
| 874 | // node. |
| 875 | if (children.size() == 1) { |
| 876 | auto *childIt = children.begin(); |
| 877 | node = std::make_unique<BoolNode>( |
| 878 | args: node->getPosition(), args: node->getQuestion(), args&: childIt->first, |
| 879 | args: std::move(childIt->second), args: std::move(node->getFailureNode())); |
| 880 | } |
| 881 | } else if (BoolNode *boolNode = dyn_cast<BoolNode>(Val: &*node)) { |
| 882 | foldSwitchToBool(node&: boolNode->getSuccessNode()); |
| 883 | } |
| 884 | |
| 885 | foldSwitchToBool(node&: node->getFailureNode()); |
| 886 | } |
| 887 | |
| 888 | /// Insert an exit node at the end of the failure path of the `root`. |
| 889 | static void insertExitNode(std::unique_ptr<MatcherNode> *root) { |
| 890 | while (*root) |
| 891 | root = &(*root)->getFailureNode(); |
| 892 | *root = std::make_unique<ExitNode>(); |
| 893 | } |
| 894 | |
| 895 | /// Sorts the range begin/end with the partial order given by cmp. |
| 896 | template <typename Iterator, typename Compare> |
| 897 | static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) { |
| 898 | while (begin != end) { |
| 899 | // Cannot compute sortBeforeOthers in the predicate of stable_partition |
| 900 | // because stable_partition will not keep the [begin, end) range intact |
| 901 | // while it runs. |
| 902 | llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers; |
| 903 | for (auto i = begin; i != end; ++i) { |
| 904 | if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); })) |
| 905 | sortBeforeOthers.insert(*i); |
| 906 | } |
| 907 | |
| 908 | auto const next = std::stable_partition(begin, end, [&](auto const &a) { |
| 909 | return sortBeforeOthers.contains(a); |
| 910 | }); |
| 911 | assert(next != begin && "not a partial ordering" ); |
| 912 | begin = next; |
| 913 | } |
| 914 | } |
| 915 | |
| 916 | /// Returns true if 'b' depends on a result of 'a'. |
| 917 | static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) { |
| 918 | auto *cqa = dyn_cast<ConstraintQuestion>(Val: a->question); |
| 919 | if (!cqa) |
| 920 | return false; |
| 921 | |
| 922 | auto positionDependsOnA = [&](Position *p) { |
| 923 | auto *cp = dyn_cast<ConstraintPosition>(Val: p); |
| 924 | return cp && cp->getQuestion() == cqa; |
| 925 | }; |
| 926 | |
| 927 | if (auto *cqb = dyn_cast<ConstraintQuestion>(Val: b->question)) { |
| 928 | // Does any argument of b use a? |
| 929 | return llvm::any_of(Range: cqb->getArgs(), P: positionDependsOnA); |
| 930 | } |
| 931 | if (auto *equalTo = dyn_cast<EqualToQuestion>(Val: b->question)) { |
| 932 | return positionDependsOnA(b->position) || |
| 933 | positionDependsOnA(equalTo->getValue()); |
| 934 | } |
| 935 | return positionDependsOnA(b->position); |
| 936 | } |
| 937 | |
| 938 | /// Given a module containing PDL pattern operations, generate a matcher tree |
| 939 | /// using the patterns within the given module and return the root matcher node. |
| 940 | std::unique_ptr<MatcherNode> |
| 941 | MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder, |
| 942 | DenseMap<Value, Position *> &valueToPosition) { |
| 943 | // The set of predicates contained within the pattern operations of the |
| 944 | // module. |
| 945 | struct PatternPredicates { |
| 946 | PatternPredicates(pdl::PatternOp pattern, Value root, |
| 947 | std::vector<PositionalPredicate> predicates) |
| 948 | : pattern(pattern), root(root), predicates(std::move(predicates)) {} |
| 949 | |
| 950 | /// A pattern. |
| 951 | pdl::PatternOp pattern; |
| 952 | |
| 953 | /// A root of the pattern chosen among the candidate roots in pdl.rewrite. |
| 954 | Value root; |
| 955 | |
| 956 | /// The extracted predicates for this pattern and root. |
| 957 | std::vector<PositionalPredicate> predicates; |
| 958 | }; |
| 959 | |
| 960 | SmallVector<PatternPredicates, 16> patternsAndPredicates; |
| 961 | for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) { |
| 962 | std::vector<PositionalPredicate> predicateList; |
| 963 | Value root = |
| 964 | buildPredicateList(pattern, builder, predList&: predicateList, valueToPosition); |
| 965 | patternsAndPredicates.emplace_back(Args&: pattern, Args&: root, Args: std::move(predicateList)); |
| 966 | } |
| 967 | |
| 968 | // Associate a pattern result with each unique predicate. |
| 969 | DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued; |
| 970 | for (auto &patternAndPredList : patternsAndPredicates) { |
| 971 | for (auto &predicate : patternAndPredList.predicates) { |
| 972 | auto it = uniqued.insert(V: predicate); |
| 973 | it.first->patternToAnswer.try_emplace(Key: patternAndPredList.pattern, |
| 974 | Args&: predicate.answer); |
| 975 | // Mark the insertion order (0-based indexing). |
| 976 | if (it.second) |
| 977 | it.first->id = uniqued.size() - 1; |
| 978 | } |
| 979 | } |
| 980 | |
| 981 | // Associate each pattern to a set of its ordered predicates for later lookup. |
| 982 | std::vector<OrderedPredicateList> lists; |
| 983 | lists.reserve(n: patternsAndPredicates.size()); |
| 984 | for (auto &patternAndPredList : patternsAndPredicates) { |
| 985 | OrderedPredicateList list(patternAndPredList.pattern, |
| 986 | patternAndPredList.root); |
| 987 | for (auto &predicate : patternAndPredList.predicates) { |
| 988 | OrderedPredicate *orderedPredicate = &*uniqued.find(V: predicate); |
| 989 | list.predicates.insert(V: orderedPredicate); |
| 990 | |
| 991 | // Increment the primary sum for each reference to a particular predicate. |
| 992 | ++orderedPredicate->primary; |
| 993 | } |
| 994 | lists.push_back(x: std::move(list)); |
| 995 | } |
| 996 | |
| 997 | // For a particular pattern, get the total primary sum and add it to the |
| 998 | // secondary sum of each predicate. Square the primary sums to emphasize |
| 999 | // shared predicates within rather than across patterns. |
| 1000 | for (auto &list : lists) { |
| 1001 | unsigned total = 0; |
| 1002 | for (auto *predicate : list.predicates) |
| 1003 | total += predicate->primary * predicate->primary; |
| 1004 | for (auto *predicate : list.predicates) |
| 1005 | predicate->secondary += total; |
| 1006 | } |
| 1007 | |
| 1008 | // Sort the set of predicates now that the cost primary and secondary sums |
| 1009 | // have been computed. |
| 1010 | std::vector<OrderedPredicate *> ordered; |
| 1011 | ordered.reserve(n: uniqued.size()); |
| 1012 | for (auto &ip : uniqued) |
| 1013 | ordered.push_back(x: &ip); |
| 1014 | llvm::sort(C&: ordered, Comp: [](OrderedPredicate *lhs, OrderedPredicate *rhs) { |
| 1015 | return *lhs < *rhs; |
| 1016 | }); |
| 1017 | |
| 1018 | // Mostly keep the now established order, but also ensure that |
| 1019 | // ConstraintQuestions come after the results they use. |
| 1020 | stableTopologicalSort(begin: ordered.begin(), end: ordered.end(), cmp: dependsOn); |
| 1021 | |
| 1022 | // Build the matchers for each of the pattern predicate lists. |
| 1023 | std::unique_ptr<MatcherNode> root; |
| 1024 | for (OrderedPredicateList &list : lists) |
| 1025 | propagatePattern(node&: root, list, current: ordered.begin(), end: ordered.end()); |
| 1026 | |
| 1027 | // Collapse the graph and insert the exit node. |
| 1028 | foldSwitchToBool(node&: root); |
| 1029 | insertExitNode(root: &root); |
| 1030 | return root; |
| 1031 | } |
| 1032 | |
| 1033 | //===----------------------------------------------------------------------===// |
| 1034 | // MatcherNode |
| 1035 | //===----------------------------------------------------------------------===// |
| 1036 | |
| 1037 | MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q, |
| 1038 | std::unique_ptr<MatcherNode> failureNode) |
| 1039 | : position(p), question(q), failureNode(std::move(failureNode)), |
| 1040 | matcherTypeID(matcherTypeID) {} |
| 1041 | |
| 1042 | //===----------------------------------------------------------------------===// |
| 1043 | // BoolNode |
| 1044 | //===----------------------------------------------------------------------===// |
| 1045 | |
| 1046 | BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer, |
| 1047 | std::unique_ptr<MatcherNode> successNode, |
| 1048 | std::unique_ptr<MatcherNode> failureNode) |
| 1049 | : MatcherNode(TypeID::get<BoolNode>(), position, question, |
| 1050 | std::move(failureNode)), |
| 1051 | answer(answer), successNode(std::move(successNode)) {} |
| 1052 | |
| 1053 | //===----------------------------------------------------------------------===// |
| 1054 | // SuccessNode |
| 1055 | //===----------------------------------------------------------------------===// |
| 1056 | |
| 1057 | SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root, |
| 1058 | std::unique_ptr<MatcherNode> failureNode) |
| 1059 | : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr, |
| 1060 | /*question=*/nullptr, std::move(failureNode)), |
| 1061 | pattern(pattern), root(root) {} |
| 1062 | |
| 1063 | //===----------------------------------------------------------------------===// |
| 1064 | // SwitchNode |
| 1065 | //===----------------------------------------------------------------------===// |
| 1066 | |
| 1067 | SwitchNode::SwitchNode(Position *position, Qualifier *question) |
| 1068 | : MatcherNode(TypeID::get<SwitchNode>(), position, question) {} |
| 1069 | |