1 | //===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements helper classes for implementing the "Op" types. This |
10 | // includes the Op type, which is the base class for Op class definitions, |
11 | // as well as number of traits in the OpTrait namespace that provide a |
12 | // declarative way to specify properties of Ops. |
13 | // |
14 | // The purpose of these types are to allow light-weight implementation of |
15 | // concrete ops (like DimOp) with very little boilerplate. |
16 | // |
17 | //===----------------------------------------------------------------------===// |
18 | |
19 | #ifndef MLIR_IR_OPDEFINITION_H |
20 | #define MLIR_IR_OPDEFINITION_H |
21 | |
22 | #include "mlir/IR/Dialect.h" |
23 | #include "mlir/IR/ODSSupport.h" |
24 | #include "mlir/IR/Operation.h" |
25 | #include "llvm/Support/PointerLikeTypeTraits.h" |
26 | |
27 | #include <optional> |
28 | #include <type_traits> |
29 | |
30 | namespace mlir { |
31 | class Builder; |
32 | class OpBuilder; |
33 | |
34 | /// This class implements `Optional` functionality for ParseResult. We don't |
35 | /// directly use Optional here, because it provides an implicit conversion |
36 | /// to 'bool' which we want to avoid. This class is used to implement tri-state |
37 | /// 'parseOptional' functions that may have a failure mode when parsing that |
38 | /// shouldn't be attributed to "not present". |
39 | class OptionalParseResult { |
40 | public: |
41 | OptionalParseResult() = default; |
42 | OptionalParseResult(LogicalResult result) : impl(result) {} |
43 | OptionalParseResult(ParseResult result) : impl(result) {} |
44 | OptionalParseResult(const InFlightDiagnostic &) |
45 | : OptionalParseResult(failure()) {} |
46 | OptionalParseResult(std::nullopt_t) : impl(std::nullopt) {} |
47 | |
48 | /// Returns true if we contain a valid ParseResult value. |
49 | bool has_value() const { return impl.has_value(); } |
50 | |
51 | /// Access the internal ParseResult value. |
52 | ParseResult value() const { return *impl; } |
53 | ParseResult operator*() const { return value(); } |
54 | |
55 | private: |
56 | std::optional<ParseResult> impl; |
57 | }; |
58 | |
59 | // These functions are out-of-line utilities, which avoids them being template |
60 | // instantiated/duplicated. |
61 | namespace impl { |
62 | /// Insert an operation, generated by `buildTerminatorOp`, at the end of the |
63 | /// region's only block if it does not have a terminator already. If the region |
64 | /// is empty, insert a new block first. `buildTerminatorOp` should return the |
65 | /// terminator operation to insert. |
66 | void ensureRegionTerminator( |
67 | Region ®ion, OpBuilder &builder, Location loc, |
68 | function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp); |
69 | void ensureRegionTerminator( |
70 | Region ®ion, Builder &builder, Location loc, |
71 | function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp); |
72 | |
73 | } // namespace impl |
74 | |
75 | /// Structure used by default as a "marker" when no "Properties" are set on an |
76 | /// Operation. |
77 | struct EmptyProperties {}; |
78 | |
79 | /// Traits to detect whether an Operation defined a `Properties` type, otherwise |
80 | /// it'll default to `EmptyProperties`. |
81 | template <class Op, class = void> |
82 | struct PropertiesSelector { |
83 | using type = EmptyProperties; |
84 | }; |
85 | template <class Op> |
86 | struct PropertiesSelector<Op, std::void_t<typename Op::Properties>> { |
87 | using type = typename Op::Properties; |
88 | }; |
89 | |
90 | /// This is the concrete base class that holds the operation pointer and has |
91 | /// non-generic methods that only depend on State (to avoid having them |
92 | /// instantiated on template types that don't affect them. |
93 | /// |
94 | /// This also has the fallback implementations of customization hooks for when |
95 | /// they aren't customized. |
96 | class OpState { |
97 | public: |
98 | /// Ops are pointer-like, so we allow conversion to bool. |
99 | explicit operator bool() { return getOperation() != nullptr; } |
100 | |
101 | /// This implicitly converts to Operation*. |
102 | operator Operation *() const { return state; } |
103 | |
104 | /// Shortcut of `->` to access a member of Operation. |
105 | Operation *operator->() const { return state; } |
106 | |
107 | /// Return the operation that this refers to. |
108 | Operation *getOperation() { return state; } |
109 | |
110 | /// Return the context this operation belongs to. |
111 | MLIRContext *getContext() { return getOperation()->getContext(); } |
112 | |
113 | /// Print the operation to the given stream. |
114 | void print(raw_ostream &os, OpPrintingFlags flags = std::nullopt) { |
115 | state->print(os, flags); |
116 | } |
117 | void print(raw_ostream &os, AsmState &asmState) { |
118 | state->print(os, state&: asmState); |
119 | } |
120 | |
121 | /// Dump this operation. |
122 | void dump() { state->dump(); } |
123 | |
124 | /// The source location the operation was defined or derived from. |
125 | Location getLoc() { return state->getLoc(); } |
126 | |
127 | /// Return true if there are no users of any results of this operation. |
128 | bool use_empty() { return state->use_empty(); } |
129 | |
130 | /// Remove this operation from its parent block and delete it. |
131 | void erase() { state->erase(); } |
132 | |
133 | /// Emit an error with the op name prefixed, like "'dim' op " which is |
134 | /// convenient for verifiers. |
135 | InFlightDiagnostic emitOpError(const Twine &message = {}); |
136 | |
137 | /// Emit an error about fatal conditions with this operation, reporting up to |
138 | /// any diagnostic handlers that may be listening. |
139 | InFlightDiagnostic emitError(const Twine &message = {}); |
140 | |
141 | /// Emit a warning about this operation, reporting up to any diagnostic |
142 | /// handlers that may be listening. |
143 | InFlightDiagnostic emitWarning(const Twine &message = {}); |
144 | |
145 | /// Emit a remark about this operation, reporting up to any diagnostic |
146 | /// handlers that may be listening. |
147 | InFlightDiagnostic (const Twine &message = {}); |
148 | |
149 | /// Walk the operation by calling the callback for each nested operation |
150 | /// (including this one), block or region, depending on the callback provided. |
151 | /// The order in which regions, blocks and operations the same nesting level |
152 | /// are visited (e.g., lexicographical or reverse lexicographical order) is |
153 | /// determined by 'Iterator'. The walk order for enclosing regions, blocks |
154 | /// and operations with respect to their nested ones is specified by 'Order' |
155 | /// (post-order by default). A callback on a block or operation is allowed to |
156 | /// erase that block or operation if either: |
157 | /// * the walk is in post-order, or |
158 | /// * the walk is in pre-order and the walk is skipped after the erasure. |
159 | /// See Operation::walk for more details. |
160 | template <WalkOrder Order = WalkOrder::PostOrder, |
161 | typename Iterator = ForwardIterator, typename FnT, |
162 | typename RetT = detail::walkResultType<FnT>> |
163 | std::enable_if_t<llvm::function_traits<std::decay_t<FnT>>::num_args == 1, |
164 | RetT> |
165 | walk(FnT &&callback) { |
166 | return state->walk<Order, Iterator>(std::forward<FnT>(callback)); |
167 | } |
168 | |
169 | /// Generic walker with a stage aware callback. Walk the operation by calling |
170 | /// the callback for each nested operation (including this one) N+1 times, |
171 | /// where N is the number of regions attached to that operation. |
172 | /// |
173 | /// The callback method can take any of the following forms: |
174 | /// void(Operation *, const WalkStage &) : Walk all operation opaquely |
175 | /// * op.walk([](Operation *nestedOp, const WalkStage &stage) { ...}); |
176 | /// void(OpT, const WalkStage &) : Walk all operations of the given derived |
177 | /// type. |
178 | /// * op.walk([](ReturnOp returnOp, const WalkStage &stage) { ...}); |
179 | /// WalkResult(Operation*|OpT, const WalkStage &stage) : Walk operations, |
180 | /// but allow for interruption/skipping. |
181 | /// * op.walk([](... op, const WalkStage &stage) { |
182 | /// // Skip the walk of this op based on some invariant. |
183 | /// if (some_invariant) |
184 | /// return WalkResult::skip(); |
185 | /// // Interrupt, i.e cancel, the walk based on some invariant. |
186 | /// if (another_invariant) |
187 | /// return WalkResult::interrupt(); |
188 | /// return WalkResult::advance(); |
189 | /// }); |
190 | template <typename FnT, typename RetT = detail::walkResultType<FnT>> |
191 | std::enable_if_t<llvm::function_traits<std::decay_t<FnT>>::num_args == 2, |
192 | RetT> |
193 | walk(FnT &&callback) { |
194 | return state->walk(std::forward<FnT>(callback)); |
195 | } |
196 | |
197 | // These are default implementations of customization hooks. |
198 | public: |
199 | /// This hook returns any canonicalization pattern rewrites that the operation |
200 | /// supports, for use by the canonicalization pass. |
201 | static void getCanonicalizationPatterns(RewritePatternSet &results, |
202 | MLIRContext *context) {} |
203 | |
204 | /// This hook populates any unset default attrs. |
205 | static void populateDefaultAttrs(const OperationName &, NamedAttrList &) {} |
206 | |
207 | protected: |
208 | /// If the concrete type didn't implement a custom verifier hook, just fall |
209 | /// back to this one which accepts everything. |
210 | LogicalResult verify() { return success(); } |
211 | LogicalResult verifyRegions() { return success(); } |
212 | |
213 | /// Parse the custom form of an operation. Unless overridden, this method will |
214 | /// first try to get an operation parser from the op's dialect. Otherwise the |
215 | /// custom assembly form of an op is always rejected. Op implementations |
216 | /// should implement this to return failure. On success, they should fill in |
217 | /// result with the fields to use. |
218 | static ParseResult parse(OpAsmParser &parser, OperationState &result); |
219 | |
220 | /// Print the operation. Unless overridden, this method will first try to get |
221 | /// an operation printer from the dialect. Otherwise, it prints the operation |
222 | /// in generic form. |
223 | static void print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect); |
224 | |
225 | /// Parse properties as a Attribute. |
226 | static ParseResult genericParseProperties(OpAsmParser &parser, |
227 | Attribute &result); |
228 | |
229 | /// Print the properties as a Attribute with names not included within |
230 | /// 'elidedProps' |
231 | static void genericPrintProperties(OpAsmPrinter &p, Attribute properties, |
232 | ArrayRef<StringRef> elidedProps = {}); |
233 | |
234 | /// Print an operation name, eliding the dialect prefix if necessary. |
235 | static void printOpName(Operation *op, OpAsmPrinter &p, |
236 | StringRef defaultDialect); |
237 | |
238 | /// Mutability management is handled by the OpWrapper/OpConstWrapper classes, |
239 | /// so we can cast it away here. |
240 | explicit OpState(Operation *state) : state(state) {} |
241 | |
242 | /// For all op which don't have properties, we keep a single instance of |
243 | /// `EmptyProperties` to be used where a reference to a properties is needed: |
244 | /// this allow to bind a pointer to the reference without triggering UB. |
245 | static EmptyProperties &getEmptyProperties() { |
246 | static EmptyProperties emptyProperties; |
247 | return emptyProperties; |
248 | } |
249 | |
250 | private: |
251 | Operation *state; |
252 | |
253 | /// Allow access to internal hook implementation methods. |
254 | friend RegisteredOperationName; |
255 | }; |
256 | |
257 | // Allow comparing operators. |
258 | inline bool operator==(OpState lhs, OpState rhs) { |
259 | return lhs.getOperation() == rhs.getOperation(); |
260 | } |
261 | inline bool operator!=(OpState lhs, OpState rhs) { |
262 | return lhs.getOperation() != rhs.getOperation(); |
263 | } |
264 | |
265 | raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr); |
266 | |
267 | /// This class represents a single result from folding an operation. |
268 | class OpFoldResult : public PointerUnion<Attribute, Value> { |
269 | using PointerUnion<Attribute, Value>::PointerUnion; |
270 | |
271 | public: |
272 | void dump() const { llvm::errs() << *this << "\n" ; } |
273 | |
274 | MLIRContext *getContext() const { |
275 | return is<Attribute>() ? get<Attribute>().getContext() |
276 | : get<Value>().getContext(); |
277 | } |
278 | }; |
279 | |
280 | // Temporarily exit the MLIR namespace to add casting support as later code in |
281 | // this uses it. The CastInfo must come after the OpFoldResult definition and |
282 | // before any cast function calls depending on CastInfo. |
283 | |
284 | } // namespace mlir |
285 | |
286 | namespace llvm { |
287 | |
288 | // Allow llvm::cast style functions. |
289 | template <typename To> |
290 | struct CastInfo<To, mlir::OpFoldResult> |
291 | : public CastInfo<To, mlir::OpFoldResult::PointerUnion> {}; |
292 | |
293 | template <typename To> |
294 | struct CastInfo<To, const mlir::OpFoldResult> |
295 | : public CastInfo<To, const mlir::OpFoldResult::PointerUnion> {}; |
296 | |
297 | } // namespace llvm |
298 | |
299 | namespace mlir { |
300 | |
301 | /// Allow printing to a stream. |
302 | inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) { |
303 | if (Value value = llvm::dyn_cast_if_present<Value>(Val&: ofr)) |
304 | value.print(os); |
305 | else |
306 | llvm::dyn_cast_if_present<Attribute>(Val&: ofr).print(os); |
307 | return os; |
308 | } |
309 | /// Allow printing to a stream. |
310 | inline raw_ostream &operator<<(raw_ostream &os, OpState op) { |
311 | op.print(os, flags: OpPrintingFlags().useLocalScope()); |
312 | return os; |
313 | } |
314 | |
315 | //===----------------------------------------------------------------------===// |
316 | // Operation Trait Types |
317 | //===----------------------------------------------------------------------===// |
318 | |
319 | namespace OpTrait { |
320 | |
321 | // These functions are out-of-line implementations of the methods in the |
322 | // corresponding trait classes. This avoids them being template |
323 | // instantiated/duplicated. |
324 | namespace impl { |
325 | LogicalResult foldCommutative(Operation *op, ArrayRef<Attribute> operands, |
326 | SmallVectorImpl<OpFoldResult> &results); |
327 | OpFoldResult foldIdempotent(Operation *op); |
328 | OpFoldResult foldInvolution(Operation *op); |
329 | LogicalResult verifyZeroOperands(Operation *op); |
330 | LogicalResult verifyOneOperand(Operation *op); |
331 | LogicalResult verifyNOperands(Operation *op, unsigned numOperands); |
332 | LogicalResult verifyIsIdempotent(Operation *op); |
333 | LogicalResult verifyIsInvolution(Operation *op); |
334 | LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands); |
335 | LogicalResult verifyOperandsAreFloatLike(Operation *op); |
336 | LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op); |
337 | LogicalResult verifySameTypeOperands(Operation *op); |
338 | LogicalResult verifyZeroRegions(Operation *op); |
339 | LogicalResult verifyOneRegion(Operation *op); |
340 | LogicalResult verifyNRegions(Operation *op, unsigned numRegions); |
341 | LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions); |
342 | LogicalResult verifyZeroResults(Operation *op); |
343 | LogicalResult verifyOneResult(Operation *op); |
344 | LogicalResult verifyNResults(Operation *op, unsigned numOperands); |
345 | LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands); |
346 | LogicalResult verifySameOperandsShape(Operation *op); |
347 | LogicalResult verifySameOperandsAndResultShape(Operation *op); |
348 | LogicalResult verifySameOperandsElementType(Operation *op); |
349 | LogicalResult verifySameOperandsAndResultElementType(Operation *op); |
350 | LogicalResult verifySameOperandsAndResultType(Operation *op); |
351 | LogicalResult verifySameOperandsAndResultRank(Operation *op); |
352 | LogicalResult verifyResultsAreBoolLike(Operation *op); |
353 | LogicalResult verifyResultsAreFloatLike(Operation *op); |
354 | LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op); |
355 | LogicalResult verifyIsTerminator(Operation *op); |
356 | LogicalResult verifyZeroSuccessors(Operation *op); |
357 | LogicalResult verifyOneSuccessor(Operation *op); |
358 | LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors); |
359 | LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors); |
360 | LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName, |
361 | StringRef valueGroupName, |
362 | size_t expectedCount); |
363 | LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName); |
364 | LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName); |
365 | LogicalResult verifyNoRegionArguments(Operation *op); |
366 | LogicalResult verifyElementwise(Operation *op); |
367 | LogicalResult verifyIsIsolatedFromAbove(Operation *op); |
368 | } // namespace impl |
369 | |
370 | /// Helper class for implementing traits. Clients are not expected to interact |
371 | /// with this directly, so its members are all protected. |
372 | template <typename ConcreteType, template <typename> class TraitType> |
373 | class TraitBase { |
374 | protected: |
375 | /// Return the ultimate Operation being worked on. |
376 | Operation *getOperation() { |
377 | auto *concrete = static_cast<ConcreteType *>(this); |
378 | return concrete->getOperation(); |
379 | } |
380 | }; |
381 | |
382 | //===----------------------------------------------------------------------===// |
383 | // Operand Traits |
384 | |
385 | namespace detail { |
386 | /// Utility trait base that provides accessors for derived traits that have |
387 | /// multiple operands. |
388 | template <typename ConcreteType, template <typename> class TraitType> |
389 | struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> { |
390 | using operand_iterator = Operation::operand_iterator; |
391 | using operand_range = Operation::operand_range; |
392 | using operand_type_iterator = Operation::operand_type_iterator; |
393 | using operand_type_range = Operation::operand_type_range; |
394 | |
395 | /// Return the number of operands. |
396 | unsigned getNumOperands() { return this->getOperation()->getNumOperands(); } |
397 | |
398 | /// Return the operand at index 'i'. |
399 | Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); } |
400 | |
401 | /// Set the operand at index 'i' to 'value'. |
402 | void setOperand(unsigned i, Value value) { |
403 | this->getOperation()->setOperand(i, value); |
404 | } |
405 | |
406 | /// Operand iterator access. |
407 | operand_iterator operand_begin() { |
408 | return this->getOperation()->operand_begin(); |
409 | } |
410 | operand_iterator operand_end() { return this->getOperation()->operand_end(); } |
411 | operand_range getOperands() { return this->getOperation()->getOperands(); } |
412 | |
413 | /// Operand type access. |
414 | operand_type_iterator operand_type_begin() { |
415 | return this->getOperation()->operand_type_begin(); |
416 | } |
417 | operand_type_iterator operand_type_end() { |
418 | return this->getOperation()->operand_type_end(); |
419 | } |
420 | operand_type_range getOperandTypes() { |
421 | return this->getOperation()->getOperandTypes(); |
422 | } |
423 | }; |
424 | } // namespace detail |
425 | |
426 | /// `verifyInvariantsImpl` verifies the invariants like the types, attrs, .etc. |
427 | /// It should be run after core traits and before any other user defined traits. |
428 | /// In order to run it in the correct order, wrap it with OpInvariants trait so |
429 | /// that tblgen will be able to put it in the right order. |
430 | template <typename ConcreteType> |
431 | class OpInvariants : public TraitBase<ConcreteType, OpInvariants> { |
432 | public: |
433 | static LogicalResult verifyTrait(Operation *op) { |
434 | return cast<ConcreteType>(op).verifyInvariantsImpl(); |
435 | } |
436 | }; |
437 | |
438 | /// This class provides the API for ops that are known to have no |
439 | /// SSA operand. |
440 | template <typename ConcreteType> |
441 | class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> { |
442 | public: |
443 | static LogicalResult verifyTrait(Operation *op) { |
444 | return impl::verifyZeroOperands(op); |
445 | } |
446 | |
447 | private: |
448 | // Disable these. |
449 | void getOperand() {} |
450 | void setOperand() {} |
451 | }; |
452 | |
453 | /// This class provides the API for ops that are known to have exactly one |
454 | /// SSA operand. |
455 | template <typename ConcreteType> |
456 | class OneOperand : public TraitBase<ConcreteType, OneOperand> { |
457 | public: |
458 | Value getOperand() { return this->getOperation()->getOperand(0); } |
459 | |
460 | void setOperand(Value value) { this->getOperation()->setOperand(0, value); } |
461 | |
462 | static LogicalResult verifyTrait(Operation *op) { |
463 | return impl::verifyOneOperand(op); |
464 | } |
465 | }; |
466 | |
467 | /// This class provides the API for ops that are known to have a specified |
468 | /// number of operands. This is used as a trait like this: |
469 | /// |
470 | /// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> { |
471 | /// |
472 | template <unsigned N> |
473 | class NOperands { |
474 | public: |
475 | static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2" ); |
476 | |
477 | template <typename ConcreteType> |
478 | class Impl |
479 | : public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> { |
480 | public: |
481 | static LogicalResult verifyTrait(Operation *op) { |
482 | return impl::verifyNOperands(op, numOperands: N); |
483 | } |
484 | }; |
485 | }; |
486 | |
487 | /// This class provides the API for ops that are known to have a at least a |
488 | /// specified number of operands. This is used as a trait like this: |
489 | /// |
490 | /// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> { |
491 | /// |
492 | template <unsigned N> |
493 | class AtLeastNOperands { |
494 | public: |
495 | template <typename ConcreteType> |
496 | class Impl : public detail::MultiOperandTraitBase<ConcreteType, |
497 | AtLeastNOperands<N>::Impl> { |
498 | public: |
499 | static LogicalResult verifyTrait(Operation *op) { |
500 | return impl::verifyAtLeastNOperands(op, numOperands: N); |
501 | } |
502 | }; |
503 | }; |
504 | |
505 | /// This class provides the API for ops which have an unknown number of |
506 | /// SSA operands. |
507 | template <typename ConcreteType> |
508 | class VariadicOperands |
509 | : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {}; |
510 | |
511 | //===----------------------------------------------------------------------===// |
512 | // Region Traits |
513 | |
514 | /// This class provides verification for ops that are known to have zero |
515 | /// regions. |
516 | template <typename ConcreteType> |
517 | class ZeroRegions : public TraitBase<ConcreteType, ZeroRegions> { |
518 | public: |
519 | static LogicalResult verifyTrait(Operation *op) { |
520 | return impl::verifyZeroRegions(op); |
521 | } |
522 | }; |
523 | |
524 | namespace detail { |
525 | /// Utility trait base that provides accessors for derived traits that have |
526 | /// multiple regions. |
527 | template <typename ConcreteType, template <typename> class TraitType> |
528 | struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> { |
529 | using region_iterator = MutableArrayRef<Region>; |
530 | using region_range = RegionRange; |
531 | |
532 | /// Return the number of regions. |
533 | unsigned getNumRegions() { return this->getOperation()->getNumRegions(); } |
534 | |
535 | /// Return the region at `index`. |
536 | Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); } |
537 | |
538 | /// Region iterator access. |
539 | region_iterator region_begin() { |
540 | return this->getOperation()->region_begin(); |
541 | } |
542 | region_iterator region_end() { return this->getOperation()->region_end(); } |
543 | region_range getRegions() { return this->getOperation()->getRegions(); } |
544 | }; |
545 | } // namespace detail |
546 | |
547 | /// This class provides APIs for ops that are known to have a single region. |
548 | template <typename ConcreteType> |
549 | class OneRegion : public TraitBase<ConcreteType, OneRegion> { |
550 | public: |
551 | Region &getRegion() { return this->getOperation()->getRegion(0); } |
552 | |
553 | /// Returns a range of operations within the region of this operation. |
554 | auto getOps() { return getRegion().getOps(); } |
555 | template <typename OpT> |
556 | auto getOps() { |
557 | return getRegion().template getOps<OpT>(); |
558 | } |
559 | |
560 | static LogicalResult verifyTrait(Operation *op) { |
561 | return impl::verifyOneRegion(op); |
562 | } |
563 | }; |
564 | |
565 | /// This class provides the API for ops that are known to have a specified |
566 | /// number of regions. |
567 | template <unsigned N> |
568 | class NRegions { |
569 | public: |
570 | static_assert(N > 1, "use ZeroRegions/OneRegion for N < 2" ); |
571 | |
572 | template <typename ConcreteType> |
573 | class Impl |
574 | : public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> { |
575 | public: |
576 | static LogicalResult verifyTrait(Operation *op) { |
577 | return impl::verifyNRegions(op, numRegions: N); |
578 | } |
579 | }; |
580 | }; |
581 | |
582 | /// This class provides APIs for ops that are known to have at least a specified |
583 | /// number of regions. |
584 | template <unsigned N> |
585 | class AtLeastNRegions { |
586 | public: |
587 | template <typename ConcreteType> |
588 | class Impl : public detail::MultiRegionTraitBase<ConcreteType, |
589 | AtLeastNRegions<N>::Impl> { |
590 | public: |
591 | static LogicalResult verifyTrait(Operation *op) { |
592 | return impl::verifyAtLeastNRegions(op, numRegions: N); |
593 | } |
594 | }; |
595 | }; |
596 | |
597 | /// This class provides the API for ops which have an unknown number of |
598 | /// regions. |
599 | template <typename ConcreteType> |
600 | class VariadicRegions |
601 | : public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {}; |
602 | |
603 | //===----------------------------------------------------------------------===// |
604 | // Result Traits |
605 | |
606 | /// This class provides return value APIs for ops that are known to have |
607 | /// zero results. |
608 | template <typename ConcreteType> |
609 | class ZeroResults : public TraitBase<ConcreteType, ZeroResults> { |
610 | public: |
611 | static LogicalResult verifyTrait(Operation *op) { |
612 | return impl::verifyZeroResults(op); |
613 | } |
614 | }; |
615 | |
616 | namespace detail { |
617 | /// Utility trait base that provides accessors for derived traits that have |
618 | /// multiple results. |
619 | template <typename ConcreteType, template <typename> class TraitType> |
620 | struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> { |
621 | using result_iterator = Operation::result_iterator; |
622 | using result_range = Operation::result_range; |
623 | using result_type_iterator = Operation::result_type_iterator; |
624 | using result_type_range = Operation::result_type_range; |
625 | |
626 | /// Return the number of results. |
627 | unsigned getNumResults() { return this->getOperation()->getNumResults(); } |
628 | |
629 | /// Return the result at index 'i'. |
630 | Value getResult(unsigned i) { return this->getOperation()->getResult(i); } |
631 | |
632 | /// Replace all uses of results of this operation with the provided 'values'. |
633 | /// 'values' may correspond to an existing operation, or a range of 'Value'. |
634 | template <typename ValuesT> |
635 | void replaceAllUsesWith(ValuesT &&values) { |
636 | this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values)); |
637 | } |
638 | |
639 | /// Return the type of the `i`-th result. |
640 | Type getType(unsigned i) { return getResult(i).getType(); } |
641 | |
642 | /// Result iterator access. |
643 | result_iterator result_begin() { |
644 | return this->getOperation()->result_begin(); |
645 | } |
646 | result_iterator result_end() { return this->getOperation()->result_end(); } |
647 | result_range getResults() { return this->getOperation()->getResults(); } |
648 | |
649 | /// Result type access. |
650 | result_type_iterator result_type_begin() { |
651 | return this->getOperation()->result_type_begin(); |
652 | } |
653 | result_type_iterator result_type_end() { |
654 | return this->getOperation()->result_type_end(); |
655 | } |
656 | result_type_range getResultTypes() { |
657 | return this->getOperation()->getResultTypes(); |
658 | } |
659 | }; |
660 | } // namespace detail |
661 | |
662 | /// This class provides return value APIs for ops that are known to have a |
663 | /// single result. ResultType is the concrete type returned by getType(). |
664 | template <typename ConcreteType> |
665 | class OneResult : public TraitBase<ConcreteType, OneResult> { |
666 | public: |
667 | /// Replace all uses of 'this' value with the new value, updating anything |
668 | /// in the IR that uses 'this' to use the other value instead. When this |
669 | /// returns there are zero uses of 'this'. |
670 | void replaceAllUsesWith(Value newValue) { |
671 | this->getOperation()->getResult(0).replaceAllUsesWith(newValue); |
672 | } |
673 | |
674 | /// Replace all uses of 'this' value with the result of 'op'. |
675 | void replaceAllUsesWith(Operation *op) { |
676 | this->getOperation()->replaceAllUsesWith(op); |
677 | } |
678 | |
679 | static LogicalResult verifyTrait(Operation *op) { |
680 | return impl::verifyOneResult(op); |
681 | } |
682 | }; |
683 | |
684 | /// This trait is used for return value APIs for ops that are known to have a |
685 | /// specific type other than `Type`. This allows the "getType()" member to be |
686 | /// more specific for an op. This should be used in conjunction with OneResult, |
687 | /// and occur in the trait list before OneResult. |
688 | template <typename ResultType> |
689 | class OneTypedResult { |
690 | public: |
691 | /// This class provides return value APIs for ops that are known to have a |
692 | /// single result. ResultType is the concrete type returned by getType(). |
693 | template <typename ConcreteType> |
694 | class Impl |
695 | : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> { |
696 | public: |
697 | mlir::TypedValue<ResultType> getResult() { |
698 | return cast<mlir::TypedValue<ResultType>>( |
699 | this->getOperation()->getResult(0)); |
700 | } |
701 | |
702 | /// If the operation returns a single value, then the Op can be implicitly |
703 | /// converted to a Value. This yields the value of the only result. |
704 | operator mlir::TypedValue<ResultType>() { return getResult(); } |
705 | |
706 | ResultType getType() { return getResult().getType(); } |
707 | }; |
708 | }; |
709 | |
710 | /// This class provides the API for ops that are known to have a specified |
711 | /// number of results. This is used as a trait like this: |
712 | /// |
713 | /// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> { |
714 | /// |
715 | template <unsigned N> |
716 | class NResults { |
717 | public: |
718 | static_assert(N > 1, "use ZeroResults/OneResult for N < 2" ); |
719 | |
720 | template <typename ConcreteType> |
721 | class Impl |
722 | : public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> { |
723 | public: |
724 | static LogicalResult verifyTrait(Operation *op) { |
725 | return impl::verifyNResults(op, numOperands: N); |
726 | } |
727 | }; |
728 | }; |
729 | |
730 | /// This class provides the API for ops that are known to have at least a |
731 | /// specified number of results. This is used as a trait like this: |
732 | /// |
733 | /// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> { |
734 | /// |
735 | template <unsigned N> |
736 | class AtLeastNResults { |
737 | public: |
738 | template <typename ConcreteType> |
739 | class Impl : public detail::MultiResultTraitBase<ConcreteType, |
740 | AtLeastNResults<N>::Impl> { |
741 | public: |
742 | static LogicalResult verifyTrait(Operation *op) { |
743 | return impl::verifyAtLeastNResults(op, numOperands: N); |
744 | } |
745 | }; |
746 | }; |
747 | |
748 | /// This class provides the API for ops which have an unknown number of |
749 | /// results. |
750 | template <typename ConcreteType> |
751 | class VariadicResults |
752 | : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {}; |
753 | |
754 | //===----------------------------------------------------------------------===// |
755 | // Terminator Traits |
756 | |
757 | /// This class indicates that the regions associated with this op don't have |
758 | /// terminators. |
759 | template <typename ConcreteType> |
760 | class NoTerminator : public TraitBase<ConcreteType, NoTerminator> {}; |
761 | |
762 | /// This class provides the API for ops that are known to be terminators. |
763 | template <typename ConcreteType> |
764 | class IsTerminator : public TraitBase<ConcreteType, IsTerminator> { |
765 | public: |
766 | static LogicalResult verifyTrait(Operation *op) { |
767 | return impl::verifyIsTerminator(op); |
768 | } |
769 | }; |
770 | |
771 | /// This class provides verification for ops that are known to have zero |
772 | /// successors. |
773 | template <typename ConcreteType> |
774 | class ZeroSuccessors : public TraitBase<ConcreteType, ZeroSuccessors> { |
775 | public: |
776 | static LogicalResult verifyTrait(Operation *op) { |
777 | return impl::verifyZeroSuccessors(op); |
778 | } |
779 | }; |
780 | |
781 | namespace detail { |
782 | /// Utility trait base that provides accessors for derived traits that have |
783 | /// multiple successors. |
784 | template <typename ConcreteType, template <typename> class TraitType> |
785 | struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> { |
786 | using succ_iterator = Operation::succ_iterator; |
787 | using succ_range = SuccessorRange; |
788 | |
789 | /// Return the number of successors. |
790 | unsigned getNumSuccessors() { |
791 | return this->getOperation()->getNumSuccessors(); |
792 | } |
793 | |
794 | /// Return the successor at `index`. |
795 | Block *getSuccessor(unsigned i) { |
796 | return this->getOperation()->getSuccessor(i); |
797 | } |
798 | |
799 | /// Set the successor at `index`. |
800 | void setSuccessor(Block *block, unsigned i) { |
801 | return this->getOperation()->setSuccessor(block, i); |
802 | } |
803 | |
804 | /// Successor iterator access. |
805 | succ_iterator succ_begin() { return this->getOperation()->succ_begin(); } |
806 | succ_iterator succ_end() { return this->getOperation()->succ_end(); } |
807 | succ_range getSuccessors() { return this->getOperation()->getSuccessors(); } |
808 | }; |
809 | } // namespace detail |
810 | |
811 | /// This class provides APIs for ops that are known to have a single successor. |
812 | template <typename ConcreteType> |
813 | class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> { |
814 | public: |
815 | Block *getSuccessor() { return this->getOperation()->getSuccessor(0); } |
816 | void setSuccessor(Block *succ) { |
817 | this->getOperation()->setSuccessor(succ, 0); |
818 | } |
819 | |
820 | static LogicalResult verifyTrait(Operation *op) { |
821 | return impl::verifyOneSuccessor(op); |
822 | } |
823 | }; |
824 | |
825 | /// This class provides the API for ops that are known to have a specified |
826 | /// number of successors. |
827 | template <unsigned N> |
828 | class NSuccessors { |
829 | public: |
830 | static_assert(N > 1, "use ZeroSuccessors/OneSuccessor for N < 2" ); |
831 | |
832 | template <typename ConcreteType> |
833 | class Impl : public detail::MultiSuccessorTraitBase<ConcreteType, |
834 | NSuccessors<N>::Impl> { |
835 | public: |
836 | static LogicalResult verifyTrait(Operation *op) { |
837 | return impl::verifyNSuccessors(op, numSuccessors: N); |
838 | } |
839 | }; |
840 | }; |
841 | |
842 | /// This class provides APIs for ops that are known to have at least a specified |
843 | /// number of successors. |
844 | template <unsigned N> |
845 | class AtLeastNSuccessors { |
846 | public: |
847 | template <typename ConcreteType> |
848 | class Impl |
849 | : public detail::MultiSuccessorTraitBase<ConcreteType, |
850 | AtLeastNSuccessors<N>::Impl> { |
851 | public: |
852 | static LogicalResult verifyTrait(Operation *op) { |
853 | return impl::verifyAtLeastNSuccessors(op, numSuccessors: N); |
854 | } |
855 | }; |
856 | }; |
857 | |
858 | /// This class provides the API for ops which have an unknown number of |
859 | /// successors. |
860 | template <typename ConcreteType> |
861 | class VariadicSuccessors |
862 | : public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> { |
863 | }; |
864 | |
865 | //===----------------------------------------------------------------------===// |
866 | // SingleBlock |
867 | |
868 | /// This class provides APIs and verifiers for ops with regions having a single |
869 | /// block. |
870 | template <typename ConcreteType> |
871 | struct SingleBlock : public TraitBase<ConcreteType, SingleBlock> { |
872 | public: |
873 | static LogicalResult verifyTrait(Operation *op) { |
874 | for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { |
875 | Region ®ion = op->getRegion(index: i); |
876 | |
877 | // Empty regions are fine. |
878 | if (region.empty()) |
879 | continue; |
880 | |
881 | // Non-empty regions must contain a single basic block. |
882 | if (!llvm::hasSingleElement(C&: region)) |
883 | return op->emitOpError(message: "expects region #" ) |
884 | << i << " to have 0 or 1 blocks" ; |
885 | |
886 | if (!ConcreteType::template hasTrait<NoTerminator>()) { |
887 | Block &block = region.front(); |
888 | if (block.empty()) |
889 | return op->emitOpError() << "expects a non-empty block" ; |
890 | } |
891 | } |
892 | return success(); |
893 | } |
894 | |
895 | Block *getBody(unsigned idx = 0) { |
896 | Region ®ion = this->getOperation()->getRegion(idx); |
897 | assert(!region.empty() && "unexpected empty region" ); |
898 | return ®ion.front(); |
899 | } |
900 | Region &getBodyRegion(unsigned idx = 0) { |
901 | return this->getOperation()->getRegion(idx); |
902 | } |
903 | |
904 | //===------------------------------------------------------------------===// |
905 | // Single Region Utilities |
906 | //===------------------------------------------------------------------===// |
907 | |
908 | /// The following are a set of methods only enabled when the parent |
909 | /// operation has a single region. Each of these methods take an additional |
910 | /// template parameter that represents the concrete operation so that we |
911 | /// can use SFINAE to disable the methods for non-single region operations. |
912 | template <typename OpT, typename T = void> |
913 | using enable_if_single_region = |
914 | std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>; |
915 | |
916 | template <typename OpT = ConcreteType> |
917 | enable_if_single_region<OpT, Block::iterator> begin() { |
918 | return getBody()->begin(); |
919 | } |
920 | template <typename OpT = ConcreteType> |
921 | enable_if_single_region<OpT, Block::iterator> end() { |
922 | return getBody()->end(); |
923 | } |
924 | template <typename OpT = ConcreteType> |
925 | enable_if_single_region<OpT, Operation &> front() { |
926 | return *begin(); |
927 | } |
928 | |
929 | /// Insert the operation into the back of the body. |
930 | template <typename OpT = ConcreteType> |
931 | enable_if_single_region<OpT> push_back(Operation *op) { |
932 | insert(Block::iterator(getBody()->end()), op); |
933 | } |
934 | |
935 | /// Insert the operation at the given insertion point. |
936 | template <typename OpT = ConcreteType> |
937 | enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) { |
938 | insert(Block::iterator(insertPt), op); |
939 | } |
940 | template <typename OpT = ConcreteType> |
941 | enable_if_single_region<OpT> insert(Block::iterator insertPt, Operation *op) { |
942 | getBody()->getOperations().insert(insertPt, op); |
943 | } |
944 | }; |
945 | |
946 | //===----------------------------------------------------------------------===// |
947 | // SingleBlockImplicitTerminator |
948 | |
949 | /// This class provides APIs and verifiers for ops with regions having a single |
950 | /// block that must terminate with `TerminatorOpType`. |
951 | template <typename TerminatorOpType> |
952 | struct SingleBlockImplicitTerminator { |
953 | template <typename ConcreteType> |
954 | class Impl : public TraitBase<ConcreteType, SingleBlockImplicitTerminator< |
955 | TerminatorOpType>::Impl> { |
956 | private: |
957 | /// Builds a terminator operation without relying on OpBuilder APIs to avoid |
958 | /// cyclic header inclusion. |
959 | static Operation *buildTerminator(OpBuilder &builder, Location loc) { |
960 | OperationState state(loc, TerminatorOpType::getOperationName()); |
961 | TerminatorOpType::build(builder, state); |
962 | return Operation::create(state); |
963 | } |
964 | |
965 | public: |
966 | /// The type of the operation used as the implicit terminator type. |
967 | using ImplicitTerminatorOpT = TerminatorOpType; |
968 | |
969 | static LogicalResult verifyRegionTrait(Operation *op) { |
970 | for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { |
971 | Region ®ion = op->getRegion(index: i); |
972 | // Empty regions are fine. |
973 | if (region.empty()) |
974 | continue; |
975 | Operation &terminator = region.front().back(); |
976 | if (isa<TerminatorOpType>(terminator)) |
977 | continue; |
978 | |
979 | return op->emitOpError(message: "expects regions to end with '" + |
980 | TerminatorOpType::getOperationName() + |
981 | "', found '" + |
982 | terminator.getName().getStringRef() + "'" ) |
983 | .attachNote() |
984 | << "in custom textual format, the absence of terminator implies " |
985 | "'" |
986 | << TerminatorOpType::getOperationName() << '\''; |
987 | } |
988 | |
989 | return success(); |
990 | } |
991 | |
992 | /// Ensure that the given region has the terminator required by this trait. |
993 | /// If OpBuilder is provided, use it to build the terminator and notify the |
994 | /// OpBuilder listeners accordingly. If only a Builder is provided, locally |
995 | /// construct an OpBuilder with no listeners; this should only be used if no |
996 | /// OpBuilder is available at the call site, e.g., in the parser. |
997 | static void ensureTerminator(Region ®ion, Builder &builder, |
998 | Location loc) { |
999 | ::mlir::impl::ensureRegionTerminator(region, builder, loc, |
1000 | buildTerminatorOp: buildTerminator); |
1001 | } |
1002 | static void ensureTerminator(Region ®ion, OpBuilder &builder, |
1003 | Location loc) { |
1004 | ::mlir::impl::ensureRegionTerminator(region, builder, loc, |
1005 | buildTerminatorOp: buildTerminator); |
1006 | } |
1007 | }; |
1008 | }; |
1009 | |
1010 | /// Check is an op defines the `ImplicitTerminatorOpT` member. This is intended |
1011 | /// to be used with `llvm::is_detected`. |
1012 | template <class T> |
1013 | using has_implicit_terminator_t = typename T::ImplicitTerminatorOpT; |
1014 | |
1015 | /// Support to check if an operation has the SingleBlockImplicitTerminator |
1016 | /// trait. We can't just use `hasTrait` because this class is templated on a |
1017 | /// specific terminator op. |
1018 | template <class Op, bool hasTerminator = |
1019 | llvm::is_detected<has_implicit_terminator_t, Op>::value> |
1020 | struct hasSingleBlockImplicitTerminator { |
1021 | static constexpr bool value = std::is_base_of< |
1022 | typename OpTrait::SingleBlockImplicitTerminator< |
1023 | typename Op::ImplicitTerminatorOpT>::template Impl<Op>, |
1024 | Op>::value; |
1025 | }; |
1026 | template <class Op> |
1027 | struct hasSingleBlockImplicitTerminator<Op, false> { |
1028 | static constexpr bool value = false; |
1029 | }; |
1030 | |
1031 | //===----------------------------------------------------------------------===// |
1032 | // Misc Traits |
1033 | |
1034 | /// This class provides verification for ops that are known to have the same |
1035 | /// operand shape: all operands are scalars, vectors/tensors of the same |
1036 | /// shape. |
1037 | template <typename ConcreteType> |
1038 | class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> { |
1039 | public: |
1040 | static LogicalResult verifyTrait(Operation *op) { |
1041 | return impl::verifySameOperandsShape(op); |
1042 | } |
1043 | }; |
1044 | |
1045 | /// This class provides verification for ops that are known to have the same |
1046 | /// operand and result shape: both are scalars, vectors/tensors of the same |
1047 | /// shape. |
1048 | template <typename ConcreteType> |
1049 | class SameOperandsAndResultShape |
1050 | : public TraitBase<ConcreteType, SameOperandsAndResultShape> { |
1051 | public: |
1052 | static LogicalResult verifyTrait(Operation *op) { |
1053 | return impl::verifySameOperandsAndResultShape(op); |
1054 | } |
1055 | }; |
1056 | |
1057 | /// This class provides verification for ops that are known to have the same |
1058 | /// operand element type (or the type itself if it is scalar). |
1059 | /// |
1060 | template <typename ConcreteType> |
1061 | class SameOperandsElementType |
1062 | : public TraitBase<ConcreteType, SameOperandsElementType> { |
1063 | public: |
1064 | static LogicalResult verifyTrait(Operation *op) { |
1065 | return impl::verifySameOperandsElementType(op); |
1066 | } |
1067 | }; |
1068 | |
1069 | /// This class provides verification for ops that are known to have the same |
1070 | /// operand and result element type (or the type itself if it is scalar). |
1071 | /// |
1072 | template <typename ConcreteType> |
1073 | class SameOperandsAndResultElementType |
1074 | : public TraitBase<ConcreteType, SameOperandsAndResultElementType> { |
1075 | public: |
1076 | static LogicalResult verifyTrait(Operation *op) { |
1077 | return impl::verifySameOperandsAndResultElementType(op); |
1078 | } |
1079 | }; |
1080 | |
1081 | /// This class provides verification for ops that are known to have the same |
1082 | /// operand and result type. |
1083 | /// |
1084 | /// Note: this trait subsumes the SameOperandsAndResultShape and |
1085 | /// SameOperandsAndResultElementType traits. |
1086 | template <typename ConcreteType> |
1087 | class SameOperandsAndResultType |
1088 | : public TraitBase<ConcreteType, SameOperandsAndResultType> { |
1089 | public: |
1090 | static LogicalResult verifyTrait(Operation *op) { |
1091 | return impl::verifySameOperandsAndResultType(op); |
1092 | } |
1093 | }; |
1094 | |
1095 | /// This class verifies that op has same ranks for all |
1096 | /// operands and results types, if known. |
1097 | template <typename ConcreteType> |
1098 | class SameOperandsAndResultRank |
1099 | : public TraitBase<ConcreteType, SameOperandsAndResultRank> { |
1100 | public: |
1101 | static LogicalResult verifyTrait(Operation *op) { |
1102 | return impl::verifySameOperandsAndResultRank(op); |
1103 | } |
1104 | }; |
1105 | |
1106 | /// This class verifies that any results of the specified op have a boolean |
1107 | /// type, a vector thereof, or a tensor thereof. |
1108 | template <typename ConcreteType> |
1109 | class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> { |
1110 | public: |
1111 | static LogicalResult verifyTrait(Operation *op) { |
1112 | return impl::verifyResultsAreBoolLike(op); |
1113 | } |
1114 | }; |
1115 | |
1116 | /// This class verifies that any results of the specified op have a floating |
1117 | /// point type, a vector thereof, or a tensor thereof. |
1118 | template <typename ConcreteType> |
1119 | class ResultsAreFloatLike |
1120 | : public TraitBase<ConcreteType, ResultsAreFloatLike> { |
1121 | public: |
1122 | static LogicalResult verifyTrait(Operation *op) { |
1123 | return impl::verifyResultsAreFloatLike(op); |
1124 | } |
1125 | }; |
1126 | |
1127 | /// This class verifies that any results of the specified op have a signless |
1128 | /// integer or index type, a vector thereof, or a tensor thereof. |
1129 | template <typename ConcreteType> |
1130 | class ResultsAreSignlessIntegerLike |
1131 | : public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> { |
1132 | public: |
1133 | static LogicalResult verifyTrait(Operation *op) { |
1134 | return impl::verifyResultsAreSignlessIntegerLike(op); |
1135 | } |
1136 | }; |
1137 | |
1138 | /// This class adds property that the operation is commutative. |
1139 | template <typename ConcreteType> |
1140 | class IsCommutative : public TraitBase<ConcreteType, IsCommutative> { |
1141 | public: |
1142 | static LogicalResult foldTrait(Operation *op, ArrayRef<Attribute> operands, |
1143 | SmallVectorImpl<OpFoldResult> &results) { |
1144 | return impl::foldCommutative(op, operands, results); |
1145 | } |
1146 | }; |
1147 | |
1148 | /// This class adds property that the operation is an involution. |
1149 | /// This means a unary to unary operation "f" that satisfies f(f(x)) = x |
1150 | template <typename ConcreteType> |
1151 | class IsInvolution : public TraitBase<ConcreteType, IsInvolution> { |
1152 | public: |
1153 | static LogicalResult verifyTrait(Operation *op) { |
1154 | static_assert(ConcreteType::template hasTrait<OneResult>(), |
1155 | "expected operation to produce one result" ); |
1156 | static_assert(ConcreteType::template hasTrait<OneOperand>(), |
1157 | "expected operation to take one operand" ); |
1158 | static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(), |
1159 | "expected operation to preserve type" ); |
1160 | // Involution requires the operation to be side effect free as well |
1161 | // but currently this check is under a FIXME and is not actually done. |
1162 | return impl::verifyIsInvolution(op); |
1163 | } |
1164 | |
1165 | static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) { |
1166 | return impl::foldInvolution(op); |
1167 | } |
1168 | }; |
1169 | |
1170 | /// This class adds property that the operation is idempotent. |
1171 | /// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x), |
1172 | /// or a binary operation "g" that satisfies g(x, x) = x. |
1173 | template <typename ConcreteType> |
1174 | class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> { |
1175 | public: |
1176 | static LogicalResult verifyTrait(Operation *op) { |
1177 | static_assert(ConcreteType::template hasTrait<OneResult>(), |
1178 | "expected operation to produce one result" ); |
1179 | static_assert(ConcreteType::template hasTrait<OneOperand>() || |
1180 | ConcreteType::template hasTrait<NOperands<2>::Impl>(), |
1181 | "expected operation to take one or two operands" ); |
1182 | static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(), |
1183 | "expected operation to preserve type" ); |
1184 | // Idempotent requires the operation to be side effect free as well |
1185 | // but currently this check is under a FIXME and is not actually done. |
1186 | return impl::verifyIsIdempotent(op); |
1187 | } |
1188 | |
1189 | static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) { |
1190 | return impl::foldIdempotent(op); |
1191 | } |
1192 | }; |
1193 | |
1194 | /// This class verifies that all operands of the specified op have a float type, |
1195 | /// a vector thereof, or a tensor thereof. |
1196 | template <typename ConcreteType> |
1197 | class OperandsAreFloatLike |
1198 | : public TraitBase<ConcreteType, OperandsAreFloatLike> { |
1199 | public: |
1200 | static LogicalResult verifyTrait(Operation *op) { |
1201 | return impl::verifyOperandsAreFloatLike(op); |
1202 | } |
1203 | }; |
1204 | |
1205 | /// This class verifies that all operands of the specified op have a signless |
1206 | /// integer or index type, a vector thereof, or a tensor thereof. |
1207 | template <typename ConcreteType> |
1208 | class OperandsAreSignlessIntegerLike |
1209 | : public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> { |
1210 | public: |
1211 | static LogicalResult verifyTrait(Operation *op) { |
1212 | return impl::verifyOperandsAreSignlessIntegerLike(op); |
1213 | } |
1214 | }; |
1215 | |
1216 | /// This class verifies that all operands of the specified op have the same |
1217 | /// type. |
1218 | template <typename ConcreteType> |
1219 | class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> { |
1220 | public: |
1221 | static LogicalResult verifyTrait(Operation *op) { |
1222 | return impl::verifySameTypeOperands(op); |
1223 | } |
1224 | }; |
1225 | |
1226 | /// This class provides the API for a sub-set of ops that are known to be |
1227 | /// constant-like. These are non-side effecting operations with one result and |
1228 | /// zero operands that can always be folded to a specific attribute value. |
1229 | template <typename ConcreteType> |
1230 | class ConstantLike : public TraitBase<ConcreteType, ConstantLike> { |
1231 | public: |
1232 | static LogicalResult verifyTrait(Operation *op) { |
1233 | static_assert(ConcreteType::template hasTrait<OneResult>(), |
1234 | "expected operation to produce one result" ); |
1235 | static_assert(ConcreteType::template hasTrait<ZeroOperands>(), |
1236 | "expected operation to take zero operands" ); |
1237 | // TODO: We should verify that the operation can always be folded, but this |
1238 | // requires that the attributes of the op already be verified. We should add |
1239 | // support for verifying traits "after" the operation to enable this use |
1240 | // case. |
1241 | return success(); |
1242 | } |
1243 | }; |
1244 | |
1245 | /// This class provides the API for ops that are known to be isolated from |
1246 | /// above. |
1247 | template <typename ConcreteType> |
1248 | class IsIsolatedFromAbove |
1249 | : public TraitBase<ConcreteType, IsIsolatedFromAbove> { |
1250 | public: |
1251 | static LogicalResult verifyRegionTrait(Operation *op) { |
1252 | return impl::verifyIsIsolatedFromAbove(op); |
1253 | } |
1254 | }; |
1255 | |
1256 | /// A trait of region holding operations that defines a new scope for polyhedral |
1257 | /// optimization purposes. Any SSA values of 'index' type that either dominate |
1258 | /// such an operation or are used at the top-level of such an operation |
1259 | /// automatically become valid symbols for the polyhedral scope defined by that |
1260 | /// operation. For more details, see `Traits.md#AffineScope`. |
1261 | template <typename ConcreteType> |
1262 | class AffineScope : public TraitBase<ConcreteType, AffineScope> { |
1263 | public: |
1264 | static LogicalResult verifyTrait(Operation *op) { |
1265 | static_assert(!ConcreteType::template hasTrait<ZeroRegions>(), |
1266 | "expected operation to have one or more regions" ); |
1267 | return success(); |
1268 | } |
1269 | }; |
1270 | |
1271 | /// A trait of region holding operations that define a new scope for automatic |
1272 | /// allocations, i.e., allocations that are freed when control is transferred |
1273 | /// back from the operation's region. Any operations performing such allocations |
1274 | /// (for eg. memref.alloca) will have their allocations automatically freed at |
1275 | /// their closest enclosing operation with this trait. |
1276 | template <typename ConcreteType> |
1277 | class AutomaticAllocationScope |
1278 | : public TraitBase<ConcreteType, AutomaticAllocationScope> { |
1279 | public: |
1280 | static LogicalResult verifyTrait(Operation *op) { |
1281 | static_assert(!ConcreteType::template hasTrait<ZeroRegions>(), |
1282 | "expected operation to have one or more regions" ); |
1283 | return success(); |
1284 | } |
1285 | }; |
1286 | |
1287 | /// This class provides a verifier for ops that are expecting their parent |
1288 | /// to be one of the given parent ops |
1289 | template <typename... ParentOpTypes> |
1290 | struct HasParent { |
1291 | template <typename ConcreteType> |
1292 | class Impl : public TraitBase<ConcreteType, Impl> { |
1293 | public: |
1294 | static LogicalResult verifyTrait(Operation *op) { |
1295 | if (llvm::isa_and_nonnull<ParentOpTypes...>(op->getParentOp())) |
1296 | return success(); |
1297 | |
1298 | return op->emitOpError() |
1299 | << "expects parent op " |
1300 | << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'" ) |
1301 | << llvm::ArrayRef({ParentOpTypes::getOperationName()...}) << "'" ; |
1302 | } |
1303 | |
1304 | template <typename ParentOpType = |
1305 | std::tuple_element_t<0, std::tuple<ParentOpTypes...>>> |
1306 | std::enable_if_t<sizeof...(ParentOpTypes) == 1, ParentOpType> |
1307 | getParentOp() { |
1308 | Operation *parent = this->getOperation()->getParentOp(); |
1309 | return llvm::cast<ParentOpType>(parent); |
1310 | } |
1311 | }; |
1312 | }; |
1313 | |
1314 | /// A trait for operations that have an attribute specifying operand segments. |
1315 | /// |
1316 | /// Certain operations can have multiple variadic operands and their size |
1317 | /// relationship is not always known statically. For such cases, we need |
1318 | /// a per-op-instance specification to divide the operands into logical groups |
1319 | /// or segments. This can be modeled by attributes. The attribute will be named |
1320 | /// as `operandSegmentSizes`. |
1321 | /// |
1322 | /// This trait verifies the attribute for specifying operand segments has |
1323 | /// the correct type (1D vector) and values (non-negative), etc. |
1324 | template <typename ConcreteType> |
1325 | class AttrSizedOperandSegments |
1326 | : public TraitBase<ConcreteType, AttrSizedOperandSegments> { |
1327 | public: |
1328 | static StringRef getOperandSegmentSizeAttr() { return "operandSegmentSizes" ; } |
1329 | |
1330 | static LogicalResult verifyTrait(Operation *op) { |
1331 | return ::mlir::OpTrait::impl::verifyOperandSizeAttr( |
1332 | op, sizeAttrName: getOperandSegmentSizeAttr()); |
1333 | } |
1334 | }; |
1335 | |
1336 | /// Similar to AttrSizedOperandSegments but used for results. |
1337 | template <typename ConcreteType> |
1338 | class AttrSizedResultSegments |
1339 | : public TraitBase<ConcreteType, AttrSizedResultSegments> { |
1340 | public: |
1341 | static StringRef getResultSegmentSizeAttr() { return "resultSegmentSizes" ; } |
1342 | |
1343 | static LogicalResult verifyTrait(Operation *op) { |
1344 | return ::mlir::OpTrait::impl::verifyResultSizeAttr( |
1345 | op, sizeAttrName: getResultSegmentSizeAttr()); |
1346 | } |
1347 | }; |
1348 | |
1349 | /// This trait provides a verifier for ops that are expecting their regions to |
1350 | /// not have any arguments |
1351 | template <typename ConcrentType> |
1352 | struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> { |
1353 | static LogicalResult verifyTrait(Operation *op) { |
1354 | return ::mlir::OpTrait::impl::verifyNoRegionArguments(op); |
1355 | } |
1356 | }; |
1357 | |
1358 | // This trait is used to flag operations that consume or produce |
1359 | // values of `MemRef` type where those references can be 'normalized'. |
1360 | // TODO: Right now, the operands of an operation are either all normalizable, |
1361 | // or not. In the future, we may want to allow some of the operands to be |
1362 | // normalizable. |
1363 | template <typename ConcrentType> |
1364 | struct MemRefsNormalizable |
1365 | : public TraitBase<ConcrentType, MemRefsNormalizable> {}; |
1366 | |
1367 | /// This trait tags element-wise ops on vectors or tensors. |
1368 | /// |
1369 | /// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this |
1370 | /// trait. In particular, broadcasting behavior is not allowed. |
1371 | /// |
1372 | /// An `Elementwise` op must satisfy the following properties: |
1373 | /// |
1374 | /// 1. If any result is a vector/tensor then at least one operand must also be a |
1375 | /// vector/tensor. |
1376 | /// 2. If any operand is a vector/tensor then there must be at least one result |
1377 | /// and all results must be vectors/tensors. |
1378 | /// 3. All operand and result vector/tensor types must be of the same shape. The |
1379 | /// shape may be dynamic in which case the op's behaviour is undefined for |
1380 | /// non-matching shapes. |
1381 | /// 4. The operation must be elementwise on its vector/tensor operands and |
1382 | /// results. When applied to single-element vectors/tensors, the result must |
1383 | /// be the same per elememnt. |
1384 | /// |
1385 | /// TODO: Avoid hardcoding vector/tensor, and generalize this trait to a new |
1386 | /// interface `ElementwiseTypeInterface` that describes the container types for |
1387 | /// which the operation is elementwise. |
1388 | /// |
1389 | /// Rationale: |
1390 | /// - 1. and 2. guarantee a well-defined iteration space and exclude the cases |
1391 | /// of 0 non-scalar operands or 0 non-scalar results, which complicate a |
1392 | /// generic definition of the iteration space. |
1393 | /// - 3. guarantees that folding can be done across scalars/vectors/tensors with |
1394 | /// the same pattern, as otherwise lots of special handling for type |
1395 | /// mismatches would be needed. |
1396 | /// - 4. guarantees that no error handling is needed. Higher-level dialects |
1397 | /// should reify any needed guards or error handling code before lowering to |
1398 | /// an `Elementwise` op. |
1399 | template <typename ConcreteType> |
1400 | struct Elementwise : public TraitBase<ConcreteType, Elementwise> { |
1401 | static LogicalResult verifyTrait(Operation *op) { |
1402 | return ::mlir::OpTrait::impl::verifyElementwise(op); |
1403 | } |
1404 | }; |
1405 | |
1406 | /// This trait tags `Elementwise` operatons that can be systematically |
1407 | /// scalarized. All vector/tensor operands and results are then replaced by |
1408 | /// scalars of the respective element type. Semantically, this is the operation |
1409 | /// on a single element of the vector/tensor. |
1410 | /// |
1411 | /// Rationale: |
1412 | /// Allow to define the vector/tensor semantics of elementwise operations based |
1413 | /// on the same op's behavior on scalars. This provides a constructive procedure |
1414 | /// for IR transformations to, e.g., create scalar loop bodies from tensor ops. |
1415 | /// |
1416 | /// Example: |
1417 | /// ``` |
1418 | /// %tensor_select = "arith.select"(%pred_tensor, %true_val, %false_val) |
1419 | /// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) |
1420 | /// -> tensor<?xf32> |
1421 | /// ``` |
1422 | /// can be scalarized to |
1423 | /// |
1424 | /// ``` |
1425 | /// %scalar_select = "arith.select"(%pred, %true_val_scalar, %false_val_scalar) |
1426 | /// : (i1, f32, f32) -> f32 |
1427 | /// ``` |
1428 | template <typename ConcreteType> |
1429 | struct Scalarizable : public TraitBase<ConcreteType, Scalarizable> { |
1430 | static LogicalResult verifyTrait(Operation *op) { |
1431 | static_assert( |
1432 | ConcreteType::template hasTrait<Elementwise>(), |
1433 | "`Scalarizable` trait is only applicable to `Elementwise` ops." ); |
1434 | return success(); |
1435 | } |
1436 | }; |
1437 | |
1438 | /// This trait tags `Elementwise` operatons that can be systematically |
1439 | /// vectorized. All scalar operands and results are then replaced by vectors |
1440 | /// with the respective element type. Semantically, this is the operation on |
1441 | /// multiple elements simultaneously. See also `Tensorizable`. |
1442 | /// |
1443 | /// Rationale: |
1444 | /// Provide the reverse to `Scalarizable` which, when chained together, allows |
1445 | /// reasoning about the relationship between the tensor and vector case. |
1446 | /// Additionally, it permits reasoning about promoting scalars to vectors via |
1447 | /// broadcasting in cases like `%select_scalar_pred` below. |
1448 | template <typename ConcreteType> |
1449 | struct Vectorizable : public TraitBase<ConcreteType, Vectorizable> { |
1450 | static LogicalResult verifyTrait(Operation *op) { |
1451 | static_assert( |
1452 | ConcreteType::template hasTrait<Elementwise>(), |
1453 | "`Vectorizable` trait is only applicable to `Elementwise` ops." ); |
1454 | return success(); |
1455 | } |
1456 | }; |
1457 | |
1458 | /// This trait tags `Elementwise` operatons that can be systematically |
1459 | /// tensorized. All scalar operands and results are then replaced by tensors |
1460 | /// with the respective element type. Semantically, this is the operation on |
1461 | /// multiple elements simultaneously. See also `Vectorizable`. |
1462 | /// |
1463 | /// Rationale: |
1464 | /// Provide the reverse to `Scalarizable` which, when chained together, allows |
1465 | /// reasoning about the relationship between the tensor and vector case. |
1466 | /// Additionally, it permits reasoning about promoting scalars to tensors via |
1467 | /// broadcasting in cases like `%select_scalar_pred` below. |
1468 | /// |
1469 | /// Examples: |
1470 | /// ``` |
1471 | /// %scalar = "arith.addf"(%a, %b) : (f32, f32) -> f32 |
1472 | /// ``` |
1473 | /// can be tensorized to |
1474 | /// ``` |
1475 | /// %tensor = "arith.addf"(%a, %b) : (tensor<?xf32>, tensor<?xf32>) |
1476 | /// -> tensor<?xf32> |
1477 | /// ``` |
1478 | /// |
1479 | /// ``` |
1480 | /// %scalar_pred = "arith.select"(%pred, %true_val, %false_val) |
1481 | /// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> |
1482 | /// ``` |
1483 | /// can be tensorized to |
1484 | /// ``` |
1485 | /// %tensor_pred = "arith.select"(%pred, %true_val, %false_val) |
1486 | /// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) |
1487 | /// -> tensor<?xf32> |
1488 | /// ``` |
1489 | template <typename ConcreteType> |
1490 | struct Tensorizable : public TraitBase<ConcreteType, Tensorizable> { |
1491 | static LogicalResult verifyTrait(Operation *op) { |
1492 | static_assert( |
1493 | ConcreteType::template hasTrait<Elementwise>(), |
1494 | "`Tensorizable` trait is only applicable to `Elementwise` ops." ); |
1495 | return success(); |
1496 | } |
1497 | }; |
1498 | |
1499 | /// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable` |
1500 | /// provide an easy way for scalar operations to conveniently generalize their |
1501 | /// behavior to vectors/tensors, and systematize conversion between these forms. |
1502 | bool hasElementwiseMappableTraits(Operation *op); |
1503 | |
1504 | } // namespace OpTrait |
1505 | |
1506 | //===----------------------------------------------------------------------===// |
1507 | // Internal Trait Utilities |
1508 | //===----------------------------------------------------------------------===// |
1509 | |
1510 | namespace op_definition_impl { |
1511 | //===----------------------------------------------------------------------===// |
1512 | // Trait Existence |
1513 | |
1514 | /// Returns true if this given Trait ID matches the IDs of any of the provided |
1515 | /// trait types `Traits`. |
1516 | template <template <typename T> class... Traits> |
1517 | inline bool hasTrait(TypeID traitID) { |
1518 | TypeID traitIDs[] = {TypeID::get<Traits>()...}; |
1519 | for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i) |
1520 | if (traitIDs[i] == traitID) |
1521 | return true; |
1522 | return false; |
1523 | } |
1524 | template <> |
1525 | inline bool hasTrait<>(TypeID traitID) { |
1526 | return false; |
1527 | } |
1528 | |
1529 | //===----------------------------------------------------------------------===// |
1530 | // Trait Folding |
1531 | |
1532 | /// Trait to check if T provides a 'foldTrait' method for single result |
1533 | /// operations. |
1534 | template <typename T, typename... Args> |
1535 | using has_single_result_fold_trait = decltype(T::foldTrait( |
1536 | std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>())); |
1537 | template <typename T> |
1538 | using detect_has_single_result_fold_trait = |
1539 | llvm::is_detected<has_single_result_fold_trait, T>; |
1540 | /// Trait to check if T provides a general 'foldTrait' method. |
1541 | template <typename T, typename... Args> |
1542 | using has_fold_trait = |
1543 | decltype(T::foldTrait(std::declval<Operation *>(), |
1544 | std::declval<ArrayRef<Attribute>>(), |
1545 | std::declval<SmallVectorImpl<OpFoldResult> &>())); |
1546 | template <typename T> |
1547 | using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>; |
1548 | /// Trait to check if T provides any `foldTrait` method. |
1549 | template <typename T> |
1550 | using detect_has_any_fold_trait = |
1551 | std::disjunction<detect_has_fold_trait<T>, |
1552 | detect_has_single_result_fold_trait<T>>; |
1553 | |
1554 | /// Returns the result of folding a trait that implements a `foldTrait` function |
1555 | /// that is specialized for operations that have a single result. |
1556 | template <typename Trait> |
1557 | static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value, |
1558 | LogicalResult> |
1559 | foldTrait(Operation *op, ArrayRef<Attribute> operands, |
1560 | SmallVectorImpl<OpFoldResult> &results) { |
1561 | assert(op->hasTrait<OpTrait::OneResult>() && |
1562 | "expected trait on non single-result operation to implement the " |
1563 | "general `foldTrait` method" ); |
1564 | // If a previous trait has already been folded and replaced this operation, we |
1565 | // fail to fold this trait. |
1566 | if (!results.empty()) |
1567 | return failure(); |
1568 | |
1569 | if (OpFoldResult result = Trait::foldTrait(op, operands)) { |
1570 | if (llvm::dyn_cast_if_present<Value>(Val&: result) != op->getResult(idx: 0)) |
1571 | results.push_back(Elt: result); |
1572 | return success(); |
1573 | } |
1574 | return failure(); |
1575 | } |
1576 | /// Returns the result of folding a trait that implements a generalized |
1577 | /// `foldTrait` function that is supports any operation type. |
1578 | template <typename Trait> |
1579 | static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult> |
1580 | foldTrait(Operation *op, ArrayRef<Attribute> operands, |
1581 | SmallVectorImpl<OpFoldResult> &results) { |
1582 | // If a previous trait has already been folded and replaced this operation, we |
1583 | // fail to fold this trait. |
1584 | return results.empty() ? Trait::foldTrait(op, operands, results) : failure(); |
1585 | } |
1586 | template <typename Trait> |
1587 | static inline std::enable_if_t<!detect_has_any_fold_trait<Trait>::value, |
1588 | LogicalResult> |
1589 | foldTrait(Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) { |
1590 | return failure(); |
1591 | } |
1592 | |
1593 | /// Given a tuple type containing a set of traits, return the result of folding |
1594 | /// the given operation. |
1595 | template <typename... Ts> |
1596 | static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands, |
1597 | SmallVectorImpl<OpFoldResult> &results) { |
1598 | return success((succeeded(foldTrait<Ts>(op, operands, results)) || ...)); |
1599 | } |
1600 | |
1601 | //===----------------------------------------------------------------------===// |
1602 | // Trait Verification |
1603 | |
1604 | /// Trait to check if T provides a `verifyTrait` method. |
1605 | template <typename T, typename... Args> |
1606 | using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>())); |
1607 | template <typename T> |
1608 | using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>; |
1609 | |
1610 | /// Trait to check if T provides a `verifyTrait` method. |
1611 | template <typename T, typename... Args> |
1612 | using has_verify_region_trait = |
1613 | decltype(T::verifyRegionTrait(std::declval<Operation *>())); |
1614 | template <typename T> |
1615 | using detect_has_verify_region_trait = |
1616 | llvm::is_detected<has_verify_region_trait, T>; |
1617 | |
1618 | /// Verify the given trait if it provides a verifier. |
1619 | template <typename T> |
1620 | std::enable_if_t<detect_has_verify_trait<T>::value, LogicalResult> |
1621 | verifyTrait(Operation *op) { |
1622 | return T::verifyTrait(op); |
1623 | } |
1624 | template <typename T> |
1625 | inline std::enable_if_t<!detect_has_verify_trait<T>::value, LogicalResult> |
1626 | verifyTrait(Operation *) { |
1627 | return success(); |
1628 | } |
1629 | |
1630 | /// Given a set of traits, return the result of verifying the given operation. |
1631 | template <typename... Ts> |
1632 | LogicalResult verifyTraits(Operation *op) { |
1633 | return success((succeeded(verifyTrait<Ts>(op)) && ...)); |
1634 | } |
1635 | |
1636 | /// Verify the given trait if it provides a region verifier. |
1637 | template <typename T> |
1638 | std::enable_if_t<detect_has_verify_region_trait<T>::value, LogicalResult> |
1639 | verifyRegionTrait(Operation *op) { |
1640 | return T::verifyRegionTrait(op); |
1641 | } |
1642 | template <typename T> |
1643 | inline std::enable_if_t<!detect_has_verify_region_trait<T>::value, |
1644 | LogicalResult> |
1645 | verifyRegionTrait(Operation *) { |
1646 | return success(); |
1647 | } |
1648 | |
1649 | /// Given a set of traits, return the result of verifying the regions of the |
1650 | /// given operation. |
1651 | template <typename... Ts> |
1652 | LogicalResult verifyRegionTraits(Operation *op) { |
1653 | return success((succeeded(verifyRegionTrait<Ts>(op)) && ...)); |
1654 | } |
1655 | } // namespace op_definition_impl |
1656 | |
1657 | //===----------------------------------------------------------------------===// |
1658 | // Operation Definition classes |
1659 | //===----------------------------------------------------------------------===// |
1660 | |
1661 | /// This provides public APIs that all operations should have. The template |
1662 | /// argument 'ConcreteType' should be the concrete type by CRTP and the others |
1663 | /// are base classes by the policy pattern. |
1664 | template <typename ConcreteType, template <typename T> class... Traits> |
1665 | class Op : public OpState, public Traits<ConcreteType>... { |
1666 | public: |
1667 | /// Inherit getOperation from `OpState`. |
1668 | using OpState::getOperation; |
1669 | using OpState::verify; |
1670 | using OpState::verifyRegions; |
1671 | |
1672 | /// Return if this operation contains the provided trait. |
1673 | template <template <typename T> class Trait> |
1674 | static constexpr bool hasTrait() { |
1675 | return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value; |
1676 | } |
1677 | |
1678 | /// Create a deep copy of this operation. |
1679 | ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); } |
1680 | |
1681 | /// Create a partial copy of this operation without traversing into attached |
1682 | /// regions. The new operation will have the same number of regions as the |
1683 | /// original one, but they will be left empty. |
1684 | ConcreteType cloneWithoutRegions() { |
1685 | return cast<ConcreteType>(getOperation()->cloneWithoutRegions()); |
1686 | } |
1687 | |
1688 | /// Return true if this "op class" can match against the specified operation. |
1689 | static bool classof(Operation *op) { |
1690 | if (auto info = op->getRegisteredInfo()) |
1691 | return TypeID::get<ConcreteType>() == info->getTypeID(); |
1692 | #ifndef NDEBUG |
1693 | if (op->getName().getStringRef() == ConcreteType::getOperationName()) |
1694 | llvm::report_fatal_error( |
1695 | "classof on '" + ConcreteType::getOperationName() + |
1696 | "' failed due to the operation not being registered" ); |
1697 | #endif |
1698 | return false; |
1699 | } |
1700 | /// Provide `classof` support for other OpBase derived classes, such as |
1701 | /// Interfaces. |
1702 | template <typename T> |
1703 | static std::enable_if_t<std::is_base_of<OpState, T>::value, bool> |
1704 | classof(const T *op) { |
1705 | return classof(const_cast<T *>(op)->getOperation()); |
1706 | } |
1707 | |
1708 | /// Expose the type we are instantiated on to template machinery that may want |
1709 | /// to introspect traits on this operation. |
1710 | using ConcreteOpType = ConcreteType; |
1711 | |
1712 | /// This is a public constructor. Any op can be initialized to null. |
1713 | explicit Op() : OpState(nullptr) {} |
1714 | Op(std::nullptr_t) : OpState(nullptr) {} |
1715 | |
1716 | /// This is a public constructor to enable access via the llvm::cast family of |
1717 | /// methods. This should not be used directly. |
1718 | explicit Op(Operation *state) : OpState(state) {} |
1719 | |
1720 | /// Methods for supporting PointerLikeTypeTraits. |
1721 | const void *getAsOpaquePointer() const { |
1722 | return static_cast<const void *>((Operation *)*this); |
1723 | } |
1724 | static ConcreteOpType getFromOpaquePointer(const void *pointer) { |
1725 | return ConcreteOpType( |
1726 | reinterpret_cast<Operation *>(const_cast<void *>(pointer))); |
1727 | } |
1728 | |
1729 | /// Attach the given models as implementations of the corresponding |
1730 | /// interfaces for the concrete operation. |
1731 | template <typename... Models> |
1732 | static void attachInterface(MLIRContext &context) { |
1733 | std::optional<RegisteredOperationName> info = |
1734 | RegisteredOperationName::lookup(TypeID::get<ConcreteType>(), &context); |
1735 | if (!info) |
1736 | llvm::report_fatal_error( |
1737 | "Attempting to attach an interface to an unregistered operation " + |
1738 | ConcreteType::getOperationName() + "." ); |
1739 | (checkInterfaceTarget<Models>(), ...); |
1740 | info->attachInterface<Models...>(); |
1741 | } |
1742 | /// Convert the provided attribute to a property and assigned it to the |
1743 | /// provided properties. This default implementation forwards to a free |
1744 | /// function `setPropertiesFromAttribute` that can be looked up with ADL in |
1745 | /// the namespace where the properties are defined. It can also be overridden |
1746 | /// in the derived ConcreteOp. |
1747 | template <typename PropertiesTy> |
1748 | static LogicalResult |
1749 | setPropertiesFromAttr(PropertiesTy &prop, Attribute attr, |
1750 | function_ref<InFlightDiagnostic()> emitError) { |
1751 | return setPropertiesFromAttribute(prop, attr, emitError); |
1752 | } |
1753 | /// Convert the provided properties to an attribute. This default |
1754 | /// implementation forwards to a free function `getPropertiesAsAttribute` that |
1755 | /// can be looked up with ADL in the namespace where the properties are |
1756 | /// defined. It can also be overridden in the derived ConcreteOp. |
1757 | template <typename PropertiesTy> |
1758 | static Attribute getPropertiesAsAttr(MLIRContext *ctx, |
1759 | const PropertiesTy &prop) { |
1760 | return getPropertiesAsAttribute(ctx, prop); |
1761 | } |
1762 | /// Hash the provided properties. This default implementation forwards to a |
1763 | /// free function `computeHash` that can be looked up with ADL in the |
1764 | /// namespace where the properties are defined. It can also be overridden in |
1765 | /// the derived ConcreteOp. |
1766 | template <typename PropertiesTy> |
1767 | static llvm::hash_code computePropertiesHash(const PropertiesTy &prop) { |
1768 | return computeHash(prop); |
1769 | } |
1770 | |
1771 | private: |
1772 | /// Trait to check if T provides a 'fold' method for a single result op. |
1773 | template <typename T, typename... Args> |
1774 | using has_single_result_fold_t = |
1775 | decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>())); |
1776 | template <typename T> |
1777 | constexpr static bool has_single_result_fold_v = |
1778 | llvm::is_detected<has_single_result_fold_t, T>::value; |
1779 | /// Trait to check if T provides a general 'fold' method. |
1780 | template <typename T, typename... Args> |
1781 | using has_fold_t = decltype(std::declval<T>().fold( |
1782 | std::declval<ArrayRef<Attribute>>(), |
1783 | std::declval<SmallVectorImpl<OpFoldResult> &>())); |
1784 | template <typename T> |
1785 | constexpr static bool has_fold_v = llvm::is_detected<has_fold_t, T>::value; |
1786 | /// Trait to check if T provides a 'fold' method with a FoldAdaptor for a |
1787 | /// single result op. |
1788 | template <typename T, typename... Args> |
1789 | using has_fold_adaptor_single_result_fold_t = |
1790 | decltype(std::declval<T>().fold(std::declval<typename T::FoldAdaptor>())); |
1791 | template <class T> |
1792 | constexpr static bool has_fold_adaptor_single_result_v = |
1793 | llvm::is_detected<has_fold_adaptor_single_result_fold_t, T>::value; |
1794 | /// Trait to check if T provides a general 'fold' method with a FoldAdaptor. |
1795 | template <typename T, typename... Args> |
1796 | using has_fold_adaptor_fold_t = decltype(std::declval<T>().fold( |
1797 | std::declval<typename T::FoldAdaptor>(), |
1798 | std::declval<SmallVectorImpl<OpFoldResult> &>())); |
1799 | template <class T> |
1800 | constexpr static bool has_fold_adaptor_v = |
1801 | llvm::is_detected<has_fold_adaptor_fold_t, T>::value; |
1802 | |
1803 | /// Trait to check if T provides a 'print' method. |
1804 | template <typename T, typename... Args> |
1805 | using has_print = |
1806 | decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>())); |
1807 | template <typename T> |
1808 | using detect_has_print = llvm::is_detected<has_print, T>; |
1809 | |
1810 | /// Trait to check if printProperties(OpAsmPrinter, T, ArrayRef<StringRef>) |
1811 | /// exist |
1812 | template <typename T, typename... Args> |
1813 | using has_print_properties = |
1814 | decltype(printProperties(std::declval<OpAsmPrinter &>(), |
1815 | std::declval<T>(), |
1816 | std::declval<ArrayRef<StringRef>>())); |
1817 | template <typename T> |
1818 | using detect_has_print_properties = |
1819 | llvm::is_detected<has_print_properties, T>; |
1820 | |
1821 | /// Trait to check if parseProperties(OpAsmParser, T) exist |
1822 | template <typename T, typename... Args> |
1823 | using has_parse_properties = decltype(parseProperties( |
1824 | std::declval<OpAsmParser &>(), std::declval<T &>())); |
1825 | template <typename T> |
1826 | using detect_has_parse_properties = |
1827 | llvm::is_detected<has_parse_properties, T>; |
1828 | |
1829 | /// Trait to check if T provides a 'ConcreteEntity' type alias. |
1830 | template <typename T> |
1831 | using has_concrete_entity_t = typename T::ConcreteEntity; |
1832 | |
1833 | public: |
1834 | /// Returns true if this operation defines a `Properties` inner type. |
1835 | static constexpr bool hasProperties() { |
1836 | return !std::is_same_v< |
1837 | typename ConcreteType::template InferredProperties<ConcreteType>, |
1838 | EmptyProperties>; |
1839 | } |
1840 | |
1841 | private: |
1842 | /// A struct-wrapped type alias to T::ConcreteEntity if provided and to |
1843 | /// ConcreteType otherwise. This is akin to std::conditional but doesn't fail |
1844 | /// on the missing typedef. Useful for checking if the interface is targeting |
1845 | /// the right class. |
1846 | template <typename T, |
1847 | bool = llvm::is_detected<has_concrete_entity_t, T>::value> |
1848 | struct InterfaceTargetOrOpT { |
1849 | using type = typename T::ConcreteEntity; |
1850 | }; |
1851 | template <typename T> |
1852 | struct InterfaceTargetOrOpT<T, false> { |
1853 | using type = ConcreteType; |
1854 | }; |
1855 | |
1856 | /// A hook for static assertion that the external interface model T is |
1857 | /// targeting the concrete type of this op. The model can also be a fallback |
1858 | /// model that works for every op. |
1859 | template <typename T> |
1860 | static void checkInterfaceTarget() { |
1861 | static_assert(std::is_same<typename InterfaceTargetOrOpT<T>::type, |
1862 | ConcreteType>::value, |
1863 | "attaching an interface to the wrong op kind" ); |
1864 | } |
1865 | |
1866 | /// Returns an interface map containing the interfaces registered to this |
1867 | /// operation. |
1868 | static detail::InterfaceMap getInterfaceMap() { |
1869 | return detail::InterfaceMap::template get<Traits<ConcreteType>...>(); |
1870 | } |
1871 | |
1872 | /// Return the internal implementations of each of the OperationName |
1873 | /// hooks. |
1874 | /// Implementation of `FoldHookFn` OperationName hook. |
1875 | static OperationName::FoldHookFn getFoldHookFn() { |
1876 | // If the operation is single result and defines a `fold` method. |
1877 | if constexpr (llvm::is_one_of<OpTrait::OneResult<ConcreteType>, |
1878 | Traits<ConcreteType>...>::value && |
1879 | (has_single_result_fold_v<ConcreteType> || |
1880 | has_fold_adaptor_single_result_v<ConcreteType>)) |
1881 | return [](Operation *op, ArrayRef<Attribute> operands, |
1882 | SmallVectorImpl<OpFoldResult> &results) { |
1883 | return foldSingleResultHook<ConcreteType>(op, operands, results); |
1884 | }; |
1885 | // The operation is not single result and defines a `fold` method. |
1886 | if constexpr (has_fold_v<ConcreteType> || has_fold_adaptor_v<ConcreteType>) |
1887 | return [](Operation *op, ArrayRef<Attribute> operands, |
1888 | SmallVectorImpl<OpFoldResult> &results) { |
1889 | return foldHook<ConcreteType>(op, operands, results); |
1890 | }; |
1891 | // The operation does not define a `fold` method. |
1892 | return [](Operation *op, ArrayRef<Attribute> operands, |
1893 | SmallVectorImpl<OpFoldResult> &results) { |
1894 | // In this case, we only need to fold the traits of the operation. |
1895 | return op_definition_impl::foldTraits<Traits<ConcreteType>...>( |
1896 | op, operands, results); |
1897 | }; |
1898 | } |
1899 | /// Return the result of folding a single result operation that defines a |
1900 | /// `fold` method. |
1901 | template <typename ConcreteOpT> |
1902 | static LogicalResult |
1903 | foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands, |
1904 | SmallVectorImpl<OpFoldResult> &results) { |
1905 | OpFoldResult result; |
1906 | if constexpr (has_fold_adaptor_single_result_v<ConcreteOpT>) { |
1907 | result = cast<ConcreteOpT>(op).fold( |
1908 | typename ConcreteOpT::FoldAdaptor(operands, cast<ConcreteOpT>(op))); |
1909 | } else { |
1910 | result = cast<ConcreteOpT>(op).fold(operands); |
1911 | } |
1912 | |
1913 | // If the fold failed or was in-place, try to fold the traits of the |
1914 | // operation. |
1915 | if (!result || |
1916 | llvm::dyn_cast_if_present<Value>(Val&: result) == op->getResult(idx: 0)) { |
1917 | if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>( |
1918 | op, operands, results))) |
1919 | return success(); |
1920 | return success(isSuccess: static_cast<bool>(result)); |
1921 | } |
1922 | results.push_back(Elt: result); |
1923 | return success(); |
1924 | } |
1925 | /// Return the result of folding an operation that defines a `fold` method. |
1926 | template <typename ConcreteOpT> |
1927 | static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands, |
1928 | SmallVectorImpl<OpFoldResult> &results) { |
1929 | auto result = LogicalResult::failure(); |
1930 | if constexpr (has_fold_adaptor_v<ConcreteOpT>) { |
1931 | result = cast<ConcreteOpT>(op).fold( |
1932 | typename ConcreteOpT::FoldAdaptor(operands, cast<ConcreteOpT>(op)), |
1933 | results); |
1934 | } else { |
1935 | result = cast<ConcreteOpT>(op).fold(operands, results); |
1936 | } |
1937 | |
1938 | // If the fold failed or was in-place, try to fold the traits of the |
1939 | // operation. |
1940 | if (failed(result) || results.empty()) { |
1941 | if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>( |
1942 | op, operands, results))) |
1943 | return success(); |
1944 | } |
1945 | return result; |
1946 | } |
1947 | |
1948 | /// Implementation of `GetHasTraitFn` |
1949 | static OperationName::HasTraitFn getHasTraitFn() { |
1950 | return |
1951 | [](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); }; |
1952 | } |
1953 | /// Implementation of `PrintAssemblyFn` OperationName hook. |
1954 | static OperationName::PrintAssemblyFn getPrintAssemblyFn() { |
1955 | if constexpr (detect_has_print<ConcreteType>::value) |
1956 | return [](Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { |
1957 | OpState::printOpName(op, p, defaultDialect); |
1958 | return cast<ConcreteType>(op).print(p); |
1959 | }; |
1960 | return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) { |
1961 | return OpState::print(op, p&: printer, defaultDialect); |
1962 | }; |
1963 | } |
1964 | |
1965 | public: |
1966 | template <typename T> |
1967 | using InferredProperties = typename PropertiesSelector<T>::type; |
1968 | template <typename T = ConcreteType> |
1969 | InferredProperties<T> &getProperties() { |
1970 | if constexpr (!hasProperties()) |
1971 | return getEmptyProperties(); |
1972 | return *getOperation() |
1973 | ->getPropertiesStorageUnsafe() |
1974 | .template as<InferredProperties<T> *>(); |
1975 | } |
1976 | |
1977 | /// This hook populates any unset default attrs when mapped to properties. |
1978 | template <typename T = ConcreteType> |
1979 | static void populateDefaultProperties(OperationName opName, |
1980 | InferredProperties<T> &properties) {} |
1981 | |
1982 | /// Print the operation properties with names not included within |
1983 | /// 'elidedProps'. Unless overridden, this method will try to dispatch to a |
1984 | /// `printProperties` free-function if it exists, and otherwise by converting |
1985 | /// the properties to an Attribute. |
1986 | template <typename T> |
1987 | static void printProperties(MLIRContext *ctx, OpAsmPrinter &p, |
1988 | const T &properties, |
1989 | ArrayRef<StringRef> elidedProps = {}) { |
1990 | if constexpr (detect_has_print_properties<T>::value) |
1991 | return printProperties(p, properties, elidedProps); |
1992 | genericPrintProperties( |
1993 | p, properties: ConcreteType::getPropertiesAsAttr(ctx, properties), elidedProps); |
1994 | } |
1995 | |
1996 | /// Parses 'prop-dict' for the operation. Unless overridden, the method will |
1997 | /// parse the properties using the generic property dictionary using the |
1998 | /// '<{ ... }>' syntax. The resulting properties are stored within the |
1999 | /// property structure of 'result', accessible via 'getOrAddProperties'. |
2000 | template <typename T = ConcreteType> |
2001 | static ParseResult parseProperties(OpAsmParser &parser, |
2002 | OperationState &result) { |
2003 | if constexpr (detect_has_parse_properties<InferredProperties<T>>::value) { |
2004 | return parseProperties( |
2005 | parser, result.getOrAddProperties<InferredProperties<T>>()); |
2006 | } |
2007 | |
2008 | Attribute propertyDictionary; |
2009 | if (genericParseProperties(parser, result&: propertyDictionary)) |
2010 | return failure(); |
2011 | |
2012 | // The generated 'setPropertiesFromParsedAttr', like |
2013 | // 'setPropertiesFromAttr', expects a 'DictionaryAttr' that is not null. |
2014 | // Use an empty dictionary in the case that the whole dictionary is |
2015 | // optional. |
2016 | if (!propertyDictionary) |
2017 | propertyDictionary = DictionaryAttr::get(result.getContext()); |
2018 | |
2019 | auto emitError = [&]() { |
2020 | return mlir::emitError(loc: result.location, message: "invalid properties " ) |
2021 | << propertyDictionary << " for op " << result.name.getStringRef() |
2022 | << ": " ; |
2023 | }; |
2024 | |
2025 | // Copy the data from the dictionary attribute into the property struct of |
2026 | // the operation. This method is generated by ODS by default if there are |
2027 | // any occurrences of 'prop-dict' in the assembly format and should set |
2028 | // any properties that aren't parsed elsewhere. |
2029 | return ConcreteOpType::setPropertiesFromParsedAttr( |
2030 | result.getOrAddProperties<InferredProperties<T>>(), propertyDictionary, |
2031 | emitError); |
2032 | } |
2033 | |
2034 | private: |
2035 | /// Implementation of `PopulateDefaultAttrsFn` OperationName hook. |
2036 | static OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn() { |
2037 | return ConcreteType::populateDefaultAttrs; |
2038 | } |
2039 | /// Implementation of `VerifyInvariantsFn` OperationName hook. |
2040 | static LogicalResult verifyInvariants(Operation *op) { |
2041 | static_assert(hasNoDataMembers(), |
2042 | "Op class shouldn't define new data members" ); |
2043 | return failure( |
2044 | failed(op_definition_impl::verifyTraits<Traits<ConcreteType>...>(op)) || |
2045 | failed(cast<ConcreteType>(op).verify())); |
2046 | } |
2047 | static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() { |
2048 | return static_cast<LogicalResult (*)(Operation *)>(&verifyInvariants); |
2049 | } |
2050 | /// Implementation of `VerifyRegionInvariantsFn` OperationName hook. |
2051 | static LogicalResult verifyRegionInvariants(Operation *op) { |
2052 | static_assert(hasNoDataMembers(), |
2053 | "Op class shouldn't define new data members" ); |
2054 | return failure( |
2055 | failed(op_definition_impl::verifyRegionTraits<Traits<ConcreteType>...>( |
2056 | op)) || |
2057 | failed(cast<ConcreteType>(op).verifyRegions())); |
2058 | } |
2059 | static OperationName::VerifyRegionInvariantsFn getVerifyRegionInvariantsFn() { |
2060 | return static_cast<LogicalResult (*)(Operation *)>(&verifyRegionInvariants); |
2061 | } |
2062 | |
2063 | static constexpr bool hasNoDataMembers() { |
2064 | // Checking that the derived class does not define any member by comparing |
2065 | // its size to an ad-hoc EmptyOp. |
2066 | class EmptyOp : public Op<EmptyOp, Traits...> {}; |
2067 | return sizeof(ConcreteType) == sizeof(EmptyOp); |
2068 | } |
2069 | |
2070 | /// Allow access to internal implementation methods. |
2071 | friend RegisteredOperationName; |
2072 | }; |
2073 | |
2074 | /// This class represents the base of an operation interface. See the definition |
2075 | /// of `detail::Interface` for requirements on the `Traits` type. |
2076 | template <typename ConcreteType, typename Traits> |
2077 | class OpInterface |
2078 | : public detail::Interface<ConcreteType, Operation *, Traits, |
2079 | Op<ConcreteType>, OpTrait::TraitBase> { |
2080 | public: |
2081 | using Base = OpInterface<ConcreteType, Traits>; |
2082 | using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits, |
2083 | Op<ConcreteType>, OpTrait::TraitBase>; |
2084 | |
2085 | /// Inherit the base class constructor. |
2086 | using InterfaceBase::InterfaceBase; |
2087 | |
2088 | protected: |
2089 | /// Returns the impl interface instance for the given operation. |
2090 | static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) { |
2091 | OperationName name = op->getName(); |
2092 | |
2093 | #ifndef NDEBUG |
2094 | // Check that the current interface isn't an unresolved promise for the |
2095 | // given operation. |
2096 | if (Dialect *dialect = name.getDialect()) { |
2097 | dialect_extension_detail::handleUseOfUndefinedPromisedInterface( |
2098 | dialect&: *dialect, interfaceRequestorID: name.getTypeID(), interfaceID: ConcreteType::getInterfaceID(), |
2099 | interfaceName: llvm::getTypeName<ConcreteType>()); |
2100 | } |
2101 | #endif |
2102 | |
2103 | // Access the raw interface from the operation info. |
2104 | if (std::optional<RegisteredOperationName> rInfo = |
2105 | name.getRegisteredInfo()) { |
2106 | if (auto *opIface = rInfo->getInterface<ConcreteType>()) |
2107 | return opIface; |
2108 | // Fallback to the dialect to provide it with a chance to implement this |
2109 | // interface for this operation. |
2110 | return rInfo->getDialect().getRegisteredInterfaceForOp<ConcreteType>( |
2111 | op->getName()); |
2112 | } |
2113 | // Fallback to the dialect to provide it with a chance to implement this |
2114 | // interface for this operation. |
2115 | if (Dialect *dialect = name.getDialect()) |
2116 | return dialect->getRegisteredInterfaceForOp<ConcreteType>(name); |
2117 | return nullptr; |
2118 | } |
2119 | |
2120 | /// Allow access to `getInterfaceFor`. |
2121 | friend InterfaceBase; |
2122 | }; |
2123 | |
2124 | } // namespace mlir |
2125 | |
2126 | namespace llvm { |
2127 | |
2128 | template <typename T> |
2129 | struct DenseMapInfo<T, |
2130 | std::enable_if_t<std::is_base_of<mlir::OpState, T>::value && |
2131 | !mlir::detail::IsInterface<T>::value>> { |
2132 | static inline T getEmptyKey() { |
2133 | auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); |
2134 | return T::getFromOpaquePointer(pointer); |
2135 | } |
2136 | static inline T getTombstoneKey() { |
2137 | auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); |
2138 | return T::getFromOpaquePointer(pointer); |
2139 | } |
2140 | static unsigned getHashValue(T val) { |
2141 | return hash_value(val.getAsOpaquePointer()); |
2142 | } |
2143 | static bool isEqual(T lhs, T rhs) { return lhs == rhs; } |
2144 | }; |
2145 | } // namespace llvm |
2146 | |
2147 | #endif |
2148 | |