1 | //===- Predicate.h - Pattern predicates -------------------------*- C++ -*-===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains definitions for "predicates" used when converting PDL into |
10 | // a matcher tree. Predicates are composed of three different parts: |
11 | // |
12 | // * Positions |
13 | // - A position refers to a specific location on the input DAG, i.e. an |
14 | // existing MLIR entity being matched. These can be attributes, operands, |
15 | // operations, results, and types. Each position also defines a relation to |
16 | // its parent. For example, the operand `[0] -> 1` has a parent operation |
17 | // position `[0]`. The attribute `[0, 1] -> "myAttr"` has parent operation |
18 | // position of `[0, 1]`. The operation `[0, 1]` has a parent operand edge |
19 | // `[0] -> 1` (i.e. it is the defining op of operand 1). The only position |
20 | // without a parent is `[0]`, which refers to the root operation. |
21 | // * Questions |
22 | // - A question refers to a query on a specific positional value. For |
23 | // example, an operation name question checks the name of an operation |
24 | // position. |
25 | // * Answers |
26 | // - An answer is the expected result of a question. For example, when |
27 | // matching an operation with the name "foo.op". The question would be an |
28 | // operation name question, with an expected answer of "foo.op". |
29 | // |
30 | //===----------------------------------------------------------------------===// |
31 | |
32 | #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ |
33 | #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ |
34 | |
35 | #include "mlir/IR/MLIRContext.h" |
36 | #include "mlir/IR/OperationSupport.h" |
37 | #include "mlir/IR/PatternMatch.h" |
38 | #include "mlir/IR/Types.h" |
39 | |
40 | namespace mlir { |
41 | namespace pdl_to_pdl_interp { |
42 | namespace Predicates { |
43 | /// An enumeration of the kinds of predicates. |
44 | enum Kind : unsigned { |
45 | /// Positions, ordered by decreasing priority. |
46 | OperationPos, |
47 | OperandPos, |
48 | OperandGroupPos, |
49 | AttributePos, |
50 | ConstraintResultPos, |
51 | ResultPos, |
52 | ResultGroupPos, |
53 | TypePos, |
54 | AttributeLiteralPos, |
55 | TypeLiteralPos, |
56 | UsersPos, |
57 | ForEachPos, |
58 | |
59 | // Questions, ordered by dependency and decreasing priority. |
60 | IsNotNullQuestion, |
61 | OperationNameQuestion, |
62 | TypeQuestion, |
63 | AttributeQuestion, |
64 | OperandCountAtLeastQuestion, |
65 | OperandCountQuestion, |
66 | ResultCountAtLeastQuestion, |
67 | ResultCountQuestion, |
68 | EqualToQuestion, |
69 | ConstraintQuestion, |
70 | |
71 | // Answers. |
72 | AttributeAnswer, |
73 | FalseAnswer, |
74 | OperationNameAnswer, |
75 | TrueAnswer, |
76 | TypeAnswer, |
77 | UnsignedAnswer, |
78 | }; |
79 | } // namespace Predicates |
80 | |
81 | /// Base class for all predicates, used to allow efficient pointer comparison. |
82 | template <typename ConcreteT, typename BaseT, typename Key, |
83 | Predicates::Kind Kind> |
84 | class PredicateBase : public BaseT { |
85 | public: |
86 | using KeyTy = Key; |
87 | using Base = PredicateBase<ConcreteT, BaseT, Key, Kind>; |
88 | |
89 | template <typename KeyT> |
90 | explicit PredicateBase(KeyT &&key) |
91 | : BaseT(Kind), key(std::forward<KeyT>(key)) {} |
92 | |
93 | /// Get an instance of this position. |
94 | template <typename... Args> |
95 | static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) { |
96 | return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...); |
97 | } |
98 | |
99 | /// Construct an instance with the given storage allocator. |
100 | template <typename KeyT> |
101 | static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc, |
102 | KeyT &&key) { |
103 | return new (alloc.allocate<ConcreteT>()) ConcreteT(std::forward<KeyT>(key)); |
104 | } |
105 | |
106 | /// Utility methods required by the storage allocator. |
107 | bool operator==(const KeyTy &key) const { return this->key == key; } |
108 | static bool classof(const BaseT *pred) { return pred->getKind() == Kind; } |
109 | |
110 | /// Return the key value of this predicate. |
111 | const KeyTy &getValue() const { return key; } |
112 | |
113 | protected: |
114 | KeyTy key; |
115 | }; |
116 | |
117 | /// Base storage for simple predicates that only unique with the kind. |
118 | template <typename ConcreteT, typename BaseT, Predicates::Kind Kind> |
119 | class PredicateBase<ConcreteT, BaseT, void, Kind> : public BaseT { |
120 | public: |
121 | using Base = PredicateBase<ConcreteT, BaseT, void, Kind>; |
122 | |
123 | explicit PredicateBase() : BaseT(Kind) {} |
124 | |
125 | static ConcreteT *get(StorageUniquer &uniquer) { |
126 | return uniquer.get<ConcreteT>(); |
127 | } |
128 | static bool classof(const BaseT *pred) { return pred->getKind() == Kind; } |
129 | }; |
130 | |
131 | //===----------------------------------------------------------------------===// |
132 | // Positions |
133 | //===----------------------------------------------------------------------===// |
134 | |
135 | struct OperationPosition; |
136 | |
137 | /// A position describes a value on the input IR on which a predicate may be |
138 | /// applied, such as an operation or attribute. This enables re-use between |
139 | /// predicates, and assists generating bytecode and memory management. |
140 | /// |
141 | /// Operation positions form the base of other positions, which are formed |
142 | /// relative to a parent operation. Operations are anchored at Operand nodes, |
143 | /// except for the root operation which is parentless. |
144 | class Position : public StorageUniquer::BaseStorage { |
145 | public: |
146 | explicit Position(Predicates::Kind kind) : kind(kind) {} |
147 | virtual ~Position(); |
148 | |
149 | /// Returns the depth of the first ancestor operation position. |
150 | unsigned getOperationDepth() const; |
151 | |
152 | /// Returns the parent position. The root operation position has no parent. |
153 | Position *getParent() const { return parent; } |
154 | |
155 | /// Returns the kind of this position. |
156 | Predicates::Kind getKind() const { return kind; } |
157 | |
158 | protected: |
159 | /// Link to the parent position. |
160 | Position *parent = nullptr; |
161 | |
162 | private: |
163 | /// The kind of this position. |
164 | Predicates::Kind kind; |
165 | }; |
166 | |
167 | //===----------------------------------------------------------------------===// |
168 | // AttributePosition |
169 | //===----------------------------------------------------------------------===// |
170 | |
171 | /// A position describing an attribute of an operation. |
172 | struct AttributePosition |
173 | : public PredicateBase<AttributePosition, Position, |
174 | std::pair<OperationPosition *, StringAttr>, |
175 | Predicates::AttributePos> { |
176 | explicit AttributePosition(const KeyTy &key); |
177 | |
178 | /// Returns the attribute name of this position. |
179 | StringAttr getName() const { return key.second; } |
180 | }; |
181 | |
182 | //===----------------------------------------------------------------------===// |
183 | // AttributeLiteralPosition |
184 | //===----------------------------------------------------------------------===// |
185 | |
186 | /// A position describing a literal attribute. |
187 | struct AttributeLiteralPosition |
188 | : public PredicateBase<AttributeLiteralPosition, Position, Attribute, |
189 | Predicates::AttributeLiteralPos> { |
190 | using PredicateBase::PredicateBase; |
191 | }; |
192 | |
193 | //===----------------------------------------------------------------------===// |
194 | // ForEachPosition |
195 | //===----------------------------------------------------------------------===// |
196 | |
197 | /// A position describing an iterative choice of an operation. |
198 | struct ForEachPosition : public PredicateBase<ForEachPosition, Position, |
199 | std::pair<Position *, unsigned>, |
200 | Predicates::ForEachPos> { |
201 | explicit ForEachPosition(const KeyTy &key) : Base(key) { parent = key.first; } |
202 | |
203 | /// Returns the ID, for differentiating various loops. |
204 | /// For upward traversals, this is the index of the root. |
205 | unsigned getID() const { return key.second; } |
206 | }; |
207 | |
208 | //===----------------------------------------------------------------------===// |
209 | // OperandPosition |
210 | //===----------------------------------------------------------------------===// |
211 | |
212 | /// A position describing an operand of an operation. |
213 | struct OperandPosition |
214 | : public PredicateBase<OperandPosition, Position, |
215 | std::pair<OperationPosition *, unsigned>, |
216 | Predicates::OperandPos> { |
217 | explicit OperandPosition(const KeyTy &key); |
218 | |
219 | /// Returns the operand number of this position. |
220 | unsigned getOperandNumber() const { return key.second; } |
221 | }; |
222 | |
223 | //===----------------------------------------------------------------------===// |
224 | // OperandGroupPosition |
225 | //===----------------------------------------------------------------------===// |
226 | |
227 | /// A position describing an operand group of an operation. |
228 | struct OperandGroupPosition |
229 | : public PredicateBase< |
230 | OperandGroupPosition, Position, |
231 | std::tuple<OperationPosition *, std::optional<unsigned>, bool>, |
232 | Predicates::OperandGroupPos> { |
233 | explicit OperandGroupPosition(const KeyTy &key); |
234 | |
235 | /// Returns a hash suitable for the given keytype. |
236 | static llvm::hash_code hashKey(const KeyTy &key) { |
237 | return llvm::hash_value(arg: key); |
238 | } |
239 | |
240 | /// Returns the group number of this position. If std::nullopt, this group |
241 | /// refers to all operands. |
242 | std::optional<unsigned> getOperandGroupNumber() const { |
243 | return std::get<1>(t: key); |
244 | } |
245 | |
246 | /// Returns if the operand group has unknown size. If false, the operand group |
247 | /// has at max one element. |
248 | bool isVariadic() const { return std::get<2>(t: key); } |
249 | }; |
250 | |
251 | //===----------------------------------------------------------------------===// |
252 | // OperationPosition |
253 | //===----------------------------------------------------------------------===// |
254 | |
255 | /// An operation position describes an operation node in the IR. Other position |
256 | /// kinds are formed with respect to an operation position. |
257 | struct OperationPosition : public PredicateBase<OperationPosition, Position, |
258 | std::pair<Position *, unsigned>, |
259 | Predicates::OperationPos> { |
260 | explicit OperationPosition(const KeyTy &key) : Base(key) { |
261 | parent = key.first; |
262 | } |
263 | |
264 | /// Returns a hash suitable for the given keytype. |
265 | static llvm::hash_code hashKey(const KeyTy &key) { |
266 | return llvm::hash_value(arg: key); |
267 | } |
268 | |
269 | /// Gets the root position. |
270 | static OperationPosition *getRoot(StorageUniquer &uniquer) { |
271 | return Base::get(uniquer, args: nullptr, args: 0); |
272 | } |
273 | |
274 | /// Gets an operation position with the given parent. |
275 | static OperationPosition *get(StorageUniquer &uniquer, Position *parent) { |
276 | return Base::get(uniquer, args&: parent, args: parent->getOperationDepth() + 1); |
277 | } |
278 | |
279 | /// Returns the depth of this position. |
280 | unsigned getDepth() const { return key.second; } |
281 | |
282 | /// Returns if this operation position corresponds to the root. |
283 | bool isRoot() const { return getDepth() == 0; } |
284 | |
285 | /// Returns if this operation represents an operand defining op. |
286 | bool isOperandDefiningOp() const; |
287 | }; |
288 | |
289 | //===----------------------------------------------------------------------===// |
290 | // ConstraintPosition |
291 | //===----------------------------------------------------------------------===// |
292 | |
293 | struct ConstraintQuestion; |
294 | |
295 | /// A position describing the result of a native constraint. It saves the |
296 | /// corresponding ConstraintQuestion and result index to enable referring |
297 | /// back to them |
298 | struct ConstraintPosition |
299 | : public PredicateBase<ConstraintPosition, Position, |
300 | std::pair<ConstraintQuestion *, unsigned>, |
301 | Predicates::ConstraintResultPos> { |
302 | using PredicateBase::PredicateBase; |
303 | |
304 | /// Returns the ConstraintQuestion to enable keeping track of the native |
305 | /// constraint this position stems from. |
306 | ConstraintQuestion *getQuestion() const { return key.first; } |
307 | |
308 | // Returns the result index of this position |
309 | unsigned getIndex() const { return key.second; } |
310 | }; |
311 | |
312 | //===----------------------------------------------------------------------===// |
313 | // ResultPosition |
314 | //===----------------------------------------------------------------------===// |
315 | |
316 | /// A position describing a result of an operation. |
317 | struct ResultPosition |
318 | : public PredicateBase<ResultPosition, Position, |
319 | std::pair<OperationPosition *, unsigned>, |
320 | Predicates::ResultPos> { |
321 | explicit ResultPosition(const KeyTy &key) : Base(key) { parent = key.first; } |
322 | |
323 | /// Returns the result number of this position. |
324 | unsigned getResultNumber() const { return key.second; } |
325 | }; |
326 | |
327 | //===----------------------------------------------------------------------===// |
328 | // ResultGroupPosition |
329 | //===----------------------------------------------------------------------===// |
330 | |
331 | /// A position describing a result group of an operation. |
332 | struct ResultGroupPosition |
333 | : public PredicateBase< |
334 | ResultGroupPosition, Position, |
335 | std::tuple<OperationPosition *, std::optional<unsigned>, bool>, |
336 | Predicates::ResultGroupPos> { |
337 | explicit ResultGroupPosition(const KeyTy &key) : Base(key) { |
338 | parent = std::get<0>(t: key); |
339 | } |
340 | |
341 | /// Returns a hash suitable for the given keytype. |
342 | static llvm::hash_code hashKey(const KeyTy &key) { |
343 | return llvm::hash_value(arg: key); |
344 | } |
345 | |
346 | /// Returns the group number of this position. If std::nullopt, this group |
347 | /// refers to all results. |
348 | std::optional<unsigned> getResultGroupNumber() const { |
349 | return std::get<1>(t: key); |
350 | } |
351 | |
352 | /// Returns if the result group has unknown size. If false, the result group |
353 | /// has at max one element. |
354 | bool isVariadic() const { return std::get<2>(t: key); } |
355 | }; |
356 | |
357 | //===----------------------------------------------------------------------===// |
358 | // TypePosition |
359 | //===----------------------------------------------------------------------===// |
360 | |
361 | /// A position describing the result type of an entity, i.e. an Attribute, |
362 | /// Operand, Result, etc. |
363 | struct TypePosition : public PredicateBase<TypePosition, Position, Position *, |
364 | Predicates::TypePos> { |
365 | explicit TypePosition(const KeyTy &key) : Base(key) { |
366 | assert((isa<AttributePosition, OperandPosition, OperandGroupPosition, |
367 | ResultPosition, ResultGroupPosition>(key)) && |
368 | "expected parent to be an attribute, operand, or result"); |
369 | parent = key; |
370 | } |
371 | }; |
372 | |
373 | //===----------------------------------------------------------------------===// |
374 | // TypeLiteralPosition |
375 | //===----------------------------------------------------------------------===// |
376 | |
377 | /// A position describing a literal type or type range. The value is stored as |
378 | /// either a TypeAttr, or an ArrayAttr of TypeAttr. |
379 | struct TypeLiteralPosition |
380 | : public PredicateBase<TypeLiteralPosition, Position, Attribute, |
381 | Predicates::TypeLiteralPos> { |
382 | using PredicateBase::PredicateBase; |
383 | }; |
384 | |
385 | //===----------------------------------------------------------------------===// |
386 | // UsersPosition |
387 | //===----------------------------------------------------------------------===// |
388 | |
389 | /// A position describing the users of a value or a range of values. The second |
390 | /// value in the key indicates whether we choose users of a representative for |
391 | /// a range (this is true, e.g., in the upward traversals). |
392 | struct UsersPosition |
393 | : public PredicateBase<UsersPosition, Position, std::pair<Position *, bool>, |
394 | Predicates::UsersPos> { |
395 | explicit UsersPosition(const KeyTy &key) : Base(key) { parent = key.first; } |
396 | |
397 | /// Returns a hash suitable for the given keytype. |
398 | static llvm::hash_code hashKey(const KeyTy &key) { |
399 | return llvm::hash_value(arg: key); |
400 | } |
401 | |
402 | /// Indicates whether to compute a range of a representative. |
403 | bool useRepresentative() const { return key.second; } |
404 | }; |
405 | |
406 | //===----------------------------------------------------------------------===// |
407 | // Qualifiers |
408 | //===----------------------------------------------------------------------===// |
409 | |
410 | /// An ordinal predicate consists of a "Question" and a set of acceptable |
411 | /// "Answers" (later converted to ordinal values). A predicate will query some |
412 | /// property of a positional value and decide what to do based on the result. |
413 | /// |
414 | /// This makes top-level predicate representations ordinal (SwitchOp). Later, |
415 | /// predicates that end up with only one acceptable answer (including all |
416 | /// boolean kinds) will be converted to boolean predicates (PredicateOp) in the |
417 | /// matcher. |
418 | /// |
419 | /// For simplicity, both are represented as "qualifiers", with a base kind and |
420 | /// perhaps additional properties. For example, all OperationName predicates ask |
421 | /// the same question, but GenericConstraint predicates may ask different ones. |
422 | class Qualifier : public StorageUniquer::BaseStorage { |
423 | public: |
424 | explicit Qualifier(Predicates::Kind kind) : kind(kind) {} |
425 | |
426 | /// Returns the kind of this qualifier. |
427 | Predicates::Kind getKind() const { return kind; } |
428 | |
429 | private: |
430 | /// The kind of this position. |
431 | Predicates::Kind kind; |
432 | }; |
433 | |
434 | //===----------------------------------------------------------------------===// |
435 | // Answers |
436 | //===----------------------------------------------------------------------===// |
437 | |
438 | /// An Answer representing an `Attribute` value. |
439 | struct AttributeAnswer |
440 | : public PredicateBase<AttributeAnswer, Qualifier, Attribute, |
441 | Predicates::AttributeAnswer> { |
442 | using Base::Base; |
443 | }; |
444 | |
445 | /// An Answer representing an `OperationName` value. |
446 | struct OperationNameAnswer |
447 | : public PredicateBase<OperationNameAnswer, Qualifier, OperationName, |
448 | Predicates::OperationNameAnswer> { |
449 | using Base::Base; |
450 | }; |
451 | |
452 | /// An Answer representing a boolean `true` value. |
453 | struct TrueAnswer |
454 | : PredicateBase<TrueAnswer, Qualifier, void, Predicates::TrueAnswer> { |
455 | using Base::Base; |
456 | }; |
457 | |
458 | /// An Answer representing a boolean 'false' value. |
459 | struct FalseAnswer |
460 | : PredicateBase<FalseAnswer, Qualifier, void, Predicates::FalseAnswer> { |
461 | using Base::Base; |
462 | }; |
463 | |
464 | /// An Answer representing a `Type` value. The value is stored as either a |
465 | /// TypeAttr, or an ArrayAttr of TypeAttr. |
466 | struct TypeAnswer : public PredicateBase<TypeAnswer, Qualifier, Attribute, |
467 | Predicates::TypeAnswer> { |
468 | using Base::Base; |
469 | }; |
470 | |
471 | /// An Answer representing an unsigned value. |
472 | struct UnsignedAnswer |
473 | : public PredicateBase<UnsignedAnswer, Qualifier, unsigned, |
474 | Predicates::UnsignedAnswer> { |
475 | using Base::Base; |
476 | }; |
477 | |
478 | //===----------------------------------------------------------------------===// |
479 | // Questions |
480 | //===----------------------------------------------------------------------===// |
481 | |
482 | /// Compare an `Attribute` to a constant value. |
483 | struct AttributeQuestion |
484 | : public PredicateBase<AttributeQuestion, Qualifier, void, |
485 | Predicates::AttributeQuestion> {}; |
486 | |
487 | /// Apply a parameterized constraint to multiple position values and possibly |
488 | /// produce results. |
489 | struct ConstraintQuestion |
490 | : public PredicateBase< |
491 | ConstraintQuestion, Qualifier, |
492 | std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>, |
493 | Predicates::ConstraintQuestion> { |
494 | using Base::Base; |
495 | |
496 | /// Return the name of the constraint. |
497 | StringRef getName() const { return std::get<0>(t: key); } |
498 | |
499 | /// Return the arguments of the constraint. |
500 | ArrayRef<Position *> getArgs() const { return std::get<1>(t: key); } |
501 | |
502 | /// Return the result types of the constraint. |
503 | ArrayRef<Type> getResultTypes() const { return std::get<2>(t: key); } |
504 | |
505 | /// Return the negation status of the constraint. |
506 | bool getIsNegated() const { return std::get<3>(t: key); } |
507 | |
508 | /// Construct an instance with the given storage allocator. |
509 | static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc, |
510 | KeyTy key) { |
511 | return Base::construct(alloc, key: KeyTy{alloc.copyInto(str: std::get<0>(t&: key)), |
512 | alloc.copyInto(elements: std::get<1>(t&: key)), |
513 | alloc.copyInto(elements: std::get<2>(t&: key)), |
514 | std::get<3>(t&: key)}); |
515 | } |
516 | |
517 | /// Returns a hash suitable for the given keytype. |
518 | static llvm::hash_code hashKey(const KeyTy &key) { |
519 | return llvm::hash_value(arg: key); |
520 | } |
521 | }; |
522 | |
523 | /// Compare the equality of two values. |
524 | struct EqualToQuestion |
525 | : public PredicateBase<EqualToQuestion, Qualifier, Position *, |
526 | Predicates::EqualToQuestion> { |
527 | using Base::Base; |
528 | }; |
529 | |
530 | /// Compare a positional value with null, i.e. check if it exists. |
531 | struct IsNotNullQuestion |
532 | : public PredicateBase<IsNotNullQuestion, Qualifier, void, |
533 | Predicates::IsNotNullQuestion> {}; |
534 | |
535 | /// Compare the number of operands of an operation with a known value. |
536 | struct OperandCountQuestion |
537 | : public PredicateBase<OperandCountQuestion, Qualifier, void, |
538 | Predicates::OperandCountQuestion> {}; |
539 | struct OperandCountAtLeastQuestion |
540 | : public PredicateBase<OperandCountAtLeastQuestion, Qualifier, void, |
541 | Predicates::OperandCountAtLeastQuestion> {}; |
542 | |
543 | /// Compare the name of an operation with a known value. |
544 | struct OperationNameQuestion |
545 | : public PredicateBase<OperationNameQuestion, Qualifier, void, |
546 | Predicates::OperationNameQuestion> {}; |
547 | |
548 | /// Compare the number of results of an operation with a known value. |
549 | struct ResultCountQuestion |
550 | : public PredicateBase<ResultCountQuestion, Qualifier, void, |
551 | Predicates::ResultCountQuestion> {}; |
552 | struct ResultCountAtLeastQuestion |
553 | : public PredicateBase<ResultCountAtLeastQuestion, Qualifier, void, |
554 | Predicates::ResultCountAtLeastQuestion> {}; |
555 | |
556 | /// Compare the type of an attribute or value with a known type. |
557 | struct TypeQuestion : public PredicateBase<TypeQuestion, Qualifier, void, |
558 | Predicates::TypeQuestion> {}; |
559 | |
560 | //===----------------------------------------------------------------------===// |
561 | // PredicateUniquer |
562 | //===----------------------------------------------------------------------===// |
563 | |
564 | /// This class provides a storage uniquer that is used to allocate predicate |
565 | /// instances. |
566 | class PredicateUniquer : public StorageUniquer { |
567 | public: |
568 | PredicateUniquer() { |
569 | // Register the types of Positions with the uniquer. |
570 | registerParametricStorageType<AttributePosition>(); |
571 | registerParametricStorageType<AttributeLiteralPosition>(); |
572 | registerParametricStorageType<ConstraintPosition>(); |
573 | registerParametricStorageType<ForEachPosition>(); |
574 | registerParametricStorageType<OperandPosition>(); |
575 | registerParametricStorageType<OperandGroupPosition>(); |
576 | registerParametricStorageType<OperationPosition>(); |
577 | registerParametricStorageType<ResultPosition>(); |
578 | registerParametricStorageType<ResultGroupPosition>(); |
579 | registerParametricStorageType<TypePosition>(); |
580 | registerParametricStorageType<TypeLiteralPosition>(); |
581 | registerParametricStorageType<UsersPosition>(); |
582 | |
583 | // Register the types of Questions with the uniquer. |
584 | registerParametricStorageType<AttributeAnswer>(); |
585 | registerParametricStorageType<OperationNameAnswer>(); |
586 | registerParametricStorageType<TypeAnswer>(); |
587 | registerParametricStorageType<UnsignedAnswer>(); |
588 | registerSingletonStorageType<FalseAnswer>(); |
589 | registerSingletonStorageType<TrueAnswer>(); |
590 | |
591 | // Register the types of Answers with the uniquer. |
592 | registerParametricStorageType<ConstraintQuestion>(); |
593 | registerParametricStorageType<EqualToQuestion>(); |
594 | registerSingletonStorageType<AttributeQuestion>(); |
595 | registerSingletonStorageType<IsNotNullQuestion>(); |
596 | registerSingletonStorageType<OperandCountQuestion>(); |
597 | registerSingletonStorageType<OperandCountAtLeastQuestion>(); |
598 | registerSingletonStorageType<OperationNameQuestion>(); |
599 | registerSingletonStorageType<ResultCountQuestion>(); |
600 | registerSingletonStorageType<ResultCountAtLeastQuestion>(); |
601 | registerSingletonStorageType<TypeQuestion>(); |
602 | } |
603 | }; |
604 | |
605 | //===----------------------------------------------------------------------===// |
606 | // PredicateBuilder |
607 | //===----------------------------------------------------------------------===// |
608 | |
609 | /// This class provides utilities for constructing predicates. |
610 | class PredicateBuilder { |
611 | public: |
612 | PredicateBuilder(PredicateUniquer &uniquer, MLIRContext *ctx) |
613 | : uniquer(uniquer), ctx(ctx) {} |
614 | |
615 | //===--------------------------------------------------------------------===// |
616 | // Positions |
617 | //===--------------------------------------------------------------------===// |
618 | |
619 | /// Returns the root operation position. |
620 | Position *getRoot() { return OperationPosition::getRoot(uniquer); } |
621 | |
622 | /// Returns the parent position defining the value held by the given operand. |
623 | OperationPosition *getOperandDefiningOp(Position *p) { |
624 | assert((isa<OperandPosition, OperandGroupPosition>(p)) && |
625 | "expected operand position"); |
626 | return OperationPosition::get(uniquer, parent: p); |
627 | } |
628 | |
629 | /// Returns the operation position equivalent to the given position. |
630 | OperationPosition *getPassthroughOp(Position *p) { |
631 | assert((isa<ForEachPosition>(p)) && "expected users position"); |
632 | return OperationPosition::get(uniquer, parent: p); |
633 | } |
634 | |
635 | // Returns a position for a new value created by a constraint. |
636 | ConstraintPosition *getConstraintPosition(ConstraintQuestion *q, |
637 | unsigned index) { |
638 | return ConstraintPosition::get(uniquer, args: std::make_pair(x&: q, y&: index)); |
639 | } |
640 | |
641 | /// Returns an attribute position for an attribute of the given operation. |
642 | Position *getAttribute(OperationPosition *p, StringRef name) { |
643 | return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name)); |
644 | } |
645 | |
646 | /// Returns an attribute position for the given attribute. |
647 | Position *getAttributeLiteral(Attribute attr) { |
648 | return AttributeLiteralPosition::get(uniquer, args&: attr); |
649 | } |
650 | |
651 | Position *getForEach(Position *p, unsigned id) { |
652 | return ForEachPosition::get(uniquer, args&: p, args&: id); |
653 | } |
654 | |
655 | /// Returns an operand position for an operand of the given operation. |
656 | Position *getOperand(OperationPosition *p, unsigned operand) { |
657 | return OperandPosition::get(uniquer, args&: p, args&: operand); |
658 | } |
659 | |
660 | /// Returns a position for a group of operands of the given operation. |
661 | Position *getOperandGroup(OperationPosition *p, std::optional<unsigned> group, |
662 | bool isVariadic) { |
663 | return OperandGroupPosition::get(uniquer, args&: p, args&: group, args&: isVariadic); |
664 | } |
665 | Position *getAllOperands(OperationPosition *p) { |
666 | return getOperandGroup(p, /*group=*/group: std::nullopt, /*isVariadic=*/isVariadic: true); |
667 | } |
668 | |
669 | /// Returns a result position for a result of the given operation. |
670 | Position *getResult(OperationPosition *p, unsigned result) { |
671 | return ResultPosition::get(uniquer, args&: p, args&: result); |
672 | } |
673 | |
674 | /// Returns a position for a group of results of the given operation. |
675 | Position *getResultGroup(OperationPosition *p, std::optional<unsigned> group, |
676 | bool isVariadic) { |
677 | return ResultGroupPosition::get(uniquer, args&: p, args&: group, args&: isVariadic); |
678 | } |
679 | Position *getAllResults(OperationPosition *p) { |
680 | return getResultGroup(p, /*group=*/group: std::nullopt, /*isVariadic=*/isVariadic: true); |
681 | } |
682 | |
683 | /// Returns a type position for the given entity. |
684 | Position *getType(Position *p) { return TypePosition::get(uniquer, args&: p); } |
685 | |
686 | /// Returns a type position for the given type value. The value is stored |
687 | /// as either a TypeAttr, or an ArrayAttr of TypeAttr. |
688 | Position *getTypeLiteral(Attribute attr) { |
689 | return TypeLiteralPosition::get(uniquer, args&: attr); |
690 | } |
691 | |
692 | /// Returns the users of a position using the value at the given operand. |
693 | UsersPosition *getUsers(Position *p, bool useRepresentative) { |
694 | assert((isa<OperandPosition, OperandGroupPosition, ResultPosition, |
695 | ResultGroupPosition>(p)) && |
696 | "expected result position"); |
697 | return UsersPosition::get(uniquer, args&: p, args&: useRepresentative); |
698 | } |
699 | |
700 | //===--------------------------------------------------------------------===// |
701 | // Qualifiers |
702 | //===--------------------------------------------------------------------===// |
703 | |
704 | /// An ordinal predicate consists of a "Question" and a set of acceptable |
705 | /// "Answers" (later converted to ordinal values). A predicate will query some |
706 | /// property of a positional value and decide what to do based on the result. |
707 | using Predicate = std::pair<Qualifier *, Qualifier *>; |
708 | |
709 | /// Create a predicate comparing an attribute to a known value. |
710 | Predicate getAttributeConstraint(Attribute attr) { |
711 | return {AttributeQuestion::get(uniquer), |
712 | AttributeAnswer::get(uniquer, args&: attr)}; |
713 | } |
714 | |
715 | /// Create a predicate checking if two values are equal. |
716 | Predicate getEqualTo(Position *pos) { |
717 | return {EqualToQuestion::get(uniquer, args&: pos), TrueAnswer::get(uniquer)}; |
718 | } |
719 | |
720 | /// Create a predicate checking if two values are not equal. |
721 | Predicate getNotEqualTo(Position *pos) { |
722 | return {EqualToQuestion::get(uniquer, args&: pos), FalseAnswer::get(uniquer)}; |
723 | } |
724 | |
725 | /// Create a predicate that applies a generic constraint. |
726 | Predicate getConstraint(StringRef name, ArrayRef<Position *> args, |
727 | ArrayRef<Type> resultTypes, bool isNegated) { |
728 | return {ConstraintQuestion::get( |
729 | uniquer, args: std::make_tuple(args&: name, args&: args, args&: resultTypes, args&: isNegated)), |
730 | TrueAnswer::get(uniquer)}; |
731 | } |
732 | |
733 | /// Create a predicate comparing a value with null. |
734 | Predicate getIsNotNull() { |
735 | return {IsNotNullQuestion::get(uniquer), TrueAnswer::get(uniquer)}; |
736 | } |
737 | |
738 | /// Create a predicate comparing the number of operands of an operation to a |
739 | /// known value. |
740 | Predicate getOperandCount(unsigned count) { |
741 | return {OperandCountQuestion::get(uniquer), |
742 | UnsignedAnswer::get(uniquer, args&: count)}; |
743 | } |
744 | Predicate getOperandCountAtLeast(unsigned count) { |
745 | return {OperandCountAtLeastQuestion::get(uniquer), |
746 | UnsignedAnswer::get(uniquer, args&: count)}; |
747 | } |
748 | |
749 | /// Create a predicate comparing the name of an operation to a known value. |
750 | Predicate getOperationName(StringRef name) { |
751 | return {OperationNameQuestion::get(uniquer), |
752 | OperationNameAnswer::get(uniquer, args: OperationName(name, ctx))}; |
753 | } |
754 | |
755 | /// Create a predicate comparing the number of results of an operation to a |
756 | /// known value. |
757 | Predicate getResultCount(unsigned count) { |
758 | return {ResultCountQuestion::get(uniquer), |
759 | UnsignedAnswer::get(uniquer, args&: count)}; |
760 | } |
761 | Predicate getResultCountAtLeast(unsigned count) { |
762 | return {ResultCountAtLeastQuestion::get(uniquer), |
763 | UnsignedAnswer::get(uniquer, args&: count)}; |
764 | } |
765 | |
766 | /// Create a predicate comparing the type of an attribute or value to a known |
767 | /// type. The value is stored as either a TypeAttr, or an ArrayAttr of |
768 | /// TypeAttr. |
769 | Predicate getTypeConstraint(Attribute type) { |
770 | return {TypeQuestion::get(uniquer), TypeAnswer::get(uniquer, args&: type)}; |
771 | } |
772 | |
773 | private: |
774 | /// The uniquer used when allocating predicate nodes. |
775 | PredicateUniquer &uniquer; |
776 | |
777 | /// The current MLIR context. |
778 | MLIRContext *ctx; |
779 | }; |
780 | |
781 | } // namespace pdl_to_pdl_interp |
782 | } // namespace mlir |
783 | |
784 | #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ |
785 |
Definitions
- Kind
- PredicateBase
- PredicateBase
- get
- construct
- operator==
- classof
- getValue
- PredicateBase
- PredicateBase
- get
- classof
- Position
- Position
- getParent
- getKind
- AttributePosition
- getName
- AttributeLiteralPosition
- ForEachPosition
- ForEachPosition
- getID
- OperandPosition
- getOperandNumber
- OperandGroupPosition
- hashKey
- getOperandGroupNumber
- isVariadic
- OperationPosition
- OperationPosition
- hashKey
- getRoot
- get
- getDepth
- isRoot
- ConstraintPosition
- getQuestion
- getIndex
- ResultPosition
- ResultPosition
- getResultNumber
- ResultGroupPosition
- ResultGroupPosition
- hashKey
- getResultGroupNumber
- isVariadic
- TypePosition
- TypePosition
- TypeLiteralPosition
- UsersPosition
- UsersPosition
- hashKey
- useRepresentative
- Qualifier
- Qualifier
- getKind
- AttributeAnswer
- OperationNameAnswer
- TrueAnswer
- FalseAnswer
- TypeAnswer
- UnsignedAnswer
- AttributeQuestion
- ConstraintQuestion
- getName
- getArgs
- getResultTypes
- getIsNegated
- construct
- hashKey
- EqualToQuestion
- IsNotNullQuestion
- OperandCountQuestion
- OperandCountAtLeastQuestion
- OperationNameQuestion
- ResultCountQuestion
- ResultCountAtLeastQuestion
- TypeQuestion
- PredicateUniquer
- PredicateUniquer
- PredicateBuilder
- PredicateBuilder
- getRoot
- getOperandDefiningOp
- getPassthroughOp
- getConstraintPosition
- getAttribute
- getAttributeLiteral
- getForEach
- getOperand
- getOperandGroup
- getAllOperands
- getResult
- getResultGroup
- getAllResults
- getType
- getTypeLiteral
- getUsers
- getAttributeConstraint
- getEqualTo
- getNotEqualTo
- getConstraint
- getIsNotNull
- getOperandCount
- getOperandCountAtLeast
- getOperationName
- getResultCount
- getResultCountAtLeast
Improve your Profiling and Debugging skills
Find out more