| 1 | //===- Predicate.h - Pattern predicates -------------------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file contains definitions for "predicates" used when converting PDL into |
| 10 | // a matcher tree. Predicates are composed of three different parts: |
| 11 | // |
| 12 | // * Positions |
| 13 | // - A position refers to a specific location on the input DAG, i.e. an |
| 14 | // existing MLIR entity being matched. These can be attributes, operands, |
| 15 | // operations, results, and types. Each position also defines a relation to |
| 16 | // its parent. For example, the operand `[0] -> 1` has a parent operation |
| 17 | // position `[0]`. The attribute `[0, 1] -> "myAttr"` has parent operation |
| 18 | // position of `[0, 1]`. The operation `[0, 1]` has a parent operand edge |
| 19 | // `[0] -> 1` (i.e. it is the defining op of operand 1). The only position |
| 20 | // without a parent is `[0]`, which refers to the root operation. |
| 21 | // * Questions |
| 22 | // - A question refers to a query on a specific positional value. For |
| 23 | // example, an operation name question checks the name of an operation |
| 24 | // position. |
| 25 | // * Answers |
| 26 | // - An answer is the expected result of a question. For example, when |
| 27 | // matching an operation with the name "foo.op". The question would be an |
| 28 | // operation name question, with an expected answer of "foo.op". |
| 29 | // |
| 30 | //===----------------------------------------------------------------------===// |
| 31 | |
| 32 | #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ |
| 33 | #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ |
| 34 | |
| 35 | #include "mlir/IR/MLIRContext.h" |
| 36 | #include "mlir/IR/OperationSupport.h" |
| 37 | #include "mlir/IR/PatternMatch.h" |
| 38 | #include "mlir/IR/Types.h" |
| 39 | |
| 40 | namespace mlir { |
| 41 | namespace pdl_to_pdl_interp { |
| 42 | namespace Predicates { |
| 43 | /// An enumeration of the kinds of predicates. |
| 44 | enum Kind : unsigned { |
| 45 | /// Positions, ordered by decreasing priority. |
| 46 | OperationPos, |
| 47 | OperandPos, |
| 48 | OperandGroupPos, |
| 49 | AttributePos, |
| 50 | ConstraintResultPos, |
| 51 | ResultPos, |
| 52 | ResultGroupPos, |
| 53 | TypePos, |
| 54 | AttributeLiteralPos, |
| 55 | TypeLiteralPos, |
| 56 | UsersPos, |
| 57 | ForEachPos, |
| 58 | |
| 59 | // Questions, ordered by dependency and decreasing priority. |
| 60 | IsNotNullQuestion, |
| 61 | OperationNameQuestion, |
| 62 | TypeQuestion, |
| 63 | AttributeQuestion, |
| 64 | OperandCountAtLeastQuestion, |
| 65 | OperandCountQuestion, |
| 66 | ResultCountAtLeastQuestion, |
| 67 | ResultCountQuestion, |
| 68 | EqualToQuestion, |
| 69 | ConstraintQuestion, |
| 70 | |
| 71 | // Answers. |
| 72 | AttributeAnswer, |
| 73 | FalseAnswer, |
| 74 | OperationNameAnswer, |
| 75 | TrueAnswer, |
| 76 | TypeAnswer, |
| 77 | UnsignedAnswer, |
| 78 | }; |
| 79 | } // namespace Predicates |
| 80 | |
| 81 | /// Base class for all predicates, used to allow efficient pointer comparison. |
| 82 | template <typename ConcreteT, typename BaseT, typename Key, |
| 83 | Predicates::Kind Kind> |
| 84 | class PredicateBase : public BaseT { |
| 85 | public: |
| 86 | using KeyTy = Key; |
| 87 | using Base = PredicateBase<ConcreteT, BaseT, Key, Kind>; |
| 88 | |
| 89 | template <typename KeyT> |
| 90 | explicit PredicateBase(KeyT &&key) |
| 91 | : BaseT(Kind), key(std::forward<KeyT>(key)) {} |
| 92 | |
| 93 | /// Get an instance of this position. |
| 94 | template <typename... Args> |
| 95 | static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) { |
| 96 | return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...); |
| 97 | } |
| 98 | |
| 99 | /// Construct an instance with the given storage allocator. |
| 100 | template <typename KeyT> |
| 101 | static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc, |
| 102 | KeyT &&key) { |
| 103 | return new (alloc.allocate<ConcreteT>()) ConcreteT(std::forward<KeyT>(key)); |
| 104 | } |
| 105 | |
| 106 | /// Utility methods required by the storage allocator. |
| 107 | bool operator==(const KeyTy &key) const { return this->key == key; } |
| 108 | static bool classof(const BaseT *pred) { return pred->getKind() == Kind; } |
| 109 | |
| 110 | /// Return the key value of this predicate. |
| 111 | const KeyTy &getValue() const { return key; } |
| 112 | |
| 113 | protected: |
| 114 | KeyTy key; |
| 115 | }; |
| 116 | |
| 117 | /// Base storage for simple predicates that only unique with the kind. |
| 118 | template <typename ConcreteT, typename BaseT, Predicates::Kind Kind> |
| 119 | class PredicateBase<ConcreteT, BaseT, void, Kind> : public BaseT { |
| 120 | public: |
| 121 | using Base = PredicateBase<ConcreteT, BaseT, void, Kind>; |
| 122 | |
| 123 | explicit PredicateBase() : BaseT(Kind) {} |
| 124 | |
| 125 | static ConcreteT *get(StorageUniquer &uniquer) { |
| 126 | return uniquer.get<ConcreteT>(); |
| 127 | } |
| 128 | static bool classof(const BaseT *pred) { return pred->getKind() == Kind; } |
| 129 | }; |
| 130 | |
| 131 | //===----------------------------------------------------------------------===// |
| 132 | // Positions |
| 133 | //===----------------------------------------------------------------------===// |
| 134 | |
| 135 | struct OperationPosition; |
| 136 | |
| 137 | /// A position describes a value on the input IR on which a predicate may be |
| 138 | /// applied, such as an operation or attribute. This enables re-use between |
| 139 | /// predicates, and assists generating bytecode and memory management. |
| 140 | /// |
| 141 | /// Operation positions form the base of other positions, which are formed |
| 142 | /// relative to a parent operation. Operations are anchored at Operand nodes, |
| 143 | /// except for the root operation which is parentless. |
| 144 | class Position : public StorageUniquer::BaseStorage { |
| 145 | public: |
| 146 | explicit Position(Predicates::Kind kind) : kind(kind) {} |
| 147 | virtual ~Position(); |
| 148 | |
| 149 | /// Returns the depth of the first ancestor operation position. |
| 150 | unsigned getOperationDepth() const; |
| 151 | |
| 152 | /// Returns the parent position. The root operation position has no parent. |
| 153 | Position *getParent() const { return parent; } |
| 154 | |
| 155 | /// Returns the kind of this position. |
| 156 | Predicates::Kind getKind() const { return kind; } |
| 157 | |
| 158 | protected: |
| 159 | /// Link to the parent position. |
| 160 | Position *parent = nullptr; |
| 161 | |
| 162 | private: |
| 163 | /// The kind of this position. |
| 164 | Predicates::Kind kind; |
| 165 | }; |
| 166 | |
| 167 | //===----------------------------------------------------------------------===// |
| 168 | // AttributePosition |
| 169 | //===----------------------------------------------------------------------===// |
| 170 | |
| 171 | /// A position describing an attribute of an operation. |
| 172 | struct AttributePosition |
| 173 | : public PredicateBase<AttributePosition, Position, |
| 174 | std::pair<OperationPosition *, StringAttr>, |
| 175 | Predicates::AttributePos> { |
| 176 | explicit AttributePosition(const KeyTy &key); |
| 177 | |
| 178 | /// Returns the attribute name of this position. |
| 179 | StringAttr getName() const { return key.second; } |
| 180 | }; |
| 181 | |
| 182 | //===----------------------------------------------------------------------===// |
| 183 | // AttributeLiteralPosition |
| 184 | //===----------------------------------------------------------------------===// |
| 185 | |
| 186 | /// A position describing a literal attribute. |
| 187 | struct AttributeLiteralPosition |
| 188 | : public PredicateBase<AttributeLiteralPosition, Position, Attribute, |
| 189 | Predicates::AttributeLiteralPos> { |
| 190 | using PredicateBase::PredicateBase; |
| 191 | }; |
| 192 | |
| 193 | //===----------------------------------------------------------------------===// |
| 194 | // ForEachPosition |
| 195 | //===----------------------------------------------------------------------===// |
| 196 | |
| 197 | /// A position describing an iterative choice of an operation. |
| 198 | struct ForEachPosition : public PredicateBase<ForEachPosition, Position, |
| 199 | std::pair<Position *, unsigned>, |
| 200 | Predicates::ForEachPos> { |
| 201 | explicit ForEachPosition(const KeyTy &key) : Base(key) { parent = key.first; } |
| 202 | |
| 203 | /// Returns the ID, for differentiating various loops. |
| 204 | /// For upward traversals, this is the index of the root. |
| 205 | unsigned getID() const { return key.second; } |
| 206 | }; |
| 207 | |
| 208 | //===----------------------------------------------------------------------===// |
| 209 | // OperandPosition |
| 210 | //===----------------------------------------------------------------------===// |
| 211 | |
| 212 | /// A position describing an operand of an operation. |
| 213 | struct OperandPosition |
| 214 | : public PredicateBase<OperandPosition, Position, |
| 215 | std::pair<OperationPosition *, unsigned>, |
| 216 | Predicates::OperandPos> { |
| 217 | explicit OperandPosition(const KeyTy &key); |
| 218 | |
| 219 | /// Returns the operand number of this position. |
| 220 | unsigned getOperandNumber() const { return key.second; } |
| 221 | }; |
| 222 | |
| 223 | //===----------------------------------------------------------------------===// |
| 224 | // OperandGroupPosition |
| 225 | //===----------------------------------------------------------------------===// |
| 226 | |
| 227 | /// A position describing an operand group of an operation. |
| 228 | struct OperandGroupPosition |
| 229 | : public PredicateBase< |
| 230 | OperandGroupPosition, Position, |
| 231 | std::tuple<OperationPosition *, std::optional<unsigned>, bool>, |
| 232 | Predicates::OperandGroupPos> { |
| 233 | explicit OperandGroupPosition(const KeyTy &key); |
| 234 | |
| 235 | /// Returns a hash suitable for the given keytype. |
| 236 | static llvm::hash_code hashKey(const KeyTy &key) { |
| 237 | return llvm::hash_value(arg: key); |
| 238 | } |
| 239 | |
| 240 | /// Returns the group number of this position. If std::nullopt, this group |
| 241 | /// refers to all operands. |
| 242 | std::optional<unsigned> getOperandGroupNumber() const { |
| 243 | return std::get<1>(t: key); |
| 244 | } |
| 245 | |
| 246 | /// Returns if the operand group has unknown size. If false, the operand group |
| 247 | /// has at max one element. |
| 248 | bool isVariadic() const { return std::get<2>(t: key); } |
| 249 | }; |
| 250 | |
| 251 | //===----------------------------------------------------------------------===// |
| 252 | // OperationPosition |
| 253 | //===----------------------------------------------------------------------===// |
| 254 | |
| 255 | /// An operation position describes an operation node in the IR. Other position |
| 256 | /// kinds are formed with respect to an operation position. |
| 257 | struct OperationPosition : public PredicateBase<OperationPosition, Position, |
| 258 | std::pair<Position *, unsigned>, |
| 259 | Predicates::OperationPos> { |
| 260 | explicit OperationPosition(const KeyTy &key) : Base(key) { |
| 261 | parent = key.first; |
| 262 | } |
| 263 | |
| 264 | /// Returns a hash suitable for the given keytype. |
| 265 | static llvm::hash_code hashKey(const KeyTy &key) { |
| 266 | return llvm::hash_value(arg: key); |
| 267 | } |
| 268 | |
| 269 | /// Gets the root position. |
| 270 | static OperationPosition *getRoot(StorageUniquer &uniquer) { |
| 271 | return Base::get(uniquer, args: nullptr, args: 0); |
| 272 | } |
| 273 | |
| 274 | /// Gets an operation position with the given parent. |
| 275 | static OperationPosition *get(StorageUniquer &uniquer, Position *parent) { |
| 276 | return Base::get(uniquer, args&: parent, args: parent->getOperationDepth() + 1); |
| 277 | } |
| 278 | |
| 279 | /// Returns the depth of this position. |
| 280 | unsigned getDepth() const { return key.second; } |
| 281 | |
| 282 | /// Returns if this operation position corresponds to the root. |
| 283 | bool isRoot() const { return getDepth() == 0; } |
| 284 | |
| 285 | /// Returns if this operation represents an operand defining op. |
| 286 | bool isOperandDefiningOp() const; |
| 287 | }; |
| 288 | |
| 289 | //===----------------------------------------------------------------------===// |
| 290 | // ConstraintPosition |
| 291 | //===----------------------------------------------------------------------===// |
| 292 | |
| 293 | struct ConstraintQuestion; |
| 294 | |
| 295 | /// A position describing the result of a native constraint. It saves the |
| 296 | /// corresponding ConstraintQuestion and result index to enable referring |
| 297 | /// back to them |
| 298 | struct ConstraintPosition |
| 299 | : public PredicateBase<ConstraintPosition, Position, |
| 300 | std::pair<ConstraintQuestion *, unsigned>, |
| 301 | Predicates::ConstraintResultPos> { |
| 302 | using PredicateBase::PredicateBase; |
| 303 | |
| 304 | /// Returns the ConstraintQuestion to enable keeping track of the native |
| 305 | /// constraint this position stems from. |
| 306 | ConstraintQuestion *getQuestion() const { return key.first; } |
| 307 | |
| 308 | // Returns the result index of this position |
| 309 | unsigned getIndex() const { return key.second; } |
| 310 | }; |
| 311 | |
| 312 | //===----------------------------------------------------------------------===// |
| 313 | // ResultPosition |
| 314 | //===----------------------------------------------------------------------===// |
| 315 | |
| 316 | /// A position describing a result of an operation. |
| 317 | struct ResultPosition |
| 318 | : public PredicateBase<ResultPosition, Position, |
| 319 | std::pair<OperationPosition *, unsigned>, |
| 320 | Predicates::ResultPos> { |
| 321 | explicit ResultPosition(const KeyTy &key) : Base(key) { parent = key.first; } |
| 322 | |
| 323 | /// Returns the result number of this position. |
| 324 | unsigned getResultNumber() const { return key.second; } |
| 325 | }; |
| 326 | |
| 327 | //===----------------------------------------------------------------------===// |
| 328 | // ResultGroupPosition |
| 329 | //===----------------------------------------------------------------------===// |
| 330 | |
| 331 | /// A position describing a result group of an operation. |
| 332 | struct ResultGroupPosition |
| 333 | : public PredicateBase< |
| 334 | ResultGroupPosition, Position, |
| 335 | std::tuple<OperationPosition *, std::optional<unsigned>, bool>, |
| 336 | Predicates::ResultGroupPos> { |
| 337 | explicit ResultGroupPosition(const KeyTy &key) : Base(key) { |
| 338 | parent = std::get<0>(t: key); |
| 339 | } |
| 340 | |
| 341 | /// Returns a hash suitable for the given keytype. |
| 342 | static llvm::hash_code hashKey(const KeyTy &key) { |
| 343 | return llvm::hash_value(arg: key); |
| 344 | } |
| 345 | |
| 346 | /// Returns the group number of this position. If std::nullopt, this group |
| 347 | /// refers to all results. |
| 348 | std::optional<unsigned> getResultGroupNumber() const { |
| 349 | return std::get<1>(t: key); |
| 350 | } |
| 351 | |
| 352 | /// Returns if the result group has unknown size. If false, the result group |
| 353 | /// has at max one element. |
| 354 | bool isVariadic() const { return std::get<2>(t: key); } |
| 355 | }; |
| 356 | |
| 357 | //===----------------------------------------------------------------------===// |
| 358 | // TypePosition |
| 359 | //===----------------------------------------------------------------------===// |
| 360 | |
| 361 | /// A position describing the result type of an entity, i.e. an Attribute, |
| 362 | /// Operand, Result, etc. |
| 363 | struct TypePosition : public PredicateBase<TypePosition, Position, Position *, |
| 364 | Predicates::TypePos> { |
| 365 | explicit TypePosition(const KeyTy &key) : Base(key) { |
| 366 | assert((isa<AttributePosition, OperandPosition, OperandGroupPosition, |
| 367 | ResultPosition, ResultGroupPosition>(key)) && |
| 368 | "expected parent to be an attribute, operand, or result" ); |
| 369 | parent = key; |
| 370 | } |
| 371 | }; |
| 372 | |
| 373 | //===----------------------------------------------------------------------===// |
| 374 | // TypeLiteralPosition |
| 375 | //===----------------------------------------------------------------------===// |
| 376 | |
| 377 | /// A position describing a literal type or type range. The value is stored as |
| 378 | /// either a TypeAttr, or an ArrayAttr of TypeAttr. |
| 379 | struct TypeLiteralPosition |
| 380 | : public PredicateBase<TypeLiteralPosition, Position, Attribute, |
| 381 | Predicates::TypeLiteralPos> { |
| 382 | using PredicateBase::PredicateBase; |
| 383 | }; |
| 384 | |
| 385 | //===----------------------------------------------------------------------===// |
| 386 | // UsersPosition |
| 387 | //===----------------------------------------------------------------------===// |
| 388 | |
| 389 | /// A position describing the users of a value or a range of values. The second |
| 390 | /// value in the key indicates whether we choose users of a representative for |
| 391 | /// a range (this is true, e.g., in the upward traversals). |
| 392 | struct UsersPosition |
| 393 | : public PredicateBase<UsersPosition, Position, std::pair<Position *, bool>, |
| 394 | Predicates::UsersPos> { |
| 395 | explicit UsersPosition(const KeyTy &key) : Base(key) { parent = key.first; } |
| 396 | |
| 397 | /// Returns a hash suitable for the given keytype. |
| 398 | static llvm::hash_code hashKey(const KeyTy &key) { |
| 399 | return llvm::hash_value(arg: key); |
| 400 | } |
| 401 | |
| 402 | /// Indicates whether to compute a range of a representative. |
| 403 | bool useRepresentative() const { return key.second; } |
| 404 | }; |
| 405 | |
| 406 | //===----------------------------------------------------------------------===// |
| 407 | // Qualifiers |
| 408 | //===----------------------------------------------------------------------===// |
| 409 | |
| 410 | /// An ordinal predicate consists of a "Question" and a set of acceptable |
| 411 | /// "Answers" (later converted to ordinal values). A predicate will query some |
| 412 | /// property of a positional value and decide what to do based on the result. |
| 413 | /// |
| 414 | /// This makes top-level predicate representations ordinal (SwitchOp). Later, |
| 415 | /// predicates that end up with only one acceptable answer (including all |
| 416 | /// boolean kinds) will be converted to boolean predicates (PredicateOp) in the |
| 417 | /// matcher. |
| 418 | /// |
| 419 | /// For simplicity, both are represented as "qualifiers", with a base kind and |
| 420 | /// perhaps additional properties. For example, all OperationName predicates ask |
| 421 | /// the same question, but GenericConstraint predicates may ask different ones. |
| 422 | class Qualifier : public StorageUniquer::BaseStorage { |
| 423 | public: |
| 424 | explicit Qualifier(Predicates::Kind kind) : kind(kind) {} |
| 425 | |
| 426 | /// Returns the kind of this qualifier. |
| 427 | Predicates::Kind getKind() const { return kind; } |
| 428 | |
| 429 | private: |
| 430 | /// The kind of this position. |
| 431 | Predicates::Kind kind; |
| 432 | }; |
| 433 | |
| 434 | //===----------------------------------------------------------------------===// |
| 435 | // Answers |
| 436 | //===----------------------------------------------------------------------===// |
| 437 | |
| 438 | /// An Answer representing an `Attribute` value. |
| 439 | struct AttributeAnswer |
| 440 | : public PredicateBase<AttributeAnswer, Qualifier, Attribute, |
| 441 | Predicates::AttributeAnswer> { |
| 442 | using Base::Base; |
| 443 | }; |
| 444 | |
| 445 | /// An Answer representing an `OperationName` value. |
| 446 | struct OperationNameAnswer |
| 447 | : public PredicateBase<OperationNameAnswer, Qualifier, OperationName, |
| 448 | Predicates::OperationNameAnswer> { |
| 449 | using Base::Base; |
| 450 | }; |
| 451 | |
| 452 | /// An Answer representing a boolean `true` value. |
| 453 | struct TrueAnswer |
| 454 | : PredicateBase<TrueAnswer, Qualifier, void, Predicates::TrueAnswer> { |
| 455 | using Base::Base; |
| 456 | }; |
| 457 | |
| 458 | /// An Answer representing a boolean 'false' value. |
| 459 | struct FalseAnswer |
| 460 | : PredicateBase<FalseAnswer, Qualifier, void, Predicates::FalseAnswer> { |
| 461 | using Base::Base; |
| 462 | }; |
| 463 | |
| 464 | /// An Answer representing a `Type` value. The value is stored as either a |
| 465 | /// TypeAttr, or an ArrayAttr of TypeAttr. |
| 466 | struct TypeAnswer : public PredicateBase<TypeAnswer, Qualifier, Attribute, |
| 467 | Predicates::TypeAnswer> { |
| 468 | using Base::Base; |
| 469 | }; |
| 470 | |
| 471 | /// An Answer representing an unsigned value. |
| 472 | struct UnsignedAnswer |
| 473 | : public PredicateBase<UnsignedAnswer, Qualifier, unsigned, |
| 474 | Predicates::UnsignedAnswer> { |
| 475 | using Base::Base; |
| 476 | }; |
| 477 | |
| 478 | //===----------------------------------------------------------------------===// |
| 479 | // Questions |
| 480 | //===----------------------------------------------------------------------===// |
| 481 | |
| 482 | /// Compare an `Attribute` to a constant value. |
| 483 | struct AttributeQuestion |
| 484 | : public PredicateBase<AttributeQuestion, Qualifier, void, |
| 485 | Predicates::AttributeQuestion> {}; |
| 486 | |
| 487 | /// Apply a parameterized constraint to multiple position values and possibly |
| 488 | /// produce results. |
| 489 | struct ConstraintQuestion |
| 490 | : public PredicateBase< |
| 491 | ConstraintQuestion, Qualifier, |
| 492 | std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>, |
| 493 | Predicates::ConstraintQuestion> { |
| 494 | using Base::Base; |
| 495 | |
| 496 | /// Return the name of the constraint. |
| 497 | StringRef getName() const { return std::get<0>(t: key); } |
| 498 | |
| 499 | /// Return the arguments of the constraint. |
| 500 | ArrayRef<Position *> getArgs() const { return std::get<1>(t: key); } |
| 501 | |
| 502 | /// Return the result types of the constraint. |
| 503 | ArrayRef<Type> getResultTypes() const { return std::get<2>(t: key); } |
| 504 | |
| 505 | /// Return the negation status of the constraint. |
| 506 | bool getIsNegated() const { return std::get<3>(t: key); } |
| 507 | |
| 508 | /// Construct an instance with the given storage allocator. |
| 509 | static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc, |
| 510 | KeyTy key) { |
| 511 | return Base::construct(alloc, key: KeyTy{alloc.copyInto(str: std::get<0>(t&: key)), |
| 512 | alloc.copyInto(elements: std::get<1>(t&: key)), |
| 513 | alloc.copyInto(elements: std::get<2>(t&: key)), |
| 514 | std::get<3>(t&: key)}); |
| 515 | } |
| 516 | |
| 517 | /// Returns a hash suitable for the given keytype. |
| 518 | static llvm::hash_code hashKey(const KeyTy &key) { |
| 519 | return llvm::hash_value(arg: key); |
| 520 | } |
| 521 | }; |
| 522 | |
| 523 | /// Compare the equality of two values. |
| 524 | struct EqualToQuestion |
| 525 | : public PredicateBase<EqualToQuestion, Qualifier, Position *, |
| 526 | Predicates::EqualToQuestion> { |
| 527 | using Base::Base; |
| 528 | }; |
| 529 | |
| 530 | /// Compare a positional value with null, i.e. check if it exists. |
| 531 | struct IsNotNullQuestion |
| 532 | : public PredicateBase<IsNotNullQuestion, Qualifier, void, |
| 533 | Predicates::IsNotNullQuestion> {}; |
| 534 | |
| 535 | /// Compare the number of operands of an operation with a known value. |
| 536 | struct OperandCountQuestion |
| 537 | : public PredicateBase<OperandCountQuestion, Qualifier, void, |
| 538 | Predicates::OperandCountQuestion> {}; |
| 539 | struct OperandCountAtLeastQuestion |
| 540 | : public PredicateBase<OperandCountAtLeastQuestion, Qualifier, void, |
| 541 | Predicates::OperandCountAtLeastQuestion> {}; |
| 542 | |
| 543 | /// Compare the name of an operation with a known value. |
| 544 | struct OperationNameQuestion |
| 545 | : public PredicateBase<OperationNameQuestion, Qualifier, void, |
| 546 | Predicates::OperationNameQuestion> {}; |
| 547 | |
| 548 | /// Compare the number of results of an operation with a known value. |
| 549 | struct ResultCountQuestion |
| 550 | : public PredicateBase<ResultCountQuestion, Qualifier, void, |
| 551 | Predicates::ResultCountQuestion> {}; |
| 552 | struct ResultCountAtLeastQuestion |
| 553 | : public PredicateBase<ResultCountAtLeastQuestion, Qualifier, void, |
| 554 | Predicates::ResultCountAtLeastQuestion> {}; |
| 555 | |
| 556 | /// Compare the type of an attribute or value with a known type. |
| 557 | struct TypeQuestion : public PredicateBase<TypeQuestion, Qualifier, void, |
| 558 | Predicates::TypeQuestion> {}; |
| 559 | |
| 560 | //===----------------------------------------------------------------------===// |
| 561 | // PredicateUniquer |
| 562 | //===----------------------------------------------------------------------===// |
| 563 | |
| 564 | /// This class provides a storage uniquer that is used to allocate predicate |
| 565 | /// instances. |
| 566 | class PredicateUniquer : public StorageUniquer { |
| 567 | public: |
| 568 | PredicateUniquer() { |
| 569 | // Register the types of Positions with the uniquer. |
| 570 | registerParametricStorageType<AttributePosition>(); |
| 571 | registerParametricStorageType<AttributeLiteralPosition>(); |
| 572 | registerParametricStorageType<ConstraintPosition>(); |
| 573 | registerParametricStorageType<ForEachPosition>(); |
| 574 | registerParametricStorageType<OperandPosition>(); |
| 575 | registerParametricStorageType<OperandGroupPosition>(); |
| 576 | registerParametricStorageType<OperationPosition>(); |
| 577 | registerParametricStorageType<ResultPosition>(); |
| 578 | registerParametricStorageType<ResultGroupPosition>(); |
| 579 | registerParametricStorageType<TypePosition>(); |
| 580 | registerParametricStorageType<TypeLiteralPosition>(); |
| 581 | registerParametricStorageType<UsersPosition>(); |
| 582 | |
| 583 | // Register the types of Questions with the uniquer. |
| 584 | registerParametricStorageType<AttributeAnswer>(); |
| 585 | registerParametricStorageType<OperationNameAnswer>(); |
| 586 | registerParametricStorageType<TypeAnswer>(); |
| 587 | registerParametricStorageType<UnsignedAnswer>(); |
| 588 | registerSingletonStorageType<FalseAnswer>(); |
| 589 | registerSingletonStorageType<TrueAnswer>(); |
| 590 | |
| 591 | // Register the types of Answers with the uniquer. |
| 592 | registerParametricStorageType<ConstraintQuestion>(); |
| 593 | registerParametricStorageType<EqualToQuestion>(); |
| 594 | registerSingletonStorageType<AttributeQuestion>(); |
| 595 | registerSingletonStorageType<IsNotNullQuestion>(); |
| 596 | registerSingletonStorageType<OperandCountQuestion>(); |
| 597 | registerSingletonStorageType<OperandCountAtLeastQuestion>(); |
| 598 | registerSingletonStorageType<OperationNameQuestion>(); |
| 599 | registerSingletonStorageType<ResultCountQuestion>(); |
| 600 | registerSingletonStorageType<ResultCountAtLeastQuestion>(); |
| 601 | registerSingletonStorageType<TypeQuestion>(); |
| 602 | } |
| 603 | }; |
| 604 | |
| 605 | //===----------------------------------------------------------------------===// |
| 606 | // PredicateBuilder |
| 607 | //===----------------------------------------------------------------------===// |
| 608 | |
| 609 | /// This class provides utilities for constructing predicates. |
| 610 | class PredicateBuilder { |
| 611 | public: |
| 612 | PredicateBuilder(PredicateUniquer &uniquer, MLIRContext *ctx) |
| 613 | : uniquer(uniquer), ctx(ctx) {} |
| 614 | |
| 615 | //===--------------------------------------------------------------------===// |
| 616 | // Positions |
| 617 | //===--------------------------------------------------------------------===// |
| 618 | |
| 619 | /// Returns the root operation position. |
| 620 | Position *getRoot() { return OperationPosition::getRoot(uniquer); } |
| 621 | |
| 622 | /// Returns the parent position defining the value held by the given operand. |
| 623 | OperationPosition *getOperandDefiningOp(Position *p) { |
| 624 | assert((isa<OperandPosition, OperandGroupPosition>(p)) && |
| 625 | "expected operand position" ); |
| 626 | return OperationPosition::get(uniquer, parent: p); |
| 627 | } |
| 628 | |
| 629 | /// Returns the operation position equivalent to the given position. |
| 630 | OperationPosition *getPassthroughOp(Position *p) { |
| 631 | assert((isa<ForEachPosition>(p)) && "expected users position" ); |
| 632 | return OperationPosition::get(uniquer, parent: p); |
| 633 | } |
| 634 | |
| 635 | // Returns a position for a new value created by a constraint. |
| 636 | ConstraintPosition *getConstraintPosition(ConstraintQuestion *q, |
| 637 | unsigned index) { |
| 638 | return ConstraintPosition::get(uniquer, args: std::make_pair(x&: q, y&: index)); |
| 639 | } |
| 640 | |
| 641 | /// Returns an attribute position for an attribute of the given operation. |
| 642 | Position *getAttribute(OperationPosition *p, StringRef name) { |
| 643 | return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name)); |
| 644 | } |
| 645 | |
| 646 | /// Returns an attribute position for the given attribute. |
| 647 | Position *getAttributeLiteral(Attribute attr) { |
| 648 | return AttributeLiteralPosition::get(uniquer, args&: attr); |
| 649 | } |
| 650 | |
| 651 | Position *getForEach(Position *p, unsigned id) { |
| 652 | return ForEachPosition::get(uniquer, args&: p, args&: id); |
| 653 | } |
| 654 | |
| 655 | /// Returns an operand position for an operand of the given operation. |
| 656 | Position *getOperand(OperationPosition *p, unsigned operand) { |
| 657 | return OperandPosition::get(uniquer, args&: p, args&: operand); |
| 658 | } |
| 659 | |
| 660 | /// Returns a position for a group of operands of the given operation. |
| 661 | Position *getOperandGroup(OperationPosition *p, std::optional<unsigned> group, |
| 662 | bool isVariadic) { |
| 663 | return OperandGroupPosition::get(uniquer, args&: p, args&: group, args&: isVariadic); |
| 664 | } |
| 665 | Position *getAllOperands(OperationPosition *p) { |
| 666 | return getOperandGroup(p, /*group=*/group: std::nullopt, /*isVariadic=*/isVariadic: true); |
| 667 | } |
| 668 | |
| 669 | /// Returns a result position for a result of the given operation. |
| 670 | Position *getResult(OperationPosition *p, unsigned result) { |
| 671 | return ResultPosition::get(uniquer, args&: p, args&: result); |
| 672 | } |
| 673 | |
| 674 | /// Returns a position for a group of results of the given operation. |
| 675 | Position *getResultGroup(OperationPosition *p, std::optional<unsigned> group, |
| 676 | bool isVariadic) { |
| 677 | return ResultGroupPosition::get(uniquer, args&: p, args&: group, args&: isVariadic); |
| 678 | } |
| 679 | Position *getAllResults(OperationPosition *p) { |
| 680 | return getResultGroup(p, /*group=*/group: std::nullopt, /*isVariadic=*/isVariadic: true); |
| 681 | } |
| 682 | |
| 683 | /// Returns a type position for the given entity. |
| 684 | Position *getType(Position *p) { return TypePosition::get(uniquer, args&: p); } |
| 685 | |
| 686 | /// Returns a type position for the given type value. The value is stored |
| 687 | /// as either a TypeAttr, or an ArrayAttr of TypeAttr. |
| 688 | Position *getTypeLiteral(Attribute attr) { |
| 689 | return TypeLiteralPosition::get(uniquer, args&: attr); |
| 690 | } |
| 691 | |
| 692 | /// Returns the users of a position using the value at the given operand. |
| 693 | UsersPosition *getUsers(Position *p, bool useRepresentative) { |
| 694 | assert((isa<OperandPosition, OperandGroupPosition, ResultPosition, |
| 695 | ResultGroupPosition>(p)) && |
| 696 | "expected result position" ); |
| 697 | return UsersPosition::get(uniquer, args&: p, args&: useRepresentative); |
| 698 | } |
| 699 | |
| 700 | //===--------------------------------------------------------------------===// |
| 701 | // Qualifiers |
| 702 | //===--------------------------------------------------------------------===// |
| 703 | |
| 704 | /// An ordinal predicate consists of a "Question" and a set of acceptable |
| 705 | /// "Answers" (later converted to ordinal values). A predicate will query some |
| 706 | /// property of a positional value and decide what to do based on the result. |
| 707 | using Predicate = std::pair<Qualifier *, Qualifier *>; |
| 708 | |
| 709 | /// Create a predicate comparing an attribute to a known value. |
| 710 | Predicate getAttributeConstraint(Attribute attr) { |
| 711 | return {AttributeQuestion::get(uniquer), |
| 712 | AttributeAnswer::get(uniquer, args&: attr)}; |
| 713 | } |
| 714 | |
| 715 | /// Create a predicate checking if two values are equal. |
| 716 | Predicate getEqualTo(Position *pos) { |
| 717 | return {EqualToQuestion::get(uniquer, args&: pos), TrueAnswer::get(uniquer)}; |
| 718 | } |
| 719 | |
| 720 | /// Create a predicate checking if two values are not equal. |
| 721 | Predicate getNotEqualTo(Position *pos) { |
| 722 | return {EqualToQuestion::get(uniquer, args&: pos), FalseAnswer::get(uniquer)}; |
| 723 | } |
| 724 | |
| 725 | /// Create a predicate that applies a generic constraint. |
| 726 | Predicate getConstraint(StringRef name, ArrayRef<Position *> args, |
| 727 | ArrayRef<Type> resultTypes, bool isNegated) { |
| 728 | return {ConstraintQuestion::get( |
| 729 | uniquer, args: std::make_tuple(args&: name, args&: args, args&: resultTypes, args&: isNegated)), |
| 730 | TrueAnswer::get(uniquer)}; |
| 731 | } |
| 732 | |
| 733 | /// Create a predicate comparing a value with null. |
| 734 | Predicate getIsNotNull() { |
| 735 | return {IsNotNullQuestion::get(uniquer), TrueAnswer::get(uniquer)}; |
| 736 | } |
| 737 | |
| 738 | /// Create a predicate comparing the number of operands of an operation to a |
| 739 | /// known value. |
| 740 | Predicate getOperandCount(unsigned count) { |
| 741 | return {OperandCountQuestion::get(uniquer), |
| 742 | UnsignedAnswer::get(uniquer, args&: count)}; |
| 743 | } |
| 744 | Predicate getOperandCountAtLeast(unsigned count) { |
| 745 | return {OperandCountAtLeastQuestion::get(uniquer), |
| 746 | UnsignedAnswer::get(uniquer, args&: count)}; |
| 747 | } |
| 748 | |
| 749 | /// Create a predicate comparing the name of an operation to a known value. |
| 750 | Predicate getOperationName(StringRef name) { |
| 751 | return {OperationNameQuestion::get(uniquer), |
| 752 | OperationNameAnswer::get(uniquer, args: OperationName(name, ctx))}; |
| 753 | } |
| 754 | |
| 755 | /// Create a predicate comparing the number of results of an operation to a |
| 756 | /// known value. |
| 757 | Predicate getResultCount(unsigned count) { |
| 758 | return {ResultCountQuestion::get(uniquer), |
| 759 | UnsignedAnswer::get(uniquer, args&: count)}; |
| 760 | } |
| 761 | Predicate getResultCountAtLeast(unsigned count) { |
| 762 | return {ResultCountAtLeastQuestion::get(uniquer), |
| 763 | UnsignedAnswer::get(uniquer, args&: count)}; |
| 764 | } |
| 765 | |
| 766 | /// Create a predicate comparing the type of an attribute or value to a known |
| 767 | /// type. The value is stored as either a TypeAttr, or an ArrayAttr of |
| 768 | /// TypeAttr. |
| 769 | Predicate getTypeConstraint(Attribute type) { |
| 770 | return {TypeQuestion::get(uniquer), TypeAnswer::get(uniquer, args&: type)}; |
| 771 | } |
| 772 | |
| 773 | private: |
| 774 | /// The uniquer used when allocating predicate nodes. |
| 775 | PredicateUniquer &uniquer; |
| 776 | |
| 777 | /// The current MLIR context. |
| 778 | MLIRContext *ctx; |
| 779 | }; |
| 780 | |
| 781 | } // namespace pdl_to_pdl_interp |
| 782 | } // namespace mlir |
| 783 | |
| 784 | #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ |
| 785 | |