1//===- Predicate.h - Pattern predicates -------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains definitions for "predicates" used when converting PDL into
10// a matcher tree. Predicates are composed of three different parts:
11//
12// * Positions
13// - A position refers to a specific location on the input DAG, i.e. an
14// existing MLIR entity being matched. These can be attributes, operands,
15// operations, results, and types. Each position also defines a relation to
16// its parent. For example, the operand `[0] -> 1` has a parent operation
17// position `[0]`. The attribute `[0, 1] -> "myAttr"` has parent operation
18// position of `[0, 1]`. The operation `[0, 1]` has a parent operand edge
19// `[0] -> 1` (i.e. it is the defining op of operand 1). The only position
20// without a parent is `[0]`, which refers to the root operation.
21// * Questions
22// - A question refers to a query on a specific positional value. For
23// example, an operation name question checks the name of an operation
24// position.
25// * Answers
26// - An answer is the expected result of a question. For example, when
27// matching an operation with the name "foo.op". The question would be an
28// operation name question, with an expected answer of "foo.op".
29//
30//===----------------------------------------------------------------------===//
31
32#ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
33#define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
34
35#include "mlir/IR/MLIRContext.h"
36#include "mlir/IR/OperationSupport.h"
37#include "mlir/IR/PatternMatch.h"
38#include "mlir/IR/Types.h"
39
40namespace mlir {
41namespace pdl_to_pdl_interp {
42namespace Predicates {
43/// An enumeration of the kinds of predicates.
44enum Kind : unsigned {
45 /// Positions, ordered by decreasing priority.
46 OperationPos,
47 OperandPos,
48 OperandGroupPos,
49 AttributePos,
50 ConstraintResultPos,
51 ResultPos,
52 ResultGroupPos,
53 TypePos,
54 AttributeLiteralPos,
55 TypeLiteralPos,
56 UsersPos,
57 ForEachPos,
58
59 // Questions, ordered by dependency and decreasing priority.
60 IsNotNullQuestion,
61 OperationNameQuestion,
62 TypeQuestion,
63 AttributeQuestion,
64 OperandCountAtLeastQuestion,
65 OperandCountQuestion,
66 ResultCountAtLeastQuestion,
67 ResultCountQuestion,
68 EqualToQuestion,
69 ConstraintQuestion,
70
71 // Answers.
72 AttributeAnswer,
73 FalseAnswer,
74 OperationNameAnswer,
75 TrueAnswer,
76 TypeAnswer,
77 UnsignedAnswer,
78};
79} // namespace Predicates
80
81/// Base class for all predicates, used to allow efficient pointer comparison.
82template <typename ConcreteT, typename BaseT, typename Key,
83 Predicates::Kind Kind>
84class PredicateBase : public BaseT {
85public:
86 using KeyTy = Key;
87 using Base = PredicateBase<ConcreteT, BaseT, Key, Kind>;
88
89 template <typename KeyT>
90 explicit PredicateBase(KeyT &&key)
91 : BaseT(Kind), key(std::forward<KeyT>(key)) {}
92
93 /// Get an instance of this position.
94 template <typename... Args>
95 static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) {
96 return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...);
97 }
98
99 /// Construct an instance with the given storage allocator.
100 template <typename KeyT>
101 static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
102 KeyT &&key) {
103 return new (alloc.allocate<ConcreteT>()) ConcreteT(std::forward<KeyT>(key));
104 }
105
106 /// Utility methods required by the storage allocator.
107 bool operator==(const KeyTy &key) const { return this->key == key; }
108 static bool classof(const BaseT *pred) { return pred->getKind() == Kind; }
109
110 /// Return the key value of this predicate.
111 const KeyTy &getValue() const { return key; }
112
113protected:
114 KeyTy key;
115};
116
117/// Base storage for simple predicates that only unique with the kind.
118template <typename ConcreteT, typename BaseT, Predicates::Kind Kind>
119class PredicateBase<ConcreteT, BaseT, void, Kind> : public BaseT {
120public:
121 using Base = PredicateBase<ConcreteT, BaseT, void, Kind>;
122
123 explicit PredicateBase() : BaseT(Kind) {}
124
125 static ConcreteT *get(StorageUniquer &uniquer) {
126 return uniquer.get<ConcreteT>();
127 }
128 static bool classof(const BaseT *pred) { return pred->getKind() == Kind; }
129};
130
131//===----------------------------------------------------------------------===//
132// Positions
133//===----------------------------------------------------------------------===//
134
135struct OperationPosition;
136
137/// A position describes a value on the input IR on which a predicate may be
138/// applied, such as an operation or attribute. This enables re-use between
139/// predicates, and assists generating bytecode and memory management.
140///
141/// Operation positions form the base of other positions, which are formed
142/// relative to a parent operation. Operations are anchored at Operand nodes,
143/// except for the root operation which is parentless.
144class Position : public StorageUniquer::BaseStorage {
145public:
146 explicit Position(Predicates::Kind kind) : kind(kind) {}
147 virtual ~Position();
148
149 /// Returns the depth of the first ancestor operation position.
150 unsigned getOperationDepth() const;
151
152 /// Returns the parent position. The root operation position has no parent.
153 Position *getParent() const { return parent; }
154
155 /// Returns the kind of this position.
156 Predicates::Kind getKind() const { return kind; }
157
158protected:
159 /// Link to the parent position.
160 Position *parent = nullptr;
161
162private:
163 /// The kind of this position.
164 Predicates::Kind kind;
165};
166
167//===----------------------------------------------------------------------===//
168// AttributePosition
169
170/// A position describing an attribute of an operation.
171struct AttributePosition
172 : public PredicateBase<AttributePosition, Position,
173 std::pair<OperationPosition *, StringAttr>,
174 Predicates::AttributePos> {
175 explicit AttributePosition(const KeyTy &key);
176
177 /// Returns the attribute name of this position.
178 StringAttr getName() const { return key.second; }
179};
180
181//===----------------------------------------------------------------------===//
182// AttributeLiteralPosition
183
184/// A position describing a literal attribute.
185struct AttributeLiteralPosition
186 : public PredicateBase<AttributeLiteralPosition, Position, Attribute,
187 Predicates::AttributeLiteralPos> {
188 using PredicateBase::PredicateBase;
189};
190
191//===----------------------------------------------------------------------===//
192// ForEachPosition
193
194/// A position describing an iterative choice of an operation.
195struct ForEachPosition : public PredicateBase<ForEachPosition, Position,
196 std::pair<Position *, unsigned>,
197 Predicates::ForEachPos> {
198 explicit ForEachPosition(const KeyTy &key) : Base(key) { parent = key.first; }
199
200 /// Returns the ID, for differentiating various loops.
201 /// For upward traversals, this is the index of the root.
202 unsigned getID() const { return key.second; }
203};
204
205//===----------------------------------------------------------------------===//
206// OperandPosition
207
208/// A position describing an operand of an operation.
209struct OperandPosition
210 : public PredicateBase<OperandPosition, Position,
211 std::pair<OperationPosition *, unsigned>,
212 Predicates::OperandPos> {
213 explicit OperandPosition(const KeyTy &key);
214
215 /// Returns the operand number of this position.
216 unsigned getOperandNumber() const { return key.second; }
217};
218
219//===----------------------------------------------------------------------===//
220// OperandGroupPosition
221
222/// A position describing an operand group of an operation.
223struct OperandGroupPosition
224 : public PredicateBase<
225 OperandGroupPosition, Position,
226 std::tuple<OperationPosition *, std::optional<unsigned>, bool>,
227 Predicates::OperandGroupPos> {
228 explicit OperandGroupPosition(const KeyTy &key);
229
230 /// Returns a hash suitable for the given keytype.
231 static llvm::hash_code hashKey(const KeyTy &key) {
232 return llvm::hash_value(arg: key);
233 }
234
235 /// Returns the group number of this position. If std::nullopt, this group
236 /// refers to all operands.
237 std::optional<unsigned> getOperandGroupNumber() const {
238 return std::get<1>(t: key);
239 }
240
241 /// Returns if the operand group has unknown size. If false, the operand group
242 /// has at max one element.
243 bool isVariadic() const { return std::get<2>(t: key); }
244};
245
246//===----------------------------------------------------------------------===//
247// OperationPosition
248
249/// An operation position describes an operation node in the IR. Other position
250/// kinds are formed with respect to an operation position.
251struct OperationPosition : public PredicateBase<OperationPosition, Position,
252 std::pair<Position *, unsigned>,
253 Predicates::OperationPos> {
254 explicit OperationPosition(const KeyTy &key) : Base(key) {
255 parent = key.first;
256 }
257
258 /// Returns a hash suitable for the given keytype.
259 static llvm::hash_code hashKey(const KeyTy &key) {
260 return llvm::hash_value(arg: key);
261 }
262
263 /// Gets the root position.
264 static OperationPosition *getRoot(StorageUniquer &uniquer) {
265 return Base::get(uniquer, args: nullptr, args: 0);
266 }
267
268 /// Gets an operation position with the given parent.
269 static OperationPosition *get(StorageUniquer &uniquer, Position *parent) {
270 return Base::get(uniquer, args&: parent, args: parent->getOperationDepth() + 1);
271 }
272
273 /// Returns the depth of this position.
274 unsigned getDepth() const { return key.second; }
275
276 /// Returns if this operation position corresponds to the root.
277 bool isRoot() const { return getDepth() == 0; }
278
279 /// Returns if this operation represents an operand defining op.
280 bool isOperandDefiningOp() const;
281};
282
283//===----------------------------------------------------------------------===//
284// ConstraintPosition
285
286struct ConstraintQuestion;
287
288/// A position describing the result of a native constraint. It saves the
289/// corresponding ConstraintQuestion and result index to enable referring
290/// back to them
291struct ConstraintPosition
292 : public PredicateBase<ConstraintPosition, Position,
293 std::pair<ConstraintQuestion *, unsigned>,
294 Predicates::ConstraintResultPos> {
295 using PredicateBase::PredicateBase;
296
297 /// Returns the ConstraintQuestion to enable keeping track of the native
298 /// constraint this position stems from.
299 ConstraintQuestion *getQuestion() const { return key.first; }
300
301 // Returns the result index of this position
302 unsigned getIndex() const { return key.second; }
303};
304
305//===----------------------------------------------------------------------===//
306// ResultPosition
307
308/// A position describing a result of an operation.
309struct ResultPosition
310 : public PredicateBase<ResultPosition, Position,
311 std::pair<OperationPosition *, unsigned>,
312 Predicates::ResultPos> {
313 explicit ResultPosition(const KeyTy &key) : Base(key) { parent = key.first; }
314
315 /// Returns the result number of this position.
316 unsigned getResultNumber() const { return key.second; }
317};
318
319//===----------------------------------------------------------------------===//
320// ResultGroupPosition
321
322/// A position describing a result group of an operation.
323struct ResultGroupPosition
324 : public PredicateBase<
325 ResultGroupPosition, Position,
326 std::tuple<OperationPosition *, std::optional<unsigned>, bool>,
327 Predicates::ResultGroupPos> {
328 explicit ResultGroupPosition(const KeyTy &key) : Base(key) {
329 parent = std::get<0>(t: key);
330 }
331
332 /// Returns a hash suitable for the given keytype.
333 static llvm::hash_code hashKey(const KeyTy &key) {
334 return llvm::hash_value(arg: key);
335 }
336
337 /// Returns the group number of this position. If std::nullopt, this group
338 /// refers to all results.
339 std::optional<unsigned> getResultGroupNumber() const {
340 return std::get<1>(t: key);
341 }
342
343 /// Returns if the result group has unknown size. If false, the result group
344 /// has at max one element.
345 bool isVariadic() const { return std::get<2>(t: key); }
346};
347
348//===----------------------------------------------------------------------===//
349// TypePosition
350
351/// A position describing the result type of an entity, i.e. an Attribute,
352/// Operand, Result, etc.
353struct TypePosition : public PredicateBase<TypePosition, Position, Position *,
354 Predicates::TypePos> {
355 explicit TypePosition(const KeyTy &key) : Base(key) {
356 assert((isa<AttributePosition, OperandPosition, OperandGroupPosition,
357 ResultPosition, ResultGroupPosition>(key)) &&
358 "expected parent to be an attribute, operand, or result");
359 parent = key;
360 }
361};
362
363//===----------------------------------------------------------------------===//
364// TypeLiteralPosition
365
366/// A position describing a literal type or type range. The value is stored as
367/// either a TypeAttr, or an ArrayAttr of TypeAttr.
368struct TypeLiteralPosition
369 : public PredicateBase<TypeLiteralPosition, Position, Attribute,
370 Predicates::TypeLiteralPos> {
371 using PredicateBase::PredicateBase;
372};
373
374//===----------------------------------------------------------------------===//
375// UsersPosition
376
377/// A position describing the users of a value or a range of values. The second
378/// value in the key indicates whether we choose users of a representative for
379/// a range (this is true, e.g., in the upward traversals).
380struct UsersPosition
381 : public PredicateBase<UsersPosition, Position, std::pair<Position *, bool>,
382 Predicates::UsersPos> {
383 explicit UsersPosition(const KeyTy &key) : Base(key) { parent = key.first; }
384
385 /// Returns a hash suitable for the given keytype.
386 static llvm::hash_code hashKey(const KeyTy &key) {
387 return llvm::hash_value(arg: key);
388 }
389
390 /// Indicates whether to compute a range of a representative.
391 bool useRepresentative() const { return key.second; }
392};
393
394//===----------------------------------------------------------------------===//
395// Qualifiers
396//===----------------------------------------------------------------------===//
397
398/// An ordinal predicate consists of a "Question" and a set of acceptable
399/// "Answers" (later converted to ordinal values). A predicate will query some
400/// property of a positional value and decide what to do based on the result.
401///
402/// This makes top-level predicate representations ordinal (SwitchOp). Later,
403/// predicates that end up with only one acceptable answer (including all
404/// boolean kinds) will be converted to boolean predicates (PredicateOp) in the
405/// matcher.
406///
407/// For simplicity, both are represented as "qualifiers", with a base kind and
408/// perhaps additional properties. For example, all OperationName predicates ask
409/// the same question, but GenericConstraint predicates may ask different ones.
410class Qualifier : public StorageUniquer::BaseStorage {
411public:
412 explicit Qualifier(Predicates::Kind kind) : kind(kind) {}
413
414 /// Returns the kind of this qualifier.
415 Predicates::Kind getKind() const { return kind; }
416
417private:
418 /// The kind of this position.
419 Predicates::Kind kind;
420};
421
422//===----------------------------------------------------------------------===//
423// Answers
424
425/// An Answer representing an `Attribute` value.
426struct AttributeAnswer
427 : public PredicateBase<AttributeAnswer, Qualifier, Attribute,
428 Predicates::AttributeAnswer> {
429 using Base::Base;
430};
431
432/// An Answer representing an `OperationName` value.
433struct OperationNameAnswer
434 : public PredicateBase<OperationNameAnswer, Qualifier, OperationName,
435 Predicates::OperationNameAnswer> {
436 using Base::Base;
437};
438
439/// An Answer representing a boolean `true` value.
440struct TrueAnswer
441 : PredicateBase<TrueAnswer, Qualifier, void, Predicates::TrueAnswer> {
442 using Base::Base;
443};
444
445/// An Answer representing a boolean 'false' value.
446struct FalseAnswer
447 : PredicateBase<FalseAnswer, Qualifier, void, Predicates::FalseAnswer> {
448 using Base::Base;
449};
450
451/// An Answer representing a `Type` value. The value is stored as either a
452/// TypeAttr, or an ArrayAttr of TypeAttr.
453struct TypeAnswer : public PredicateBase<TypeAnswer, Qualifier, Attribute,
454 Predicates::TypeAnswer> {
455 using Base::Base;
456};
457
458/// An Answer representing an unsigned value.
459struct UnsignedAnswer
460 : public PredicateBase<UnsignedAnswer, Qualifier, unsigned,
461 Predicates::UnsignedAnswer> {
462 using Base::Base;
463};
464
465//===----------------------------------------------------------------------===//
466// Questions
467
468/// Compare an `Attribute` to a constant value.
469struct AttributeQuestion
470 : public PredicateBase<AttributeQuestion, Qualifier, void,
471 Predicates::AttributeQuestion> {};
472
473/// Apply a parameterized constraint to multiple position values and possibly
474/// produce results.
475struct ConstraintQuestion
476 : public PredicateBase<
477 ConstraintQuestion, Qualifier,
478 std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
479 Predicates::ConstraintQuestion> {
480 using Base::Base;
481
482 /// Return the name of the constraint.
483 StringRef getName() const { return std::get<0>(t: key); }
484
485 /// Return the arguments of the constraint.
486 ArrayRef<Position *> getArgs() const { return std::get<1>(t: key); }
487
488 /// Return the result types of the constraint.
489 ArrayRef<Type> getResultTypes() const { return std::get<2>(t: key); }
490
491 /// Return the negation status of the constraint.
492 bool getIsNegated() const { return std::get<3>(t: key); }
493
494 /// Construct an instance with the given storage allocator.
495 static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
496 KeyTy key) {
497 return Base::construct(alloc, key: KeyTy{alloc.copyInto(str: std::get<0>(t&: key)),
498 alloc.copyInto(elements: std::get<1>(t&: key)),
499 alloc.copyInto(elements: std::get<2>(t&: key)),
500 std::get<3>(t&: key)});
501 }
502
503 /// Returns a hash suitable for the given keytype.
504 static llvm::hash_code hashKey(const KeyTy &key) {
505 return llvm::hash_value(arg: key);
506 }
507};
508
509/// Compare the equality of two values.
510struct EqualToQuestion
511 : public PredicateBase<EqualToQuestion, Qualifier, Position *,
512 Predicates::EqualToQuestion> {
513 using Base::Base;
514};
515
516/// Compare a positional value with null, i.e. check if it exists.
517struct IsNotNullQuestion
518 : public PredicateBase<IsNotNullQuestion, Qualifier, void,
519 Predicates::IsNotNullQuestion> {};
520
521/// Compare the number of operands of an operation with a known value.
522struct OperandCountQuestion
523 : public PredicateBase<OperandCountQuestion, Qualifier, void,
524 Predicates::OperandCountQuestion> {};
525struct OperandCountAtLeastQuestion
526 : public PredicateBase<OperandCountAtLeastQuestion, Qualifier, void,
527 Predicates::OperandCountAtLeastQuestion> {};
528
529/// Compare the name of an operation with a known value.
530struct OperationNameQuestion
531 : public PredicateBase<OperationNameQuestion, Qualifier, void,
532 Predicates::OperationNameQuestion> {};
533
534/// Compare the number of results of an operation with a known value.
535struct ResultCountQuestion
536 : public PredicateBase<ResultCountQuestion, Qualifier, void,
537 Predicates::ResultCountQuestion> {};
538struct ResultCountAtLeastQuestion
539 : public PredicateBase<ResultCountAtLeastQuestion, Qualifier, void,
540 Predicates::ResultCountAtLeastQuestion> {};
541
542/// Compare the type of an attribute or value with a known type.
543struct TypeQuestion : public PredicateBase<TypeQuestion, Qualifier, void,
544 Predicates::TypeQuestion> {};
545
546//===----------------------------------------------------------------------===//
547// PredicateUniquer
548//===----------------------------------------------------------------------===//
549
550/// This class provides a storage uniquer that is used to allocate predicate
551/// instances.
552class PredicateUniquer : public StorageUniquer {
553public:
554 PredicateUniquer() {
555 // Register the types of Positions with the uniquer.
556 registerParametricStorageType<AttributePosition>();
557 registerParametricStorageType<AttributeLiteralPosition>();
558 registerParametricStorageType<ConstraintPosition>();
559 registerParametricStorageType<ForEachPosition>();
560 registerParametricStorageType<OperandPosition>();
561 registerParametricStorageType<OperandGroupPosition>();
562 registerParametricStorageType<OperationPosition>();
563 registerParametricStorageType<ResultPosition>();
564 registerParametricStorageType<ResultGroupPosition>();
565 registerParametricStorageType<TypePosition>();
566 registerParametricStorageType<TypeLiteralPosition>();
567 registerParametricStorageType<UsersPosition>();
568
569 // Register the types of Questions with the uniquer.
570 registerParametricStorageType<AttributeAnswer>();
571 registerParametricStorageType<OperationNameAnswer>();
572 registerParametricStorageType<TypeAnswer>();
573 registerParametricStorageType<UnsignedAnswer>();
574 registerSingletonStorageType<FalseAnswer>();
575 registerSingletonStorageType<TrueAnswer>();
576
577 // Register the types of Answers with the uniquer.
578 registerParametricStorageType<ConstraintQuestion>();
579 registerParametricStorageType<EqualToQuestion>();
580 registerSingletonStorageType<AttributeQuestion>();
581 registerSingletonStorageType<IsNotNullQuestion>();
582 registerSingletonStorageType<OperandCountQuestion>();
583 registerSingletonStorageType<OperandCountAtLeastQuestion>();
584 registerSingletonStorageType<OperationNameQuestion>();
585 registerSingletonStorageType<ResultCountQuestion>();
586 registerSingletonStorageType<ResultCountAtLeastQuestion>();
587 registerSingletonStorageType<TypeQuestion>();
588 }
589};
590
591//===----------------------------------------------------------------------===//
592// PredicateBuilder
593//===----------------------------------------------------------------------===//
594
595/// This class provides utilities for constructing predicates.
596class PredicateBuilder {
597public:
598 PredicateBuilder(PredicateUniquer &uniquer, MLIRContext *ctx)
599 : uniquer(uniquer), ctx(ctx) {}
600
601 //===--------------------------------------------------------------------===//
602 // Positions
603 //===--------------------------------------------------------------------===//
604
605 /// Returns the root operation position.
606 Position *getRoot() { return OperationPosition::getRoot(uniquer); }
607
608 /// Returns the parent position defining the value held by the given operand.
609 OperationPosition *getOperandDefiningOp(Position *p) {
610 assert((isa<OperandPosition, OperandGroupPosition>(p)) &&
611 "expected operand position");
612 return OperationPosition::get(uniquer, parent: p);
613 }
614
615 /// Returns the operation position equivalent to the given position.
616 OperationPosition *getPassthroughOp(Position *p) {
617 assert((isa<ForEachPosition>(p)) && "expected users position");
618 return OperationPosition::get(uniquer, parent: p);
619 }
620
621 // Returns a position for a new value created by a constraint.
622 ConstraintPosition *getConstraintPosition(ConstraintQuestion *q,
623 unsigned index) {
624 return ConstraintPosition::get(uniquer, args: std::make_pair(x&: q, y&: index));
625 }
626
627 /// Returns an attribute position for an attribute of the given operation.
628 Position *getAttribute(OperationPosition *p, StringRef name) {
629 return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
630 }
631
632 /// Returns an attribute position for the given attribute.
633 Position *getAttributeLiteral(Attribute attr) {
634 return AttributeLiteralPosition::get(uniquer, args&: attr);
635 }
636
637 Position *getForEach(Position *p, unsigned id) {
638 return ForEachPosition::get(uniquer, args&: p, args&: id);
639 }
640
641 /// Returns an operand position for an operand of the given operation.
642 Position *getOperand(OperationPosition *p, unsigned operand) {
643 return OperandPosition::get(uniquer, args&: p, args&: operand);
644 }
645
646 /// Returns a position for a group of operands of the given operation.
647 Position *getOperandGroup(OperationPosition *p, std::optional<unsigned> group,
648 bool isVariadic) {
649 return OperandGroupPosition::get(uniquer, args&: p, args&: group, args&: isVariadic);
650 }
651 Position *getAllOperands(OperationPosition *p) {
652 return getOperandGroup(p, /*group=*/group: std::nullopt, /*isVariadic=*/isVariadic: true);
653 }
654
655 /// Returns a result position for a result of the given operation.
656 Position *getResult(OperationPosition *p, unsigned result) {
657 return ResultPosition::get(uniquer, args&: p, args&: result);
658 }
659
660 /// Returns a position for a group of results of the given operation.
661 Position *getResultGroup(OperationPosition *p, std::optional<unsigned> group,
662 bool isVariadic) {
663 return ResultGroupPosition::get(uniquer, args&: p, args&: group, args&: isVariadic);
664 }
665 Position *getAllResults(OperationPosition *p) {
666 return getResultGroup(p, /*group=*/group: std::nullopt, /*isVariadic=*/isVariadic: true);
667 }
668
669 /// Returns a type position for the given entity.
670 Position *getType(Position *p) { return TypePosition::get(uniquer, args&: p); }
671
672 /// Returns a type position for the given type value. The value is stored
673 /// as either a TypeAttr, or an ArrayAttr of TypeAttr.
674 Position *getTypeLiteral(Attribute attr) {
675 return TypeLiteralPosition::get(uniquer, args&: attr);
676 }
677
678 /// Returns the users of a position using the value at the given operand.
679 UsersPosition *getUsers(Position *p, bool useRepresentative) {
680 assert((isa<OperandPosition, OperandGroupPosition, ResultPosition,
681 ResultGroupPosition>(p)) &&
682 "expected result position");
683 return UsersPosition::get(uniquer, args&: p, args&: useRepresentative);
684 }
685
686 //===--------------------------------------------------------------------===//
687 // Qualifiers
688 //===--------------------------------------------------------------------===//
689
690 /// An ordinal predicate consists of a "Question" and a set of acceptable
691 /// "Answers" (later converted to ordinal values). A predicate will query some
692 /// property of a positional value and decide what to do based on the result.
693 using Predicate = std::pair<Qualifier *, Qualifier *>;
694
695 /// Create a predicate comparing an attribute to a known value.
696 Predicate getAttributeConstraint(Attribute attr) {
697 return {AttributeQuestion::get(uniquer),
698 AttributeAnswer::get(uniquer, args&: attr)};
699 }
700
701 /// Create a predicate checking if two values are equal.
702 Predicate getEqualTo(Position *pos) {
703 return {EqualToQuestion::get(uniquer, args&: pos), TrueAnswer::get(uniquer)};
704 }
705
706 /// Create a predicate checking if two values are not equal.
707 Predicate getNotEqualTo(Position *pos) {
708 return {EqualToQuestion::get(uniquer, args&: pos), FalseAnswer::get(uniquer)};
709 }
710
711 /// Create a predicate that applies a generic constraint.
712 Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
713 ArrayRef<Type> resultTypes, bool isNegated) {
714 return {ConstraintQuestion::get(
715 uniquer, args: std::make_tuple(args&: name, args&: args, args&: resultTypes, args&: isNegated)),
716 TrueAnswer::get(uniquer)};
717 }
718
719 /// Create a predicate comparing a value with null.
720 Predicate getIsNotNull() {
721 return {IsNotNullQuestion::get(uniquer), TrueAnswer::get(uniquer)};
722 }
723
724 /// Create a predicate comparing the number of operands of an operation to a
725 /// known value.
726 Predicate getOperandCount(unsigned count) {
727 return {OperandCountQuestion::get(uniquer),
728 UnsignedAnswer::get(uniquer, args&: count)};
729 }
730 Predicate getOperandCountAtLeast(unsigned count) {
731 return {OperandCountAtLeastQuestion::get(uniquer),
732 UnsignedAnswer::get(uniquer, args&: count)};
733 }
734
735 /// Create a predicate comparing the name of an operation to a known value.
736 Predicate getOperationName(StringRef name) {
737 return {OperationNameQuestion::get(uniquer),
738 OperationNameAnswer::get(uniquer, args: OperationName(name, ctx))};
739 }
740
741 /// Create a predicate comparing the number of results of an operation to a
742 /// known value.
743 Predicate getResultCount(unsigned count) {
744 return {ResultCountQuestion::get(uniquer),
745 UnsignedAnswer::get(uniquer, args&: count)};
746 }
747 Predicate getResultCountAtLeast(unsigned count) {
748 return {ResultCountAtLeastQuestion::get(uniquer),
749 UnsignedAnswer::get(uniquer, args&: count)};
750 }
751
752 /// Create a predicate comparing the type of an attribute or value to a known
753 /// type. The value is stored as either a TypeAttr, or an ArrayAttr of
754 /// TypeAttr.
755 Predicate getTypeConstraint(Attribute type) {
756 return {TypeQuestion::get(uniquer), TypeAnswer::get(uniquer, args&: type)};
757 }
758
759private:
760 /// The uniquer used when allocating predicate nodes.
761 PredicateUniquer &uniquer;
762
763 /// The current MLIR context.
764 MLIRContext *ctx;
765};
766
767} // namespace pdl_to_pdl_interp
768} // namespace mlir
769
770#endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
771

source code of mlir/lib/Conversion/PDLToPDLInterp/Predicate.h