1 | //===- Predicate.h - Pattern predicates -------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains definitions for "predicates" used when converting PDL into |
10 | // a matcher tree. Predicates are composed of three different parts: |
11 | // |
12 | // * Positions |
13 | // - A position refers to a specific location on the input DAG, i.e. an |
14 | // existing MLIR entity being matched. These can be attributes, operands, |
15 | // operations, results, and types. Each position also defines a relation to |
16 | // its parent. For example, the operand `[0] -> 1` has a parent operation |
17 | // position `[0]`. The attribute `[0, 1] -> "myAttr"` has parent operation |
18 | // position of `[0, 1]`. The operation `[0, 1]` has a parent operand edge |
19 | // `[0] -> 1` (i.e. it is the defining op of operand 1). The only position |
20 | // without a parent is `[0]`, which refers to the root operation. |
21 | // * Questions |
22 | // - A question refers to a query on a specific positional value. For |
23 | // example, an operation name question checks the name of an operation |
24 | // position. |
25 | // * Answers |
26 | // - An answer is the expected result of a question. For example, when |
27 | // matching an operation with the name "foo.op". The question would be an |
28 | // operation name question, with an expected answer of "foo.op". |
29 | // |
30 | //===----------------------------------------------------------------------===// |
31 | |
32 | #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ |
33 | #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ |
34 | |
35 | #include "mlir/IR/MLIRContext.h" |
36 | #include "mlir/IR/OperationSupport.h" |
37 | #include "mlir/IR/PatternMatch.h" |
38 | #include "mlir/IR/Types.h" |
39 | |
40 | namespace mlir { |
41 | namespace pdl_to_pdl_interp { |
42 | namespace Predicates { |
43 | /// An enumeration of the kinds of predicates. |
44 | enum Kind : unsigned { |
45 | /// Positions, ordered by decreasing priority. |
46 | OperationPos, |
47 | OperandPos, |
48 | OperandGroupPos, |
49 | AttributePos, |
50 | ConstraintResultPos, |
51 | ResultPos, |
52 | ResultGroupPos, |
53 | TypePos, |
54 | AttributeLiteralPos, |
55 | TypeLiteralPos, |
56 | UsersPos, |
57 | ForEachPos, |
58 | |
59 | // Questions, ordered by dependency and decreasing priority. |
60 | IsNotNullQuestion, |
61 | OperationNameQuestion, |
62 | TypeQuestion, |
63 | AttributeQuestion, |
64 | OperandCountAtLeastQuestion, |
65 | OperandCountQuestion, |
66 | ResultCountAtLeastQuestion, |
67 | ResultCountQuestion, |
68 | EqualToQuestion, |
69 | ConstraintQuestion, |
70 | |
71 | // Answers. |
72 | AttributeAnswer, |
73 | FalseAnswer, |
74 | OperationNameAnswer, |
75 | TrueAnswer, |
76 | TypeAnswer, |
77 | UnsignedAnswer, |
78 | }; |
79 | } // namespace Predicates |
80 | |
81 | /// Base class for all predicates, used to allow efficient pointer comparison. |
82 | template <typename ConcreteT, typename BaseT, typename Key, |
83 | Predicates::Kind Kind> |
84 | class PredicateBase : public BaseT { |
85 | public: |
86 | using KeyTy = Key; |
87 | using Base = PredicateBase<ConcreteT, BaseT, Key, Kind>; |
88 | |
89 | template <typename KeyT> |
90 | explicit PredicateBase(KeyT &&key) |
91 | : BaseT(Kind), key(std::forward<KeyT>(key)) {} |
92 | |
93 | /// Get an instance of this position. |
94 | template <typename... Args> |
95 | static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) { |
96 | return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...); |
97 | } |
98 | |
99 | /// Construct an instance with the given storage allocator. |
100 | template <typename KeyT> |
101 | static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc, |
102 | KeyT &&key) { |
103 | return new (alloc.allocate<ConcreteT>()) ConcreteT(std::forward<KeyT>(key)); |
104 | } |
105 | |
106 | /// Utility methods required by the storage allocator. |
107 | bool operator==(const KeyTy &key) const { return this->key == key; } |
108 | static bool classof(const BaseT *pred) { return pred->getKind() == Kind; } |
109 | |
110 | /// Return the key value of this predicate. |
111 | const KeyTy &getValue() const { return key; } |
112 | |
113 | protected: |
114 | KeyTy key; |
115 | }; |
116 | |
117 | /// Base storage for simple predicates that only unique with the kind. |
118 | template <typename ConcreteT, typename BaseT, Predicates::Kind Kind> |
119 | class PredicateBase<ConcreteT, BaseT, void, Kind> : public BaseT { |
120 | public: |
121 | using Base = PredicateBase<ConcreteT, BaseT, void, Kind>; |
122 | |
123 | explicit PredicateBase() : BaseT(Kind) {} |
124 | |
125 | static ConcreteT *get(StorageUniquer &uniquer) { |
126 | return uniquer.get<ConcreteT>(); |
127 | } |
128 | static bool classof(const BaseT *pred) { return pred->getKind() == Kind; } |
129 | }; |
130 | |
131 | //===----------------------------------------------------------------------===// |
132 | // Positions |
133 | //===----------------------------------------------------------------------===// |
134 | |
135 | struct OperationPosition; |
136 | |
137 | /// A position describes a value on the input IR on which a predicate may be |
138 | /// applied, such as an operation or attribute. This enables re-use between |
139 | /// predicates, and assists generating bytecode and memory management. |
140 | /// |
141 | /// Operation positions form the base of other positions, which are formed |
142 | /// relative to a parent operation. Operations are anchored at Operand nodes, |
143 | /// except for the root operation which is parentless. |
144 | class Position : public StorageUniquer::BaseStorage { |
145 | public: |
146 | explicit Position(Predicates::Kind kind) : kind(kind) {} |
147 | virtual ~Position(); |
148 | |
149 | /// Returns the depth of the first ancestor operation position. |
150 | unsigned getOperationDepth() const; |
151 | |
152 | /// Returns the parent position. The root operation position has no parent. |
153 | Position *getParent() const { return parent; } |
154 | |
155 | /// Returns the kind of this position. |
156 | Predicates::Kind getKind() const { return kind; } |
157 | |
158 | protected: |
159 | /// Link to the parent position. |
160 | Position *parent = nullptr; |
161 | |
162 | private: |
163 | /// The kind of this position. |
164 | Predicates::Kind kind; |
165 | }; |
166 | |
167 | //===----------------------------------------------------------------------===// |
168 | // AttributePosition |
169 | |
170 | /// A position describing an attribute of an operation. |
171 | struct AttributePosition |
172 | : public PredicateBase<AttributePosition, Position, |
173 | std::pair<OperationPosition *, StringAttr>, |
174 | Predicates::AttributePos> { |
175 | explicit AttributePosition(const KeyTy &key); |
176 | |
177 | /// Returns the attribute name of this position. |
178 | StringAttr getName() const { return key.second; } |
179 | }; |
180 | |
181 | //===----------------------------------------------------------------------===// |
182 | // AttributeLiteralPosition |
183 | |
184 | /// A position describing a literal attribute. |
185 | struct AttributeLiteralPosition |
186 | : public PredicateBase<AttributeLiteralPosition, Position, Attribute, |
187 | Predicates::AttributeLiteralPos> { |
188 | using PredicateBase::PredicateBase; |
189 | }; |
190 | |
191 | //===----------------------------------------------------------------------===// |
192 | // ForEachPosition |
193 | |
194 | /// A position describing an iterative choice of an operation. |
195 | struct ForEachPosition : public PredicateBase<ForEachPosition, Position, |
196 | std::pair<Position *, unsigned>, |
197 | Predicates::ForEachPos> { |
198 | explicit ForEachPosition(const KeyTy &key) : Base(key) { parent = key.first; } |
199 | |
200 | /// Returns the ID, for differentiating various loops. |
201 | /// For upward traversals, this is the index of the root. |
202 | unsigned getID() const { return key.second; } |
203 | }; |
204 | |
205 | //===----------------------------------------------------------------------===// |
206 | // OperandPosition |
207 | |
208 | /// A position describing an operand of an operation. |
209 | struct OperandPosition |
210 | : public PredicateBase<OperandPosition, Position, |
211 | std::pair<OperationPosition *, unsigned>, |
212 | Predicates::OperandPos> { |
213 | explicit OperandPosition(const KeyTy &key); |
214 | |
215 | /// Returns the operand number of this position. |
216 | unsigned getOperandNumber() const { return key.second; } |
217 | }; |
218 | |
219 | //===----------------------------------------------------------------------===// |
220 | // OperandGroupPosition |
221 | |
222 | /// A position describing an operand group of an operation. |
223 | struct OperandGroupPosition |
224 | : public PredicateBase< |
225 | OperandGroupPosition, Position, |
226 | std::tuple<OperationPosition *, std::optional<unsigned>, bool>, |
227 | Predicates::OperandGroupPos> { |
228 | explicit OperandGroupPosition(const KeyTy &key); |
229 | |
230 | /// Returns a hash suitable for the given keytype. |
231 | static llvm::hash_code hashKey(const KeyTy &key) { |
232 | return llvm::hash_value(arg: key); |
233 | } |
234 | |
235 | /// Returns the group number of this position. If std::nullopt, this group |
236 | /// refers to all operands. |
237 | std::optional<unsigned> getOperandGroupNumber() const { |
238 | return std::get<1>(t: key); |
239 | } |
240 | |
241 | /// Returns if the operand group has unknown size. If false, the operand group |
242 | /// has at max one element. |
243 | bool isVariadic() const { return std::get<2>(t: key); } |
244 | }; |
245 | |
246 | //===----------------------------------------------------------------------===// |
247 | // OperationPosition |
248 | |
249 | /// An operation position describes an operation node in the IR. Other position |
250 | /// kinds are formed with respect to an operation position. |
251 | struct OperationPosition : public PredicateBase<OperationPosition, Position, |
252 | std::pair<Position *, unsigned>, |
253 | Predicates::OperationPos> { |
254 | explicit OperationPosition(const KeyTy &key) : Base(key) { |
255 | parent = key.first; |
256 | } |
257 | |
258 | /// Returns a hash suitable for the given keytype. |
259 | static llvm::hash_code hashKey(const KeyTy &key) { |
260 | return llvm::hash_value(arg: key); |
261 | } |
262 | |
263 | /// Gets the root position. |
264 | static OperationPosition *getRoot(StorageUniquer &uniquer) { |
265 | return Base::get(uniquer, args: nullptr, args: 0); |
266 | } |
267 | |
268 | /// Gets an operation position with the given parent. |
269 | static OperationPosition *get(StorageUniquer &uniquer, Position *parent) { |
270 | return Base::get(uniquer, args&: parent, args: parent->getOperationDepth() + 1); |
271 | } |
272 | |
273 | /// Returns the depth of this position. |
274 | unsigned getDepth() const { return key.second; } |
275 | |
276 | /// Returns if this operation position corresponds to the root. |
277 | bool isRoot() const { return getDepth() == 0; } |
278 | |
279 | /// Returns if this operation represents an operand defining op. |
280 | bool isOperandDefiningOp() const; |
281 | }; |
282 | |
283 | //===----------------------------------------------------------------------===// |
284 | // ConstraintPosition |
285 | |
286 | struct ConstraintQuestion; |
287 | |
288 | /// A position describing the result of a native constraint. It saves the |
289 | /// corresponding ConstraintQuestion and result index to enable referring |
290 | /// back to them |
291 | struct ConstraintPosition |
292 | : public PredicateBase<ConstraintPosition, Position, |
293 | std::pair<ConstraintQuestion *, unsigned>, |
294 | Predicates::ConstraintResultPos> { |
295 | using PredicateBase::PredicateBase; |
296 | |
297 | /// Returns the ConstraintQuestion to enable keeping track of the native |
298 | /// constraint this position stems from. |
299 | ConstraintQuestion *getQuestion() const { return key.first; } |
300 | |
301 | // Returns the result index of this position |
302 | unsigned getIndex() const { return key.second; } |
303 | }; |
304 | |
305 | //===----------------------------------------------------------------------===// |
306 | // ResultPosition |
307 | |
308 | /// A position describing a result of an operation. |
309 | struct ResultPosition |
310 | : public PredicateBase<ResultPosition, Position, |
311 | std::pair<OperationPosition *, unsigned>, |
312 | Predicates::ResultPos> { |
313 | explicit ResultPosition(const KeyTy &key) : Base(key) { parent = key.first; } |
314 | |
315 | /// Returns the result number of this position. |
316 | unsigned getResultNumber() const { return key.second; } |
317 | }; |
318 | |
319 | //===----------------------------------------------------------------------===// |
320 | // ResultGroupPosition |
321 | |
322 | /// A position describing a result group of an operation. |
323 | struct ResultGroupPosition |
324 | : public PredicateBase< |
325 | ResultGroupPosition, Position, |
326 | std::tuple<OperationPosition *, std::optional<unsigned>, bool>, |
327 | Predicates::ResultGroupPos> { |
328 | explicit ResultGroupPosition(const KeyTy &key) : Base(key) { |
329 | parent = std::get<0>(t: key); |
330 | } |
331 | |
332 | /// Returns a hash suitable for the given keytype. |
333 | static llvm::hash_code hashKey(const KeyTy &key) { |
334 | return llvm::hash_value(arg: key); |
335 | } |
336 | |
337 | /// Returns the group number of this position. If std::nullopt, this group |
338 | /// refers to all results. |
339 | std::optional<unsigned> getResultGroupNumber() const { |
340 | return std::get<1>(t: key); |
341 | } |
342 | |
343 | /// Returns if the result group has unknown size. If false, the result group |
344 | /// has at max one element. |
345 | bool isVariadic() const { return std::get<2>(t: key); } |
346 | }; |
347 | |
348 | //===----------------------------------------------------------------------===// |
349 | // TypePosition |
350 | |
351 | /// A position describing the result type of an entity, i.e. an Attribute, |
352 | /// Operand, Result, etc. |
353 | struct TypePosition : public PredicateBase<TypePosition, Position, Position *, |
354 | Predicates::TypePos> { |
355 | explicit TypePosition(const KeyTy &key) : Base(key) { |
356 | assert((isa<AttributePosition, OperandPosition, OperandGroupPosition, |
357 | ResultPosition, ResultGroupPosition>(key)) && |
358 | "expected parent to be an attribute, operand, or result" ); |
359 | parent = key; |
360 | } |
361 | }; |
362 | |
363 | //===----------------------------------------------------------------------===// |
364 | // TypeLiteralPosition |
365 | |
366 | /// A position describing a literal type or type range. The value is stored as |
367 | /// either a TypeAttr, or an ArrayAttr of TypeAttr. |
368 | struct TypeLiteralPosition |
369 | : public PredicateBase<TypeLiteralPosition, Position, Attribute, |
370 | Predicates::TypeLiteralPos> { |
371 | using PredicateBase::PredicateBase; |
372 | }; |
373 | |
374 | //===----------------------------------------------------------------------===// |
375 | // UsersPosition |
376 | |
377 | /// A position describing the users of a value or a range of values. The second |
378 | /// value in the key indicates whether we choose users of a representative for |
379 | /// a range (this is true, e.g., in the upward traversals). |
380 | struct UsersPosition |
381 | : public PredicateBase<UsersPosition, Position, std::pair<Position *, bool>, |
382 | Predicates::UsersPos> { |
383 | explicit UsersPosition(const KeyTy &key) : Base(key) { parent = key.first; } |
384 | |
385 | /// Returns a hash suitable for the given keytype. |
386 | static llvm::hash_code hashKey(const KeyTy &key) { |
387 | return llvm::hash_value(arg: key); |
388 | } |
389 | |
390 | /// Indicates whether to compute a range of a representative. |
391 | bool useRepresentative() const { return key.second; } |
392 | }; |
393 | |
394 | //===----------------------------------------------------------------------===// |
395 | // Qualifiers |
396 | //===----------------------------------------------------------------------===// |
397 | |
398 | /// An ordinal predicate consists of a "Question" and a set of acceptable |
399 | /// "Answers" (later converted to ordinal values). A predicate will query some |
400 | /// property of a positional value and decide what to do based on the result. |
401 | /// |
402 | /// This makes top-level predicate representations ordinal (SwitchOp). Later, |
403 | /// predicates that end up with only one acceptable answer (including all |
404 | /// boolean kinds) will be converted to boolean predicates (PredicateOp) in the |
405 | /// matcher. |
406 | /// |
407 | /// For simplicity, both are represented as "qualifiers", with a base kind and |
408 | /// perhaps additional properties. For example, all OperationName predicates ask |
409 | /// the same question, but GenericConstraint predicates may ask different ones. |
410 | class Qualifier : public StorageUniquer::BaseStorage { |
411 | public: |
412 | explicit Qualifier(Predicates::Kind kind) : kind(kind) {} |
413 | |
414 | /// Returns the kind of this qualifier. |
415 | Predicates::Kind getKind() const { return kind; } |
416 | |
417 | private: |
418 | /// The kind of this position. |
419 | Predicates::Kind kind; |
420 | }; |
421 | |
422 | //===----------------------------------------------------------------------===// |
423 | // Answers |
424 | |
425 | /// An Answer representing an `Attribute` value. |
426 | struct AttributeAnswer |
427 | : public PredicateBase<AttributeAnswer, Qualifier, Attribute, |
428 | Predicates::AttributeAnswer> { |
429 | using Base::Base; |
430 | }; |
431 | |
432 | /// An Answer representing an `OperationName` value. |
433 | struct OperationNameAnswer |
434 | : public PredicateBase<OperationNameAnswer, Qualifier, OperationName, |
435 | Predicates::OperationNameAnswer> { |
436 | using Base::Base; |
437 | }; |
438 | |
439 | /// An Answer representing a boolean `true` value. |
440 | struct TrueAnswer |
441 | : PredicateBase<TrueAnswer, Qualifier, void, Predicates::TrueAnswer> { |
442 | using Base::Base; |
443 | }; |
444 | |
445 | /// An Answer representing a boolean 'false' value. |
446 | struct FalseAnswer |
447 | : PredicateBase<FalseAnswer, Qualifier, void, Predicates::FalseAnswer> { |
448 | using Base::Base; |
449 | }; |
450 | |
451 | /// An Answer representing a `Type` value. The value is stored as either a |
452 | /// TypeAttr, or an ArrayAttr of TypeAttr. |
453 | struct TypeAnswer : public PredicateBase<TypeAnswer, Qualifier, Attribute, |
454 | Predicates::TypeAnswer> { |
455 | using Base::Base; |
456 | }; |
457 | |
458 | /// An Answer representing an unsigned value. |
459 | struct UnsignedAnswer |
460 | : public PredicateBase<UnsignedAnswer, Qualifier, unsigned, |
461 | Predicates::UnsignedAnswer> { |
462 | using Base::Base; |
463 | }; |
464 | |
465 | //===----------------------------------------------------------------------===// |
466 | // Questions |
467 | |
468 | /// Compare an `Attribute` to a constant value. |
469 | struct AttributeQuestion |
470 | : public PredicateBase<AttributeQuestion, Qualifier, void, |
471 | Predicates::AttributeQuestion> {}; |
472 | |
473 | /// Apply a parameterized constraint to multiple position values and possibly |
474 | /// produce results. |
475 | struct ConstraintQuestion |
476 | : public PredicateBase< |
477 | ConstraintQuestion, Qualifier, |
478 | std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>, |
479 | Predicates::ConstraintQuestion> { |
480 | using Base::Base; |
481 | |
482 | /// Return the name of the constraint. |
483 | StringRef getName() const { return std::get<0>(t: key); } |
484 | |
485 | /// Return the arguments of the constraint. |
486 | ArrayRef<Position *> getArgs() const { return std::get<1>(t: key); } |
487 | |
488 | /// Return the result types of the constraint. |
489 | ArrayRef<Type> getResultTypes() const { return std::get<2>(t: key); } |
490 | |
491 | /// Return the negation status of the constraint. |
492 | bool getIsNegated() const { return std::get<3>(t: key); } |
493 | |
494 | /// Construct an instance with the given storage allocator. |
495 | static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc, |
496 | KeyTy key) { |
497 | return Base::construct(alloc, key: KeyTy{alloc.copyInto(str: std::get<0>(t&: key)), |
498 | alloc.copyInto(elements: std::get<1>(t&: key)), |
499 | alloc.copyInto(elements: std::get<2>(t&: key)), |
500 | std::get<3>(t&: key)}); |
501 | } |
502 | |
503 | /// Returns a hash suitable for the given keytype. |
504 | static llvm::hash_code hashKey(const KeyTy &key) { |
505 | return llvm::hash_value(arg: key); |
506 | } |
507 | }; |
508 | |
509 | /// Compare the equality of two values. |
510 | struct EqualToQuestion |
511 | : public PredicateBase<EqualToQuestion, Qualifier, Position *, |
512 | Predicates::EqualToQuestion> { |
513 | using Base::Base; |
514 | }; |
515 | |
516 | /// Compare a positional value with null, i.e. check if it exists. |
517 | struct IsNotNullQuestion |
518 | : public PredicateBase<IsNotNullQuestion, Qualifier, void, |
519 | Predicates::IsNotNullQuestion> {}; |
520 | |
521 | /// Compare the number of operands of an operation with a known value. |
522 | struct OperandCountQuestion |
523 | : public PredicateBase<OperandCountQuestion, Qualifier, void, |
524 | Predicates::OperandCountQuestion> {}; |
525 | struct OperandCountAtLeastQuestion |
526 | : public PredicateBase<OperandCountAtLeastQuestion, Qualifier, void, |
527 | Predicates::OperandCountAtLeastQuestion> {}; |
528 | |
529 | /// Compare the name of an operation with a known value. |
530 | struct OperationNameQuestion |
531 | : public PredicateBase<OperationNameQuestion, Qualifier, void, |
532 | Predicates::OperationNameQuestion> {}; |
533 | |
534 | /// Compare the number of results of an operation with a known value. |
535 | struct ResultCountQuestion |
536 | : public PredicateBase<ResultCountQuestion, Qualifier, void, |
537 | Predicates::ResultCountQuestion> {}; |
538 | struct ResultCountAtLeastQuestion |
539 | : public PredicateBase<ResultCountAtLeastQuestion, Qualifier, void, |
540 | Predicates::ResultCountAtLeastQuestion> {}; |
541 | |
542 | /// Compare the type of an attribute or value with a known type. |
543 | struct TypeQuestion : public PredicateBase<TypeQuestion, Qualifier, void, |
544 | Predicates::TypeQuestion> {}; |
545 | |
546 | //===----------------------------------------------------------------------===// |
547 | // PredicateUniquer |
548 | //===----------------------------------------------------------------------===// |
549 | |
550 | /// This class provides a storage uniquer that is used to allocate predicate |
551 | /// instances. |
552 | class PredicateUniquer : public StorageUniquer { |
553 | public: |
554 | PredicateUniquer() { |
555 | // Register the types of Positions with the uniquer. |
556 | registerParametricStorageType<AttributePosition>(); |
557 | registerParametricStorageType<AttributeLiteralPosition>(); |
558 | registerParametricStorageType<ConstraintPosition>(); |
559 | registerParametricStorageType<ForEachPosition>(); |
560 | registerParametricStorageType<OperandPosition>(); |
561 | registerParametricStorageType<OperandGroupPosition>(); |
562 | registerParametricStorageType<OperationPosition>(); |
563 | registerParametricStorageType<ResultPosition>(); |
564 | registerParametricStorageType<ResultGroupPosition>(); |
565 | registerParametricStorageType<TypePosition>(); |
566 | registerParametricStorageType<TypeLiteralPosition>(); |
567 | registerParametricStorageType<UsersPosition>(); |
568 | |
569 | // Register the types of Questions with the uniquer. |
570 | registerParametricStorageType<AttributeAnswer>(); |
571 | registerParametricStorageType<OperationNameAnswer>(); |
572 | registerParametricStorageType<TypeAnswer>(); |
573 | registerParametricStorageType<UnsignedAnswer>(); |
574 | registerSingletonStorageType<FalseAnswer>(); |
575 | registerSingletonStorageType<TrueAnswer>(); |
576 | |
577 | // Register the types of Answers with the uniquer. |
578 | registerParametricStorageType<ConstraintQuestion>(); |
579 | registerParametricStorageType<EqualToQuestion>(); |
580 | registerSingletonStorageType<AttributeQuestion>(); |
581 | registerSingletonStorageType<IsNotNullQuestion>(); |
582 | registerSingletonStorageType<OperandCountQuestion>(); |
583 | registerSingletonStorageType<OperandCountAtLeastQuestion>(); |
584 | registerSingletonStorageType<OperationNameQuestion>(); |
585 | registerSingletonStorageType<ResultCountQuestion>(); |
586 | registerSingletonStorageType<ResultCountAtLeastQuestion>(); |
587 | registerSingletonStorageType<TypeQuestion>(); |
588 | } |
589 | }; |
590 | |
591 | //===----------------------------------------------------------------------===// |
592 | // PredicateBuilder |
593 | //===----------------------------------------------------------------------===// |
594 | |
595 | /// This class provides utilities for constructing predicates. |
596 | class PredicateBuilder { |
597 | public: |
598 | PredicateBuilder(PredicateUniquer &uniquer, MLIRContext *ctx) |
599 | : uniquer(uniquer), ctx(ctx) {} |
600 | |
601 | //===--------------------------------------------------------------------===// |
602 | // Positions |
603 | //===--------------------------------------------------------------------===// |
604 | |
605 | /// Returns the root operation position. |
606 | Position *getRoot() { return OperationPosition::getRoot(uniquer); } |
607 | |
608 | /// Returns the parent position defining the value held by the given operand. |
609 | OperationPosition *getOperandDefiningOp(Position *p) { |
610 | assert((isa<OperandPosition, OperandGroupPosition>(p)) && |
611 | "expected operand position" ); |
612 | return OperationPosition::get(uniquer, parent: p); |
613 | } |
614 | |
615 | /// Returns the operation position equivalent to the given position. |
616 | OperationPosition *getPassthroughOp(Position *p) { |
617 | assert((isa<ForEachPosition>(p)) && "expected users position" ); |
618 | return OperationPosition::get(uniquer, parent: p); |
619 | } |
620 | |
621 | // Returns a position for a new value created by a constraint. |
622 | ConstraintPosition *getConstraintPosition(ConstraintQuestion *q, |
623 | unsigned index) { |
624 | return ConstraintPosition::get(uniquer, args: std::make_pair(x&: q, y&: index)); |
625 | } |
626 | |
627 | /// Returns an attribute position for an attribute of the given operation. |
628 | Position *getAttribute(OperationPosition *p, StringRef name) { |
629 | return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name)); |
630 | } |
631 | |
632 | /// Returns an attribute position for the given attribute. |
633 | Position *getAttributeLiteral(Attribute attr) { |
634 | return AttributeLiteralPosition::get(uniquer, args&: attr); |
635 | } |
636 | |
637 | Position *getForEach(Position *p, unsigned id) { |
638 | return ForEachPosition::get(uniquer, args&: p, args&: id); |
639 | } |
640 | |
641 | /// Returns an operand position for an operand of the given operation. |
642 | Position *getOperand(OperationPosition *p, unsigned operand) { |
643 | return OperandPosition::get(uniquer, args&: p, args&: operand); |
644 | } |
645 | |
646 | /// Returns a position for a group of operands of the given operation. |
647 | Position *getOperandGroup(OperationPosition *p, std::optional<unsigned> group, |
648 | bool isVariadic) { |
649 | return OperandGroupPosition::get(uniquer, args&: p, args&: group, args&: isVariadic); |
650 | } |
651 | Position *getAllOperands(OperationPosition *p) { |
652 | return getOperandGroup(p, /*group=*/group: std::nullopt, /*isVariadic=*/isVariadic: true); |
653 | } |
654 | |
655 | /// Returns a result position for a result of the given operation. |
656 | Position *getResult(OperationPosition *p, unsigned result) { |
657 | return ResultPosition::get(uniquer, args&: p, args&: result); |
658 | } |
659 | |
660 | /// Returns a position for a group of results of the given operation. |
661 | Position *getResultGroup(OperationPosition *p, std::optional<unsigned> group, |
662 | bool isVariadic) { |
663 | return ResultGroupPosition::get(uniquer, args&: p, args&: group, args&: isVariadic); |
664 | } |
665 | Position *getAllResults(OperationPosition *p) { |
666 | return getResultGroup(p, /*group=*/group: std::nullopt, /*isVariadic=*/isVariadic: true); |
667 | } |
668 | |
669 | /// Returns a type position for the given entity. |
670 | Position *getType(Position *p) { return TypePosition::get(uniquer, args&: p); } |
671 | |
672 | /// Returns a type position for the given type value. The value is stored |
673 | /// as either a TypeAttr, or an ArrayAttr of TypeAttr. |
674 | Position *getTypeLiteral(Attribute attr) { |
675 | return TypeLiteralPosition::get(uniquer, args&: attr); |
676 | } |
677 | |
678 | /// Returns the users of a position using the value at the given operand. |
679 | UsersPosition *getUsers(Position *p, bool useRepresentative) { |
680 | assert((isa<OperandPosition, OperandGroupPosition, ResultPosition, |
681 | ResultGroupPosition>(p)) && |
682 | "expected result position" ); |
683 | return UsersPosition::get(uniquer, args&: p, args&: useRepresentative); |
684 | } |
685 | |
686 | //===--------------------------------------------------------------------===// |
687 | // Qualifiers |
688 | //===--------------------------------------------------------------------===// |
689 | |
690 | /// An ordinal predicate consists of a "Question" and a set of acceptable |
691 | /// "Answers" (later converted to ordinal values). A predicate will query some |
692 | /// property of a positional value and decide what to do based on the result. |
693 | using Predicate = std::pair<Qualifier *, Qualifier *>; |
694 | |
695 | /// Create a predicate comparing an attribute to a known value. |
696 | Predicate getAttributeConstraint(Attribute attr) { |
697 | return {AttributeQuestion::get(uniquer), |
698 | AttributeAnswer::get(uniquer, args&: attr)}; |
699 | } |
700 | |
701 | /// Create a predicate checking if two values are equal. |
702 | Predicate getEqualTo(Position *pos) { |
703 | return {EqualToQuestion::get(uniquer, args&: pos), TrueAnswer::get(uniquer)}; |
704 | } |
705 | |
706 | /// Create a predicate checking if two values are not equal. |
707 | Predicate getNotEqualTo(Position *pos) { |
708 | return {EqualToQuestion::get(uniquer, args&: pos), FalseAnswer::get(uniquer)}; |
709 | } |
710 | |
711 | /// Create a predicate that applies a generic constraint. |
712 | Predicate getConstraint(StringRef name, ArrayRef<Position *> args, |
713 | ArrayRef<Type> resultTypes, bool isNegated) { |
714 | return {ConstraintQuestion::get( |
715 | uniquer, args: std::make_tuple(args&: name, args&: args, args&: resultTypes, args&: isNegated)), |
716 | TrueAnswer::get(uniquer)}; |
717 | } |
718 | |
719 | /// Create a predicate comparing a value with null. |
720 | Predicate getIsNotNull() { |
721 | return {IsNotNullQuestion::get(uniquer), TrueAnswer::get(uniquer)}; |
722 | } |
723 | |
724 | /// Create a predicate comparing the number of operands of an operation to a |
725 | /// known value. |
726 | Predicate getOperandCount(unsigned count) { |
727 | return {OperandCountQuestion::get(uniquer), |
728 | UnsignedAnswer::get(uniquer, args&: count)}; |
729 | } |
730 | Predicate getOperandCountAtLeast(unsigned count) { |
731 | return {OperandCountAtLeastQuestion::get(uniquer), |
732 | UnsignedAnswer::get(uniquer, args&: count)}; |
733 | } |
734 | |
735 | /// Create a predicate comparing the name of an operation to a known value. |
736 | Predicate getOperationName(StringRef name) { |
737 | return {OperationNameQuestion::get(uniquer), |
738 | OperationNameAnswer::get(uniquer, args: OperationName(name, ctx))}; |
739 | } |
740 | |
741 | /// Create a predicate comparing the number of results of an operation to a |
742 | /// known value. |
743 | Predicate getResultCount(unsigned count) { |
744 | return {ResultCountQuestion::get(uniquer), |
745 | UnsignedAnswer::get(uniquer, args&: count)}; |
746 | } |
747 | Predicate getResultCountAtLeast(unsigned count) { |
748 | return {ResultCountAtLeastQuestion::get(uniquer), |
749 | UnsignedAnswer::get(uniquer, args&: count)}; |
750 | } |
751 | |
752 | /// Create a predicate comparing the type of an attribute or value to a known |
753 | /// type. The value is stored as either a TypeAttr, or an ArrayAttr of |
754 | /// TypeAttr. |
755 | Predicate getTypeConstraint(Attribute type) { |
756 | return {TypeQuestion::get(uniquer), TypeAnswer::get(uniquer, args&: type)}; |
757 | } |
758 | |
759 | private: |
760 | /// The uniquer used when allocating predicate nodes. |
761 | PredicateUniquer &uniquer; |
762 | |
763 | /// The current MLIR context. |
764 | MLIRContext *ctx; |
765 | }; |
766 | |
767 | } // namespace pdl_to_pdl_interp |
768 | } // namespace mlir |
769 | |
770 | #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ |
771 | |