1//===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements helper classes for implementing the "Op" types. This
10// includes the Op type, which is the base class for Op class definitions,
11// as well as number of traits in the OpTrait namespace that provide a
12// declarative way to specify properties of Ops.
13//
14// The purpose of these types are to allow light-weight implementation of
15// concrete ops (like DimOp) with very little boilerplate.
16//
17//===----------------------------------------------------------------------===//
18
19#ifndef MLIR_IR_OPDEFINITION_H
20#define MLIR_IR_OPDEFINITION_H
21
22#include "mlir/IR/Dialect.h"
23#include "mlir/IR/ODSSupport.h"
24#include "mlir/IR/Operation.h"
25#include "llvm/Support/PointerLikeTypeTraits.h"
26
27#include <optional>
28#include <type_traits>
29
30namespace mlir {
31class Builder;
32class OpBuilder;
33
34/// This class implements `Optional` functionality for ParseResult. We don't
35/// directly use Optional here, because it provides an implicit conversion
36/// to 'bool' which we want to avoid. This class is used to implement tri-state
37/// 'parseOptional' functions that may have a failure mode when parsing that
38/// shouldn't be attributed to "not present".
39class OptionalParseResult {
40public:
41 OptionalParseResult() = default;
42 OptionalParseResult(LogicalResult result) : impl(result) {}
43 OptionalParseResult(ParseResult result) : impl(result) {}
44 OptionalParseResult(const InFlightDiagnostic &)
45 : OptionalParseResult(failure()) {}
46 OptionalParseResult(std::nullopt_t) : impl(std::nullopt) {}
47
48 /// Returns true if we contain a valid ParseResult value.
49 bool has_value() const { return impl.has_value(); }
50
51 /// Access the internal ParseResult value.
52 ParseResult value() const { return *impl; }
53 ParseResult operator*() const { return value(); }
54
55private:
56 std::optional<ParseResult> impl;
57};
58
59// These functions are out-of-line utilities, which avoids them being template
60// instantiated/duplicated.
61namespace impl {
62/// Insert an operation, generated by `buildTerminatorOp`, at the end of the
63/// region's only block if it does not have a terminator already. If the region
64/// is empty, insert a new block first. `buildTerminatorOp` should return the
65/// terminator operation to insert.
66void ensureRegionTerminator(
67 Region &region, OpBuilder &builder, Location loc,
68 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
69void ensureRegionTerminator(
70 Region &region, Builder &builder, Location loc,
71 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
72
73} // namespace impl
74
75/// Structure used by default as a "marker" when no "Properties" are set on an
76/// Operation.
77struct EmptyProperties {};
78
79/// Traits to detect whether an Operation defined a `Properties` type, otherwise
80/// it'll default to `EmptyProperties`.
81template <class Op, class = void>
82struct PropertiesSelector {
83 using type = EmptyProperties;
84};
85template <class Op>
86struct PropertiesSelector<Op, std::void_t<typename Op::Properties>> {
87 using type = typename Op::Properties;
88};
89
90/// This is the concrete base class that holds the operation pointer and has
91/// non-generic methods that only depend on State (to avoid having them
92/// instantiated on template types that don't affect them.
93///
94/// This also has the fallback implementations of customization hooks for when
95/// they aren't customized.
96class OpState {
97public:
98 /// Ops are pointer-like, so we allow conversion to bool.
99 explicit operator bool() { return getOperation() != nullptr; }
100
101 /// This implicitly converts to Operation*.
102 operator Operation *() const { return state; }
103
104 /// Shortcut of `->` to access a member of Operation.
105 Operation *operator->() const { return state; }
106
107 /// Return the operation that this refers to.
108 Operation *getOperation() { return state; }
109
110 /// Return the context this operation belongs to.
111 MLIRContext *getContext() { return getOperation()->getContext(); }
112
113 /// Print the operation to the given stream.
114 void print(raw_ostream &os, OpPrintingFlags flags = std::nullopt) {
115 state->print(os, flags);
116 }
117 void print(raw_ostream &os, AsmState &asmState) {
118 state->print(os, state&: asmState);
119 }
120
121 /// Dump this operation.
122 void dump() { state->dump(); }
123
124 /// The source location the operation was defined or derived from.
125 Location getLoc() { return state->getLoc(); }
126
127 /// Return true if there are no users of any results of this operation.
128 bool use_empty() { return state->use_empty(); }
129
130 /// Remove this operation from its parent block and delete it.
131 void erase() { state->erase(); }
132
133 /// Emit an error with the op name prefixed, like "'dim' op " which is
134 /// convenient for verifiers.
135 InFlightDiagnostic emitOpError(const Twine &message = {});
136
137 /// Emit an error about fatal conditions with this operation, reporting up to
138 /// any diagnostic handlers that may be listening.
139 InFlightDiagnostic emitError(const Twine &message = {});
140
141 /// Emit a warning about this operation, reporting up to any diagnostic
142 /// handlers that may be listening.
143 InFlightDiagnostic emitWarning(const Twine &message = {});
144
145 /// Emit a remark about this operation, reporting up to any diagnostic
146 /// handlers that may be listening.
147 InFlightDiagnostic emitRemark(const Twine &message = {});
148
149 /// Walk the operation by calling the callback for each nested operation
150 /// (including this one), block or region, depending on the callback provided.
151 /// The order in which regions, blocks and operations the same nesting level
152 /// are visited (e.g., lexicographical or reverse lexicographical order) is
153 /// determined by 'Iterator'. The walk order for enclosing regions, blocks
154 /// and operations with respect to their nested ones is specified by 'Order'
155 /// (post-order by default). A callback on a block or operation is allowed to
156 /// erase that block or operation if either:
157 /// * the walk is in post-order, or
158 /// * the walk is in pre-order and the walk is skipped after the erasure.
159 /// See Operation::walk for more details.
160 template <WalkOrder Order = WalkOrder::PostOrder,
161 typename Iterator = ForwardIterator, typename FnT,
162 typename RetT = detail::walkResultType<FnT>>
163 std::enable_if_t<llvm::function_traits<std::decay_t<FnT>>::num_args == 1,
164 RetT>
165 walk(FnT &&callback) {
166 return state->walk<Order, Iterator>(std::forward<FnT>(callback));
167 }
168
169 /// Generic walker with a stage aware callback. Walk the operation by calling
170 /// the callback for each nested operation (including this one) N+1 times,
171 /// where N is the number of regions attached to that operation.
172 ///
173 /// The callback method can take any of the following forms:
174 /// void(Operation *, const WalkStage &) : Walk all operation opaquely
175 /// * op.walk([](Operation *nestedOp, const WalkStage &stage) { ...});
176 /// void(OpT, const WalkStage &) : Walk all operations of the given derived
177 /// type.
178 /// * op.walk([](ReturnOp returnOp, const WalkStage &stage) { ...});
179 /// WalkResult(Operation*|OpT, const WalkStage &stage) : Walk operations,
180 /// but allow for interruption/skipping.
181 /// * op.walk([](... op, const WalkStage &stage) {
182 /// // Skip the walk of this op based on some invariant.
183 /// if (some_invariant)
184 /// return WalkResult::skip();
185 /// // Interrupt, i.e cancel, the walk based on some invariant.
186 /// if (another_invariant)
187 /// return WalkResult::interrupt();
188 /// return WalkResult::advance();
189 /// });
190 template <typename FnT, typename RetT = detail::walkResultType<FnT>>
191 std::enable_if_t<llvm::function_traits<std::decay_t<FnT>>::num_args == 2,
192 RetT>
193 walk(FnT &&callback) {
194 return state->walk(std::forward<FnT>(callback));
195 }
196
197 // These are default implementations of customization hooks.
198public:
199 /// This hook returns any canonicalization pattern rewrites that the operation
200 /// supports, for use by the canonicalization pass.
201 static void getCanonicalizationPatterns(RewritePatternSet &results,
202 MLIRContext *context) {}
203
204 /// This hook populates any unset default attrs.
205 static void populateDefaultAttrs(const OperationName &, NamedAttrList &) {}
206
207protected:
208 /// If the concrete type didn't implement a custom verifier hook, just fall
209 /// back to this one which accepts everything.
210 LogicalResult verify() { return success(); }
211 LogicalResult verifyRegions() { return success(); }
212
213 /// Parse the custom form of an operation. Unless overridden, this method will
214 /// first try to get an operation parser from the op's dialect. Otherwise the
215 /// custom assembly form of an op is always rejected. Op implementations
216 /// should implement this to return failure. On success, they should fill in
217 /// result with the fields to use.
218 static ParseResult parse(OpAsmParser &parser, OperationState &result);
219
220 /// Print the operation. Unless overridden, this method will first try to get
221 /// an operation printer from the dialect. Otherwise, it prints the operation
222 /// in generic form.
223 static void print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect);
224
225 /// Parse properties as a Attribute.
226 static ParseResult genericParseProperties(OpAsmParser &parser,
227 Attribute &result);
228
229 /// Print the properties as a Attribute with names not included within
230 /// 'elidedProps'
231 static void genericPrintProperties(OpAsmPrinter &p, Attribute properties,
232 ArrayRef<StringRef> elidedProps = {});
233
234 /// Print an operation name, eliding the dialect prefix if necessary.
235 static void printOpName(Operation *op, OpAsmPrinter &p,
236 StringRef defaultDialect);
237
238 /// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
239 /// so we can cast it away here.
240 explicit OpState(Operation *state) : state(state) {}
241
242 /// For all op which don't have properties, we keep a single instance of
243 /// `EmptyProperties` to be used where a reference to a properties is needed:
244 /// this allow to bind a pointer to the reference without triggering UB.
245 static EmptyProperties &getEmptyProperties() {
246 static EmptyProperties emptyProperties;
247 return emptyProperties;
248 }
249
250private:
251 Operation *state;
252
253 /// Allow access to internal hook implementation methods.
254 friend RegisteredOperationName;
255};
256
257// Allow comparing operators.
258inline bool operator==(OpState lhs, OpState rhs) {
259 return lhs.getOperation() == rhs.getOperation();
260}
261inline bool operator!=(OpState lhs, OpState rhs) {
262 return lhs.getOperation() != rhs.getOperation();
263}
264
265raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr);
266
267/// This class represents a single result from folding an operation.
268class OpFoldResult : public PointerUnion<Attribute, Value> {
269 using PointerUnion<Attribute, Value>::PointerUnion;
270
271public:
272 void dump() const { llvm::errs() << *this << "\n"; }
273
274 MLIRContext *getContext() const {
275 return is<Attribute>() ? get<Attribute>().getContext()
276 : get<Value>().getContext();
277 }
278};
279
280// Temporarily exit the MLIR namespace to add casting support as later code in
281// this uses it. The CastInfo must come after the OpFoldResult definition and
282// before any cast function calls depending on CastInfo.
283
284} // namespace mlir
285
286namespace llvm {
287
288// Allow llvm::cast style functions.
289template <typename To>
290struct CastInfo<To, mlir::OpFoldResult>
291 : public CastInfo<To, mlir::OpFoldResult::PointerUnion> {};
292
293template <typename To>
294struct CastInfo<To, const mlir::OpFoldResult>
295 : public CastInfo<To, const mlir::OpFoldResult::PointerUnion> {};
296
297} // namespace llvm
298
299namespace mlir {
300
301/// Allow printing to a stream.
302inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) {
303 if (Value value = llvm::dyn_cast_if_present<Value>(Val&: ofr))
304 value.print(os);
305 else
306 llvm::dyn_cast_if_present<Attribute>(Val&: ofr).print(os);
307 return os;
308}
309/// Allow printing to a stream.
310inline raw_ostream &operator<<(raw_ostream &os, OpState op) {
311 op.print(os, flags: OpPrintingFlags().useLocalScope());
312 return os;
313}
314
315//===----------------------------------------------------------------------===//
316// Operation Trait Types
317//===----------------------------------------------------------------------===//
318
319namespace OpTrait {
320
321// These functions are out-of-line implementations of the methods in the
322// corresponding trait classes. This avoids them being template
323// instantiated/duplicated.
324namespace impl {
325LogicalResult foldCommutative(Operation *op, ArrayRef<Attribute> operands,
326 SmallVectorImpl<OpFoldResult> &results);
327OpFoldResult foldIdempotent(Operation *op);
328OpFoldResult foldInvolution(Operation *op);
329LogicalResult verifyZeroOperands(Operation *op);
330LogicalResult verifyOneOperand(Operation *op);
331LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
332LogicalResult verifyIsIdempotent(Operation *op);
333LogicalResult verifyIsInvolution(Operation *op);
334LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
335LogicalResult verifyOperandsAreFloatLike(Operation *op);
336LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
337LogicalResult verifySameTypeOperands(Operation *op);
338LogicalResult verifyZeroRegions(Operation *op);
339LogicalResult verifyOneRegion(Operation *op);
340LogicalResult verifyNRegions(Operation *op, unsigned numRegions);
341LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions);
342LogicalResult verifyZeroResults(Operation *op);
343LogicalResult verifyOneResult(Operation *op);
344LogicalResult verifyNResults(Operation *op, unsigned numOperands);
345LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
346LogicalResult verifySameOperandsShape(Operation *op);
347LogicalResult verifySameOperandsAndResultShape(Operation *op);
348LogicalResult verifySameOperandsElementType(Operation *op);
349LogicalResult verifySameOperandsAndResultElementType(Operation *op);
350LogicalResult verifySameOperandsAndResultType(Operation *op);
351LogicalResult verifySameOperandsAndResultRank(Operation *op);
352LogicalResult verifyResultsAreBoolLike(Operation *op);
353LogicalResult verifyResultsAreFloatLike(Operation *op);
354LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
355LogicalResult verifyIsTerminator(Operation *op);
356LogicalResult verifyZeroSuccessors(Operation *op);
357LogicalResult verifyOneSuccessor(Operation *op);
358LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors);
359LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
360LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName,
361 StringRef valueGroupName,
362 size_t expectedCount);
363LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
364LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
365LogicalResult verifyNoRegionArguments(Operation *op);
366LogicalResult verifyElementwise(Operation *op);
367LogicalResult verifyIsIsolatedFromAbove(Operation *op);
368} // namespace impl
369
370/// Helper class for implementing traits. Clients are not expected to interact
371/// with this directly, so its members are all protected.
372template <typename ConcreteType, template <typename> class TraitType>
373class TraitBase {
374protected:
375 /// Return the ultimate Operation being worked on.
376 Operation *getOperation() {
377 auto *concrete = static_cast<ConcreteType *>(this);
378 return concrete->getOperation();
379 }
380};
381
382//===----------------------------------------------------------------------===//
383// Operand Traits
384
385namespace detail {
386/// Utility trait base that provides accessors for derived traits that have
387/// multiple operands.
388template <typename ConcreteType, template <typename> class TraitType>
389struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
390 using operand_iterator = Operation::operand_iterator;
391 using operand_range = Operation::operand_range;
392 using operand_type_iterator = Operation::operand_type_iterator;
393 using operand_type_range = Operation::operand_type_range;
394
395 /// Return the number of operands.
396 unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }
397
398 /// Return the operand at index 'i'.
399 Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); }
400
401 /// Set the operand at index 'i' to 'value'.
402 void setOperand(unsigned i, Value value) {
403 this->getOperation()->setOperand(i, value);
404 }
405
406 /// Operand iterator access.
407 operand_iterator operand_begin() {
408 return this->getOperation()->operand_begin();
409 }
410 operand_iterator operand_end() { return this->getOperation()->operand_end(); }
411 operand_range getOperands() { return this->getOperation()->getOperands(); }
412
413 /// Operand type access.
414 operand_type_iterator operand_type_begin() {
415 return this->getOperation()->operand_type_begin();
416 }
417 operand_type_iterator operand_type_end() {
418 return this->getOperation()->operand_type_end();
419 }
420 operand_type_range getOperandTypes() {
421 return this->getOperation()->getOperandTypes();
422 }
423};
424} // namespace detail
425
426/// `verifyInvariantsImpl` verifies the invariants like the types, attrs, .etc.
427/// It should be run after core traits and before any other user defined traits.
428/// In order to run it in the correct order, wrap it with OpInvariants trait so
429/// that tblgen will be able to put it in the right order.
430template <typename ConcreteType>
431class OpInvariants : public TraitBase<ConcreteType, OpInvariants> {
432public:
433 static LogicalResult verifyTrait(Operation *op) {
434 return cast<ConcreteType>(op).verifyInvariantsImpl();
435 }
436};
437
438/// This class provides the API for ops that are known to have no
439/// SSA operand.
440template <typename ConcreteType>
441class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> {
442public:
443 static LogicalResult verifyTrait(Operation *op) {
444 return impl::verifyZeroOperands(op);
445 }
446
447private:
448 // Disable these.
449 void getOperand() {}
450 void setOperand() {}
451};
452
453/// This class provides the API for ops that are known to have exactly one
454/// SSA operand.
455template <typename ConcreteType>
456class OneOperand : public TraitBase<ConcreteType, OneOperand> {
457public:
458 Value getOperand() { return this->getOperation()->getOperand(0); }
459
460 void setOperand(Value value) { this->getOperation()->setOperand(0, value); }
461
462 static LogicalResult verifyTrait(Operation *op) {
463 return impl::verifyOneOperand(op);
464 }
465};
466
467/// This class provides the API for ops that are known to have a specified
468/// number of operands. This is used as a trait like this:
469///
470/// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> {
471///
472template <unsigned N>
473class NOperands {
474public:
475 static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2");
476
477 template <typename ConcreteType>
478 class Impl
479 : public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> {
480 public:
481 static LogicalResult verifyTrait(Operation *op) {
482 return impl::verifyNOperands(op, numOperands: N);
483 }
484 };
485};
486
487/// This class provides the API for ops that are known to have a at least a
488/// specified number of operands. This is used as a trait like this:
489///
490/// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> {
491///
492template <unsigned N>
493class AtLeastNOperands {
494public:
495 template <typename ConcreteType>
496 class Impl : public detail::MultiOperandTraitBase<ConcreteType,
497 AtLeastNOperands<N>::Impl> {
498 public:
499 static LogicalResult verifyTrait(Operation *op) {
500 return impl::verifyAtLeastNOperands(op, numOperands: N);
501 }
502 };
503};
504
505/// This class provides the API for ops which have an unknown number of
506/// SSA operands.
507template <typename ConcreteType>
508class VariadicOperands
509 : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};
510
511//===----------------------------------------------------------------------===//
512// Region Traits
513
514/// This class provides verification for ops that are known to have zero
515/// regions.
516template <typename ConcreteType>
517class ZeroRegions : public TraitBase<ConcreteType, ZeroRegions> {
518public:
519 static LogicalResult verifyTrait(Operation *op) {
520 return impl::verifyZeroRegions(op);
521 }
522};
523
524namespace detail {
525/// Utility trait base that provides accessors for derived traits that have
526/// multiple regions.
527template <typename ConcreteType, template <typename> class TraitType>
528struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> {
529 using region_iterator = MutableArrayRef<Region>;
530 using region_range = RegionRange;
531
532 /// Return the number of regions.
533 unsigned getNumRegions() { return this->getOperation()->getNumRegions(); }
534
535 /// Return the region at `index`.
536 Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); }
537
538 /// Region iterator access.
539 region_iterator region_begin() {
540 return this->getOperation()->region_begin();
541 }
542 region_iterator region_end() { return this->getOperation()->region_end(); }
543 region_range getRegions() { return this->getOperation()->getRegions(); }
544};
545} // namespace detail
546
547/// This class provides APIs for ops that are known to have a single region.
548template <typename ConcreteType>
549class OneRegion : public TraitBase<ConcreteType, OneRegion> {
550public:
551 Region &getRegion() { return this->getOperation()->getRegion(0); }
552
553 /// Returns a range of operations within the region of this operation.
554 auto getOps() { return getRegion().getOps(); }
555 template <typename OpT>
556 auto getOps() {
557 return getRegion().template getOps<OpT>();
558 }
559
560 static LogicalResult verifyTrait(Operation *op) {
561 return impl::verifyOneRegion(op);
562 }
563};
564
565/// This class provides the API for ops that are known to have a specified
566/// number of regions.
567template <unsigned N>
568class NRegions {
569public:
570 static_assert(N > 1, "use ZeroRegions/OneRegion for N < 2");
571
572 template <typename ConcreteType>
573 class Impl
574 : public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> {
575 public:
576 static LogicalResult verifyTrait(Operation *op) {
577 return impl::verifyNRegions(op, numRegions: N);
578 }
579 };
580};
581
582/// This class provides APIs for ops that are known to have at least a specified
583/// number of regions.
584template <unsigned N>
585class AtLeastNRegions {
586public:
587 template <typename ConcreteType>
588 class Impl : public detail::MultiRegionTraitBase<ConcreteType,
589 AtLeastNRegions<N>::Impl> {
590 public:
591 static LogicalResult verifyTrait(Operation *op) {
592 return impl::verifyAtLeastNRegions(op, numRegions: N);
593 }
594 };
595};
596
597/// This class provides the API for ops which have an unknown number of
598/// regions.
599template <typename ConcreteType>
600class VariadicRegions
601 : public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {};
602
603//===----------------------------------------------------------------------===//
604// Result Traits
605
606/// This class provides return value APIs for ops that are known to have
607/// zero results.
608template <typename ConcreteType>
609class ZeroResults : public TraitBase<ConcreteType, ZeroResults> {
610public:
611 static LogicalResult verifyTrait(Operation *op) {
612 return impl::verifyZeroResults(op);
613 }
614};
615
616namespace detail {
617/// Utility trait base that provides accessors for derived traits that have
618/// multiple results.
619template <typename ConcreteType, template <typename> class TraitType>
620struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
621 using result_iterator = Operation::result_iterator;
622 using result_range = Operation::result_range;
623 using result_type_iterator = Operation::result_type_iterator;
624 using result_type_range = Operation::result_type_range;
625
626 /// Return the number of results.
627 unsigned getNumResults() { return this->getOperation()->getNumResults(); }
628
629 /// Return the result at index 'i'.
630 Value getResult(unsigned i) { return this->getOperation()->getResult(i); }
631
632 /// Replace all uses of results of this operation with the provided 'values'.
633 /// 'values' may correspond to an existing operation, or a range of 'Value'.
634 template <typename ValuesT>
635 void replaceAllUsesWith(ValuesT &&values) {
636 this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
637 }
638
639 /// Return the type of the `i`-th result.
640 Type getType(unsigned i) { return getResult(i).getType(); }
641
642 /// Result iterator access.
643 result_iterator result_begin() {
644 return this->getOperation()->result_begin();
645 }
646 result_iterator result_end() { return this->getOperation()->result_end(); }
647 result_range getResults() { return this->getOperation()->getResults(); }
648
649 /// Result type access.
650 result_type_iterator result_type_begin() {
651 return this->getOperation()->result_type_begin();
652 }
653 result_type_iterator result_type_end() {
654 return this->getOperation()->result_type_end();
655 }
656 result_type_range getResultTypes() {
657 return this->getOperation()->getResultTypes();
658 }
659};
660} // namespace detail
661
662/// This class provides return value APIs for ops that are known to have a
663/// single result. ResultType is the concrete type returned by getType().
664template <typename ConcreteType>
665class OneResult : public TraitBase<ConcreteType, OneResult> {
666public:
667 /// Replace all uses of 'this' value with the new value, updating anything
668 /// in the IR that uses 'this' to use the other value instead. When this
669 /// returns there are zero uses of 'this'.
670 void replaceAllUsesWith(Value newValue) {
671 this->getOperation()->getResult(0).replaceAllUsesWith(newValue);
672 }
673
674 /// Replace all uses of 'this' value with the result of 'op'.
675 void replaceAllUsesWith(Operation *op) {
676 this->getOperation()->replaceAllUsesWith(op);
677 }
678
679 static LogicalResult verifyTrait(Operation *op) {
680 return impl::verifyOneResult(op);
681 }
682};
683
684/// This trait is used for return value APIs for ops that are known to have a
685/// specific type other than `Type`. This allows the "getType()" member to be
686/// more specific for an op. This should be used in conjunction with OneResult,
687/// and occur in the trait list before OneResult.
688template <typename ResultType>
689class OneTypedResult {
690public:
691 /// This class provides return value APIs for ops that are known to have a
692 /// single result. ResultType is the concrete type returned by getType().
693 template <typename ConcreteType>
694 class Impl
695 : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
696 public:
697 mlir::TypedValue<ResultType> getResult() {
698 return cast<mlir::TypedValue<ResultType>>(
699 this->getOperation()->getResult(0));
700 }
701
702 /// If the operation returns a single value, then the Op can be implicitly
703 /// converted to a Value. This yields the value of the only result.
704 operator mlir::TypedValue<ResultType>() { return getResult(); }
705
706 ResultType getType() { return getResult().getType(); }
707 };
708};
709
710/// This class provides the API for ops that are known to have a specified
711/// number of results. This is used as a trait like this:
712///
713/// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> {
714///
715template <unsigned N>
716class NResults {
717public:
718 static_assert(N > 1, "use ZeroResults/OneResult for N < 2");
719
720 template <typename ConcreteType>
721 class Impl
722 : public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> {
723 public:
724 static LogicalResult verifyTrait(Operation *op) {
725 return impl::verifyNResults(op, numOperands: N);
726 }
727 };
728};
729
730/// This class provides the API for ops that are known to have at least a
731/// specified number of results. This is used as a trait like this:
732///
733/// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> {
734///
735template <unsigned N>
736class AtLeastNResults {
737public:
738 template <typename ConcreteType>
739 class Impl : public detail::MultiResultTraitBase<ConcreteType,
740 AtLeastNResults<N>::Impl> {
741 public:
742 static LogicalResult verifyTrait(Operation *op) {
743 return impl::verifyAtLeastNResults(op, numOperands: N);
744 }
745 };
746};
747
748/// This class provides the API for ops which have an unknown number of
749/// results.
750template <typename ConcreteType>
751class VariadicResults
752 : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};
753
754//===----------------------------------------------------------------------===//
755// Terminator Traits
756
757/// This class indicates that the regions associated with this op don't have
758/// terminators.
759template <typename ConcreteType>
760class NoTerminator : public TraitBase<ConcreteType, NoTerminator> {};
761
762/// This class provides the API for ops that are known to be terminators.
763template <typename ConcreteType>
764class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
765public:
766 static LogicalResult verifyTrait(Operation *op) {
767 return impl::verifyIsTerminator(op);
768 }
769};
770
771/// This class provides verification for ops that are known to have zero
772/// successors.
773template <typename ConcreteType>
774class ZeroSuccessors : public TraitBase<ConcreteType, ZeroSuccessors> {
775public:
776 static LogicalResult verifyTrait(Operation *op) {
777 return impl::verifyZeroSuccessors(op);
778 }
779};
780
781namespace detail {
782/// Utility trait base that provides accessors for derived traits that have
783/// multiple successors.
784template <typename ConcreteType, template <typename> class TraitType>
785struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> {
786 using succ_iterator = Operation::succ_iterator;
787 using succ_range = SuccessorRange;
788
789 /// Return the number of successors.
790 unsigned getNumSuccessors() {
791 return this->getOperation()->getNumSuccessors();
792 }
793
794 /// Return the successor at `index`.
795 Block *getSuccessor(unsigned i) {
796 return this->getOperation()->getSuccessor(i);
797 }
798
799 /// Set the successor at `index`.
800 void setSuccessor(Block *block, unsigned i) {
801 return this->getOperation()->setSuccessor(block, i);
802 }
803
804 /// Successor iterator access.
805 succ_iterator succ_begin() { return this->getOperation()->succ_begin(); }
806 succ_iterator succ_end() { return this->getOperation()->succ_end(); }
807 succ_range getSuccessors() { return this->getOperation()->getSuccessors(); }
808};
809} // namespace detail
810
811/// This class provides APIs for ops that are known to have a single successor.
812template <typename ConcreteType>
813class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> {
814public:
815 Block *getSuccessor() { return this->getOperation()->getSuccessor(0); }
816 void setSuccessor(Block *succ) {
817 this->getOperation()->setSuccessor(succ, 0);
818 }
819
820 static LogicalResult verifyTrait(Operation *op) {
821 return impl::verifyOneSuccessor(op);
822 }
823};
824
825/// This class provides the API for ops that are known to have a specified
826/// number of successors.
827template <unsigned N>
828class NSuccessors {
829public:
830 static_assert(N > 1, "use ZeroSuccessors/OneSuccessor for N < 2");
831
832 template <typename ConcreteType>
833 class Impl : public detail::MultiSuccessorTraitBase<ConcreteType,
834 NSuccessors<N>::Impl> {
835 public:
836 static LogicalResult verifyTrait(Operation *op) {
837 return impl::verifyNSuccessors(op, numSuccessors: N);
838 }
839 };
840};
841
842/// This class provides APIs for ops that are known to have at least a specified
843/// number of successors.
844template <unsigned N>
845class AtLeastNSuccessors {
846public:
847 template <typename ConcreteType>
848 class Impl
849 : public detail::MultiSuccessorTraitBase<ConcreteType,
850 AtLeastNSuccessors<N>::Impl> {
851 public:
852 static LogicalResult verifyTrait(Operation *op) {
853 return impl::verifyAtLeastNSuccessors(op, numSuccessors: N);
854 }
855 };
856};
857
858/// This class provides the API for ops which have an unknown number of
859/// successors.
860template <typename ConcreteType>
861class VariadicSuccessors
862 : public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> {
863};
864
865//===----------------------------------------------------------------------===//
866// SingleBlock
867
868/// This class provides APIs and verifiers for ops with regions having a single
869/// block.
870template <typename ConcreteType>
871struct SingleBlock : public TraitBase<ConcreteType, SingleBlock> {
872public:
873 static LogicalResult verifyTrait(Operation *op) {
874 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
875 Region &region = op->getRegion(index: i);
876
877 // Empty regions are fine.
878 if (region.empty())
879 continue;
880
881 // Non-empty regions must contain a single basic block.
882 if (!llvm::hasSingleElement(C&: region))
883 return op->emitOpError(message: "expects region #")
884 << i << " to have 0 or 1 blocks";
885
886 if (!ConcreteType::template hasTrait<NoTerminator>()) {
887 Block &block = region.front();
888 if (block.empty())
889 return op->emitOpError() << "expects a non-empty block";
890 }
891 }
892 return success();
893 }
894
895 Block *getBody(unsigned idx = 0) {
896 Region &region = this->getOperation()->getRegion(idx);
897 assert(!region.empty() && "unexpected empty region");
898 return &region.front();
899 }
900 Region &getBodyRegion(unsigned idx = 0) {
901 return this->getOperation()->getRegion(idx);
902 }
903
904 //===------------------------------------------------------------------===//
905 // Single Region Utilities
906 //===------------------------------------------------------------------===//
907
908 /// The following are a set of methods only enabled when the parent
909 /// operation has a single region. Each of these methods take an additional
910 /// template parameter that represents the concrete operation so that we
911 /// can use SFINAE to disable the methods for non-single region operations.
912 template <typename OpT, typename T = void>
913 using enable_if_single_region =
914 std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
915
916 template <typename OpT = ConcreteType>
917 enable_if_single_region<OpT, Block::iterator> begin() {
918 return getBody()->begin();
919 }
920 template <typename OpT = ConcreteType>
921 enable_if_single_region<OpT, Block::iterator> end() {
922 return getBody()->end();
923 }
924 template <typename OpT = ConcreteType>
925 enable_if_single_region<OpT, Operation &> front() {
926 return *begin();
927 }
928
929 /// Insert the operation into the back of the body.
930 template <typename OpT = ConcreteType>
931 enable_if_single_region<OpT> push_back(Operation *op) {
932 insert(Block::iterator(getBody()->end()), op);
933 }
934
935 /// Insert the operation at the given insertion point.
936 template <typename OpT = ConcreteType>
937 enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
938 insert(Block::iterator(insertPt), op);
939 }
940 template <typename OpT = ConcreteType>
941 enable_if_single_region<OpT> insert(Block::iterator insertPt, Operation *op) {
942 getBody()->getOperations().insert(insertPt, op);
943 }
944};
945
946//===----------------------------------------------------------------------===//
947// SingleBlockImplicitTerminator
948
949/// This class provides APIs and verifiers for ops with regions having a single
950/// block that must terminate with `TerminatorOpType`.
951template <typename TerminatorOpType>
952struct SingleBlockImplicitTerminator {
953 template <typename ConcreteType>
954 class Impl : public TraitBase<ConcreteType, SingleBlockImplicitTerminator<
955 TerminatorOpType>::Impl> {
956 private:
957 /// Builds a terminator operation without relying on OpBuilder APIs to avoid
958 /// cyclic header inclusion.
959 static Operation *buildTerminator(OpBuilder &builder, Location loc) {
960 OperationState state(loc, TerminatorOpType::getOperationName());
961 TerminatorOpType::build(builder, state);
962 return Operation::create(state);
963 }
964
965 public:
966 /// The type of the operation used as the implicit terminator type.
967 using ImplicitTerminatorOpT = TerminatorOpType;
968
969 static LogicalResult verifyRegionTrait(Operation *op) {
970 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
971 Region &region = op->getRegion(index: i);
972 // Empty regions are fine.
973 if (region.empty())
974 continue;
975 Operation &terminator = region.front().back();
976 if (isa<TerminatorOpType>(terminator))
977 continue;
978
979 return op->emitOpError(message: "expects regions to end with '" +
980 TerminatorOpType::getOperationName() +
981 "', found '" +
982 terminator.getName().getStringRef() + "'")
983 .attachNote()
984 << "in custom textual format, the absence of terminator implies "
985 "'"
986 << TerminatorOpType::getOperationName() << '\'';
987 }
988
989 return success();
990 }
991
992 /// Ensure that the given region has the terminator required by this trait.
993 /// If OpBuilder is provided, use it to build the terminator and notify the
994 /// OpBuilder listeners accordingly. If only a Builder is provided, locally
995 /// construct an OpBuilder with no listeners; this should only be used if no
996 /// OpBuilder is available at the call site, e.g., in the parser.
997 static void ensureTerminator(Region &region, Builder &builder,
998 Location loc) {
999 ::mlir::impl::ensureRegionTerminator(region, builder, loc,
1000 buildTerminatorOp: buildTerminator);
1001 }
1002 static void ensureTerminator(Region &region, OpBuilder &builder,
1003 Location loc) {
1004 ::mlir::impl::ensureRegionTerminator(region, builder, loc,
1005 buildTerminatorOp: buildTerminator);
1006 }
1007 };
1008};
1009
1010/// Check is an op defines the `ImplicitTerminatorOpT` member. This is intended
1011/// to be used with `llvm::is_detected`.
1012template <class T>
1013using has_implicit_terminator_t = typename T::ImplicitTerminatorOpT;
1014
1015/// Support to check if an operation has the SingleBlockImplicitTerminator
1016/// trait. We can't just use `hasTrait` because this class is templated on a
1017/// specific terminator op.
1018template <class Op, bool hasTerminator =
1019 llvm::is_detected<has_implicit_terminator_t, Op>::value>
1020struct hasSingleBlockImplicitTerminator {
1021 static constexpr bool value = std::is_base_of<
1022 typename OpTrait::SingleBlockImplicitTerminator<
1023 typename Op::ImplicitTerminatorOpT>::template Impl<Op>,
1024 Op>::value;
1025};
1026template <class Op>
1027struct hasSingleBlockImplicitTerminator<Op, false> {
1028 static constexpr bool value = false;
1029};
1030
1031//===----------------------------------------------------------------------===//
1032// Misc Traits
1033
1034/// This class provides verification for ops that are known to have the same
1035/// operand shape: all operands are scalars, vectors/tensors of the same
1036/// shape.
1037template <typename ConcreteType>
1038class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> {
1039public:
1040 static LogicalResult verifyTrait(Operation *op) {
1041 return impl::verifySameOperandsShape(op);
1042 }
1043};
1044
1045/// This class provides verification for ops that are known to have the same
1046/// operand and result shape: both are scalars, vectors/tensors of the same
1047/// shape.
1048template <typename ConcreteType>
1049class SameOperandsAndResultShape
1050 : public TraitBase<ConcreteType, SameOperandsAndResultShape> {
1051public:
1052 static LogicalResult verifyTrait(Operation *op) {
1053 return impl::verifySameOperandsAndResultShape(op);
1054 }
1055};
1056
1057/// This class provides verification for ops that are known to have the same
1058/// operand element type (or the type itself if it is scalar).
1059///
1060template <typename ConcreteType>
1061class SameOperandsElementType
1062 : public TraitBase<ConcreteType, SameOperandsElementType> {
1063public:
1064 static LogicalResult verifyTrait(Operation *op) {
1065 return impl::verifySameOperandsElementType(op);
1066 }
1067};
1068
1069/// This class provides verification for ops that are known to have the same
1070/// operand and result element type (or the type itself if it is scalar).
1071///
1072template <typename ConcreteType>
1073class SameOperandsAndResultElementType
1074 : public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
1075public:
1076 static LogicalResult verifyTrait(Operation *op) {
1077 return impl::verifySameOperandsAndResultElementType(op);
1078 }
1079};
1080
1081/// This class provides verification for ops that are known to have the same
1082/// operand and result type.
1083///
1084/// Note: this trait subsumes the SameOperandsAndResultShape and
1085/// SameOperandsAndResultElementType traits.
1086template <typename ConcreteType>
1087class SameOperandsAndResultType
1088 : public TraitBase<ConcreteType, SameOperandsAndResultType> {
1089public:
1090 static LogicalResult verifyTrait(Operation *op) {
1091 return impl::verifySameOperandsAndResultType(op);
1092 }
1093};
1094
1095/// This class verifies that op has same ranks for all
1096/// operands and results types, if known.
1097template <typename ConcreteType>
1098class SameOperandsAndResultRank
1099 : public TraitBase<ConcreteType, SameOperandsAndResultRank> {
1100public:
1101 static LogicalResult verifyTrait(Operation *op) {
1102 return impl::verifySameOperandsAndResultRank(op);
1103 }
1104};
1105
1106/// This class verifies that any results of the specified op have a boolean
1107/// type, a vector thereof, or a tensor thereof.
1108template <typename ConcreteType>
1109class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> {
1110public:
1111 static LogicalResult verifyTrait(Operation *op) {
1112 return impl::verifyResultsAreBoolLike(op);
1113 }
1114};
1115
1116/// This class verifies that any results of the specified op have a floating
1117/// point type, a vector thereof, or a tensor thereof.
1118template <typename ConcreteType>
1119class ResultsAreFloatLike
1120 : public TraitBase<ConcreteType, ResultsAreFloatLike> {
1121public:
1122 static LogicalResult verifyTrait(Operation *op) {
1123 return impl::verifyResultsAreFloatLike(op);
1124 }
1125};
1126
1127/// This class verifies that any results of the specified op have a signless
1128/// integer or index type, a vector thereof, or a tensor thereof.
1129template <typename ConcreteType>
1130class ResultsAreSignlessIntegerLike
1131 : public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> {
1132public:
1133 static LogicalResult verifyTrait(Operation *op) {
1134 return impl::verifyResultsAreSignlessIntegerLike(op);
1135 }
1136};
1137
1138/// This class adds property that the operation is commutative.
1139template <typename ConcreteType>
1140class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {
1141public:
1142 static LogicalResult foldTrait(Operation *op, ArrayRef<Attribute> operands,
1143 SmallVectorImpl<OpFoldResult> &results) {
1144 return impl::foldCommutative(op, operands, results);
1145 }
1146};
1147
1148/// This class adds property that the operation is an involution.
1149/// This means a unary to unary operation "f" that satisfies f(f(x)) = x
1150template <typename ConcreteType>
1151class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
1152public:
1153 static LogicalResult verifyTrait(Operation *op) {
1154 static_assert(ConcreteType::template hasTrait<OneResult>(),
1155 "expected operation to produce one result");
1156 static_assert(ConcreteType::template hasTrait<OneOperand>(),
1157 "expected operation to take one operand");
1158 static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1159 "expected operation to preserve type");
1160 // Involution requires the operation to be side effect free as well
1161 // but currently this check is under a FIXME and is not actually done.
1162 return impl::verifyIsInvolution(op);
1163 }
1164
1165 static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1166 return impl::foldInvolution(op);
1167 }
1168};
1169
1170/// This class adds property that the operation is idempotent.
1171/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x),
1172/// or a binary operation "g" that satisfies g(x, x) = x.
1173template <typename ConcreteType>
1174class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> {
1175public:
1176 static LogicalResult verifyTrait(Operation *op) {
1177 static_assert(ConcreteType::template hasTrait<OneResult>(),
1178 "expected operation to produce one result");
1179 static_assert(ConcreteType::template hasTrait<OneOperand>() ||
1180 ConcreteType::template hasTrait<NOperands<2>::Impl>(),
1181 "expected operation to take one or two operands");
1182 static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1183 "expected operation to preserve type");
1184 // Idempotent requires the operation to be side effect free as well
1185 // but currently this check is under a FIXME and is not actually done.
1186 return impl::verifyIsIdempotent(op);
1187 }
1188
1189 static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1190 return impl::foldIdempotent(op);
1191 }
1192};
1193
1194/// This class verifies that all operands of the specified op have a float type,
1195/// a vector thereof, or a tensor thereof.
1196template <typename ConcreteType>
1197class OperandsAreFloatLike
1198 : public TraitBase<ConcreteType, OperandsAreFloatLike> {
1199public:
1200 static LogicalResult verifyTrait(Operation *op) {
1201 return impl::verifyOperandsAreFloatLike(op);
1202 }
1203};
1204
1205/// This class verifies that all operands of the specified op have a signless
1206/// integer or index type, a vector thereof, or a tensor thereof.
1207template <typename ConcreteType>
1208class OperandsAreSignlessIntegerLike
1209 : public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> {
1210public:
1211 static LogicalResult verifyTrait(Operation *op) {
1212 return impl::verifyOperandsAreSignlessIntegerLike(op);
1213 }
1214};
1215
1216/// This class verifies that all operands of the specified op have the same
1217/// type.
1218template <typename ConcreteType>
1219class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> {
1220public:
1221 static LogicalResult verifyTrait(Operation *op) {
1222 return impl::verifySameTypeOperands(op);
1223 }
1224};
1225
1226/// This class provides the API for a sub-set of ops that are known to be
1227/// constant-like. These are non-side effecting operations with one result and
1228/// zero operands that can always be folded to a specific attribute value.
1229template <typename ConcreteType>
1230class ConstantLike : public TraitBase<ConcreteType, ConstantLike> {
1231public:
1232 static LogicalResult verifyTrait(Operation *op) {
1233 static_assert(ConcreteType::template hasTrait<OneResult>(),
1234 "expected operation to produce one result");
1235 static_assert(ConcreteType::template hasTrait<ZeroOperands>(),
1236 "expected operation to take zero operands");
1237 // TODO: We should verify that the operation can always be folded, but this
1238 // requires that the attributes of the op already be verified. We should add
1239 // support for verifying traits "after" the operation to enable this use
1240 // case.
1241 return success();
1242 }
1243};
1244
1245/// This class provides the API for ops that are known to be isolated from
1246/// above.
1247template <typename ConcreteType>
1248class IsIsolatedFromAbove
1249 : public TraitBase<ConcreteType, IsIsolatedFromAbove> {
1250public:
1251 static LogicalResult verifyRegionTrait(Operation *op) {
1252 return impl::verifyIsIsolatedFromAbove(op);
1253 }
1254};
1255
1256/// A trait of region holding operations that defines a new scope for polyhedral
1257/// optimization purposes. Any SSA values of 'index' type that either dominate
1258/// such an operation or are used at the top-level of such an operation
1259/// automatically become valid symbols for the polyhedral scope defined by that
1260/// operation. For more details, see `Traits.md#AffineScope`.
1261template <typename ConcreteType>
1262class AffineScope : public TraitBase<ConcreteType, AffineScope> {
1263public:
1264 static LogicalResult verifyTrait(Operation *op) {
1265 static_assert(!ConcreteType::template hasTrait<ZeroRegions>(),
1266 "expected operation to have one or more regions");
1267 return success();
1268 }
1269};
1270
1271/// A trait of region holding operations that define a new scope for automatic
1272/// allocations, i.e., allocations that are freed when control is transferred
1273/// back from the operation's region. Any operations performing such allocations
1274/// (for eg. memref.alloca) will have their allocations automatically freed at
1275/// their closest enclosing operation with this trait.
1276template <typename ConcreteType>
1277class AutomaticAllocationScope
1278 : public TraitBase<ConcreteType, AutomaticAllocationScope> {
1279public:
1280 static LogicalResult verifyTrait(Operation *op) {
1281 static_assert(!ConcreteType::template hasTrait<ZeroRegions>(),
1282 "expected operation to have one or more regions");
1283 return success();
1284 }
1285};
1286
1287/// This class provides a verifier for ops that are expecting their parent
1288/// to be one of the given parent ops
1289template <typename... ParentOpTypes>
1290struct HasParent {
1291 template <typename ConcreteType>
1292 class Impl : public TraitBase<ConcreteType, Impl> {
1293 public:
1294 static LogicalResult verifyTrait(Operation *op) {
1295 if (llvm::isa_and_nonnull<ParentOpTypes...>(op->getParentOp()))
1296 return success();
1297
1298 return op->emitOpError()
1299 << "expects parent op "
1300 << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'")
1301 << llvm::ArrayRef({ParentOpTypes::getOperationName()...}) << "'";
1302 }
1303
1304 template <typename ParentOpType =
1305 std::tuple_element_t<0, std::tuple<ParentOpTypes...>>>
1306 std::enable_if_t<sizeof...(ParentOpTypes) == 1, ParentOpType>
1307 getParentOp() {
1308 Operation *parent = this->getOperation()->getParentOp();
1309 return llvm::cast<ParentOpType>(parent);
1310 }
1311 };
1312};
1313
1314/// A trait for operations that have an attribute specifying operand segments.
1315///
1316/// Certain operations can have multiple variadic operands and their size
1317/// relationship is not always known statically. For such cases, we need
1318/// a per-op-instance specification to divide the operands into logical groups
1319/// or segments. This can be modeled by attributes. The attribute will be named
1320/// as `operandSegmentSizes`.
1321///
1322/// This trait verifies the attribute for specifying operand segments has
1323/// the correct type (1D vector) and values (non-negative), etc.
1324template <typename ConcreteType>
1325class AttrSizedOperandSegments
1326 : public TraitBase<ConcreteType, AttrSizedOperandSegments> {
1327public:
1328 static StringRef getOperandSegmentSizeAttr() { return "operandSegmentSizes"; }
1329
1330 static LogicalResult verifyTrait(Operation *op) {
1331 return ::mlir::OpTrait::impl::verifyOperandSizeAttr(
1332 op, sizeAttrName: getOperandSegmentSizeAttr());
1333 }
1334};
1335
1336/// Similar to AttrSizedOperandSegments but used for results.
1337template <typename ConcreteType>
1338class AttrSizedResultSegments
1339 : public TraitBase<ConcreteType, AttrSizedResultSegments> {
1340public:
1341 static StringRef getResultSegmentSizeAttr() { return "resultSegmentSizes"; }
1342
1343 static LogicalResult verifyTrait(Operation *op) {
1344 return ::mlir::OpTrait::impl::verifyResultSizeAttr(
1345 op, sizeAttrName: getResultSegmentSizeAttr());
1346 }
1347};
1348
1349/// This trait provides a verifier for ops that are expecting their regions to
1350/// not have any arguments
1351template <typename ConcrentType>
1352struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> {
1353 static LogicalResult verifyTrait(Operation *op) {
1354 return ::mlir::OpTrait::impl::verifyNoRegionArguments(op);
1355 }
1356};
1357
1358// This trait is used to flag operations that consume or produce
1359// values of `MemRef` type where those references can be 'normalized'.
1360// TODO: Right now, the operands of an operation are either all normalizable,
1361// or not. In the future, we may want to allow some of the operands to be
1362// normalizable.
1363template <typename ConcrentType>
1364struct MemRefsNormalizable
1365 : public TraitBase<ConcrentType, MemRefsNormalizable> {};
1366
1367/// This trait tags element-wise ops on vectors or tensors.
1368///
1369/// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this
1370/// trait. In particular, broadcasting behavior is not allowed.
1371///
1372/// An `Elementwise` op must satisfy the following properties:
1373///
1374/// 1. If any result is a vector/tensor then at least one operand must also be a
1375/// vector/tensor.
1376/// 2. If any operand is a vector/tensor then there must be at least one result
1377/// and all results must be vectors/tensors.
1378/// 3. All operand and result vector/tensor types must be of the same shape. The
1379/// shape may be dynamic in which case the op's behaviour is undefined for
1380/// non-matching shapes.
1381/// 4. The operation must be elementwise on its vector/tensor operands and
1382/// results. When applied to single-element vectors/tensors, the result must
1383/// be the same per elememnt.
1384///
1385/// TODO: Avoid hardcoding vector/tensor, and generalize this trait to a new
1386/// interface `ElementwiseTypeInterface` that describes the container types for
1387/// which the operation is elementwise.
1388///
1389/// Rationale:
1390/// - 1. and 2. guarantee a well-defined iteration space and exclude the cases
1391/// of 0 non-scalar operands or 0 non-scalar results, which complicate a
1392/// generic definition of the iteration space.
1393/// - 3. guarantees that folding can be done across scalars/vectors/tensors with
1394/// the same pattern, as otherwise lots of special handling for type
1395/// mismatches would be needed.
1396/// - 4. guarantees that no error handling is needed. Higher-level dialects
1397/// should reify any needed guards or error handling code before lowering to
1398/// an `Elementwise` op.
1399template <typename ConcreteType>
1400struct Elementwise : public TraitBase<ConcreteType, Elementwise> {
1401 static LogicalResult verifyTrait(Operation *op) {
1402 return ::mlir::OpTrait::impl::verifyElementwise(op);
1403 }
1404};
1405
1406/// This trait tags `Elementwise` operatons that can be systematically
1407/// scalarized. All vector/tensor operands and results are then replaced by
1408/// scalars of the respective element type. Semantically, this is the operation
1409/// on a single element of the vector/tensor.
1410///
1411/// Rationale:
1412/// Allow to define the vector/tensor semantics of elementwise operations based
1413/// on the same op's behavior on scalars. This provides a constructive procedure
1414/// for IR transformations to, e.g., create scalar loop bodies from tensor ops.
1415///
1416/// Example:
1417/// ```
1418/// %tensor_select = "arith.select"(%pred_tensor, %true_val, %false_val)
1419/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
1420/// -> tensor<?xf32>
1421/// ```
1422/// can be scalarized to
1423///
1424/// ```
1425/// %scalar_select = "arith.select"(%pred, %true_val_scalar, %false_val_scalar)
1426/// : (i1, f32, f32) -> f32
1427/// ```
1428template <typename ConcreteType>
1429struct Scalarizable : public TraitBase<ConcreteType, Scalarizable> {
1430 static LogicalResult verifyTrait(Operation *op) {
1431 static_assert(
1432 ConcreteType::template hasTrait<Elementwise>(),
1433 "`Scalarizable` trait is only applicable to `Elementwise` ops.");
1434 return success();
1435 }
1436};
1437
1438/// This trait tags `Elementwise` operatons that can be systematically
1439/// vectorized. All scalar operands and results are then replaced by vectors
1440/// with the respective element type. Semantically, this is the operation on
1441/// multiple elements simultaneously. See also `Tensorizable`.
1442///
1443/// Rationale:
1444/// Provide the reverse to `Scalarizable` which, when chained together, allows
1445/// reasoning about the relationship between the tensor and vector case.
1446/// Additionally, it permits reasoning about promoting scalars to vectors via
1447/// broadcasting in cases like `%select_scalar_pred` below.
1448template <typename ConcreteType>
1449struct Vectorizable : public TraitBase<ConcreteType, Vectorizable> {
1450 static LogicalResult verifyTrait(Operation *op) {
1451 static_assert(
1452 ConcreteType::template hasTrait<Elementwise>(),
1453 "`Vectorizable` trait is only applicable to `Elementwise` ops.");
1454 return success();
1455 }
1456};
1457
1458/// This trait tags `Elementwise` operatons that can be systematically
1459/// tensorized. All scalar operands and results are then replaced by tensors
1460/// with the respective element type. Semantically, this is the operation on
1461/// multiple elements simultaneously. See also `Vectorizable`.
1462///
1463/// Rationale:
1464/// Provide the reverse to `Scalarizable` which, when chained together, allows
1465/// reasoning about the relationship between the tensor and vector case.
1466/// Additionally, it permits reasoning about promoting scalars to tensors via
1467/// broadcasting in cases like `%select_scalar_pred` below.
1468///
1469/// Examples:
1470/// ```
1471/// %scalar = "arith.addf"(%a, %b) : (f32, f32) -> f32
1472/// ```
1473/// can be tensorized to
1474/// ```
1475/// %tensor = "arith.addf"(%a, %b) : (tensor<?xf32>, tensor<?xf32>)
1476/// -> tensor<?xf32>
1477/// ```
1478///
1479/// ```
1480/// %scalar_pred = "arith.select"(%pred, %true_val, %false_val)
1481/// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
1482/// ```
1483/// can be tensorized to
1484/// ```
1485/// %tensor_pred = "arith.select"(%pred, %true_val, %false_val)
1486/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
1487/// -> tensor<?xf32>
1488/// ```
1489template <typename ConcreteType>
1490struct Tensorizable : public TraitBase<ConcreteType, Tensorizable> {
1491 static LogicalResult verifyTrait(Operation *op) {
1492 static_assert(
1493 ConcreteType::template hasTrait<Elementwise>(),
1494 "`Tensorizable` trait is only applicable to `Elementwise` ops.");
1495 return success();
1496 }
1497};
1498
1499/// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable`
1500/// provide an easy way for scalar operations to conveniently generalize their
1501/// behavior to vectors/tensors, and systematize conversion between these forms.
1502bool hasElementwiseMappableTraits(Operation *op);
1503
1504} // namespace OpTrait
1505
1506//===----------------------------------------------------------------------===//
1507// Internal Trait Utilities
1508//===----------------------------------------------------------------------===//
1509
1510namespace op_definition_impl {
1511//===----------------------------------------------------------------------===//
1512// Trait Existence
1513
1514/// Returns true if this given Trait ID matches the IDs of any of the provided
1515/// trait types `Traits`.
1516template <template <typename T> class... Traits>
1517inline bool hasTrait(TypeID traitID) {
1518 TypeID traitIDs[] = {TypeID::get<Traits>()...};
1519 for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
1520 if (traitIDs[i] == traitID)
1521 return true;
1522 return false;
1523}
1524template <>
1525inline bool hasTrait<>(TypeID traitID) {
1526 return false;
1527}
1528
1529//===----------------------------------------------------------------------===//
1530// Trait Folding
1531
1532/// Trait to check if T provides a 'foldTrait' method for single result
1533/// operations.
1534template <typename T, typename... Args>
1535using has_single_result_fold_trait = decltype(T::foldTrait(
1536 std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>()));
1537template <typename T>
1538using detect_has_single_result_fold_trait =
1539 llvm::is_detected<has_single_result_fold_trait, T>;
1540/// Trait to check if T provides a general 'foldTrait' method.
1541template <typename T, typename... Args>
1542using has_fold_trait =
1543 decltype(T::foldTrait(std::declval<Operation *>(),
1544 std::declval<ArrayRef<Attribute>>(),
1545 std::declval<SmallVectorImpl<OpFoldResult> &>()));
1546template <typename T>
1547using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>;
1548/// Trait to check if T provides any `foldTrait` method.
1549template <typename T>
1550using detect_has_any_fold_trait =
1551 std::disjunction<detect_has_fold_trait<T>,
1552 detect_has_single_result_fold_trait<T>>;
1553
1554/// Returns the result of folding a trait that implements a `foldTrait` function
1555/// that is specialized for operations that have a single result.
1556template <typename Trait>
1557static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value,
1558 LogicalResult>
1559foldTrait(Operation *op, ArrayRef<Attribute> operands,
1560 SmallVectorImpl<OpFoldResult> &results) {
1561 assert(op->hasTrait<OpTrait::OneResult>() &&
1562 "expected trait on non single-result operation to implement the "
1563 "general `foldTrait` method");
1564 // If a previous trait has already been folded and replaced this operation, we
1565 // fail to fold this trait.
1566 if (!results.empty())
1567 return failure();
1568
1569 if (OpFoldResult result = Trait::foldTrait(op, operands)) {
1570 if (llvm::dyn_cast_if_present<Value>(Val&: result) != op->getResult(idx: 0))
1571 results.push_back(Elt: result);
1572 return success();
1573 }
1574 return failure();
1575}
1576/// Returns the result of folding a trait that implements a generalized
1577/// `foldTrait` function that is supports any operation type.
1578template <typename Trait>
1579static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult>
1580foldTrait(Operation *op, ArrayRef<Attribute> operands,
1581 SmallVectorImpl<OpFoldResult> &results) {
1582 // If a previous trait has already been folded and replaced this operation, we
1583 // fail to fold this trait.
1584 return results.empty() ? Trait::foldTrait(op, operands, results) : failure();
1585}
1586template <typename Trait>
1587static inline std::enable_if_t<!detect_has_any_fold_trait<Trait>::value,
1588 LogicalResult>
1589foldTrait(Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) {
1590 return failure();
1591}
1592
1593/// Given a tuple type containing a set of traits, return the result of folding
1594/// the given operation.
1595template <typename... Ts>
1596static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
1597 SmallVectorImpl<OpFoldResult> &results) {
1598 return success((succeeded(foldTrait<Ts>(op, operands, results)) || ...));
1599}
1600
1601//===----------------------------------------------------------------------===//
1602// Trait Verification
1603
1604/// Trait to check if T provides a `verifyTrait` method.
1605template <typename T, typename... Args>
1606using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>()));
1607template <typename T>
1608using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>;
1609
1610/// Trait to check if T provides a `verifyTrait` method.
1611template <typename T, typename... Args>
1612using has_verify_region_trait =
1613 decltype(T::verifyRegionTrait(std::declval<Operation *>()));
1614template <typename T>
1615using detect_has_verify_region_trait =
1616 llvm::is_detected<has_verify_region_trait, T>;
1617
1618/// Verify the given trait if it provides a verifier.
1619template <typename T>
1620std::enable_if_t<detect_has_verify_trait<T>::value, LogicalResult>
1621verifyTrait(Operation *op) {
1622 return T::verifyTrait(op);
1623}
1624template <typename T>
1625inline std::enable_if_t<!detect_has_verify_trait<T>::value, LogicalResult>
1626verifyTrait(Operation *) {
1627 return success();
1628}
1629
1630/// Given a set of traits, return the result of verifying the given operation.
1631template <typename... Ts>
1632LogicalResult verifyTraits(Operation *op) {
1633 return success((succeeded(verifyTrait<Ts>(op)) && ...));
1634}
1635
1636/// Verify the given trait if it provides a region verifier.
1637template <typename T>
1638std::enable_if_t<detect_has_verify_region_trait<T>::value, LogicalResult>
1639verifyRegionTrait(Operation *op) {
1640 return T::verifyRegionTrait(op);
1641}
1642template <typename T>
1643inline std::enable_if_t<!detect_has_verify_region_trait<T>::value,
1644 LogicalResult>
1645verifyRegionTrait(Operation *) {
1646 return success();
1647}
1648
1649/// Given a set of traits, return the result of verifying the regions of the
1650/// given operation.
1651template <typename... Ts>
1652LogicalResult verifyRegionTraits(Operation *op) {
1653 return success((succeeded(verifyRegionTrait<Ts>(op)) && ...));
1654}
1655} // namespace op_definition_impl
1656
1657//===----------------------------------------------------------------------===//
1658// Operation Definition classes
1659//===----------------------------------------------------------------------===//
1660
1661/// This provides public APIs that all operations should have. The template
1662/// argument 'ConcreteType' should be the concrete type by CRTP and the others
1663/// are base classes by the policy pattern.
1664template <typename ConcreteType, template <typename T> class... Traits>
1665class Op : public OpState, public Traits<ConcreteType>... {
1666public:
1667 /// Inherit getOperation from `OpState`.
1668 using OpState::getOperation;
1669 using OpState::verify;
1670 using OpState::verifyRegions;
1671
1672 /// Return if this operation contains the provided trait.
1673 template <template <typename T> class Trait>
1674 static constexpr bool hasTrait() {
1675 return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value;
1676 }
1677
1678 /// Create a deep copy of this operation.
1679 ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); }
1680
1681 /// Create a partial copy of this operation without traversing into attached
1682 /// regions. The new operation will have the same number of regions as the
1683 /// original one, but they will be left empty.
1684 ConcreteType cloneWithoutRegions() {
1685 return cast<ConcreteType>(getOperation()->cloneWithoutRegions());
1686 }
1687
1688 /// Return true if this "op class" can match against the specified operation.
1689 static bool classof(Operation *op) {
1690 if (auto info = op->getRegisteredInfo())
1691 return TypeID::get<ConcreteType>() == info->getTypeID();
1692#ifndef NDEBUG
1693 if (op->getName().getStringRef() == ConcreteType::getOperationName())
1694 llvm::report_fatal_error(
1695 "classof on '" + ConcreteType::getOperationName() +
1696 "' failed due to the operation not being registered");
1697#endif
1698 return false;
1699 }
1700 /// Provide `classof` support for other OpBase derived classes, such as
1701 /// Interfaces.
1702 template <typename T>
1703 static std::enable_if_t<std::is_base_of<OpState, T>::value, bool>
1704 classof(const T *op) {
1705 return classof(const_cast<T *>(op)->getOperation());
1706 }
1707
1708 /// Expose the type we are instantiated on to template machinery that may want
1709 /// to introspect traits on this operation.
1710 using ConcreteOpType = ConcreteType;
1711
1712 /// This is a public constructor. Any op can be initialized to null.
1713 explicit Op() : OpState(nullptr) {}
1714 Op(std::nullptr_t) : OpState(nullptr) {}
1715
1716 /// This is a public constructor to enable access via the llvm::cast family of
1717 /// methods. This should not be used directly.
1718 explicit Op(Operation *state) : OpState(state) {}
1719
1720 /// Methods for supporting PointerLikeTypeTraits.
1721 const void *getAsOpaquePointer() const {
1722 return static_cast<const void *>((Operation *)*this);
1723 }
1724 static ConcreteOpType getFromOpaquePointer(const void *pointer) {
1725 return ConcreteOpType(
1726 reinterpret_cast<Operation *>(const_cast<void *>(pointer)));
1727 }
1728
1729 /// Attach the given models as implementations of the corresponding
1730 /// interfaces for the concrete operation.
1731 template <typename... Models>
1732 static void attachInterface(MLIRContext &context) {
1733 std::optional<RegisteredOperationName> info =
1734 RegisteredOperationName::lookup(TypeID::get<ConcreteType>(), &context);
1735 if (!info)
1736 llvm::report_fatal_error(
1737 "Attempting to attach an interface to an unregistered operation " +
1738 ConcreteType::getOperationName() + ".");
1739 (checkInterfaceTarget<Models>(), ...);
1740 info->attachInterface<Models...>();
1741 }
1742 /// Convert the provided attribute to a property and assigned it to the
1743 /// provided properties. This default implementation forwards to a free
1744 /// function `setPropertiesFromAttribute` that can be looked up with ADL in
1745 /// the namespace where the properties are defined. It can also be overridden
1746 /// in the derived ConcreteOp.
1747 template <typename PropertiesTy>
1748 static LogicalResult
1749 setPropertiesFromAttr(PropertiesTy &prop, Attribute attr,
1750 function_ref<InFlightDiagnostic()> emitError) {
1751 return setPropertiesFromAttribute(prop, attr, emitError);
1752 }
1753 /// Convert the provided properties to an attribute. This default
1754 /// implementation forwards to a free function `getPropertiesAsAttribute` that
1755 /// can be looked up with ADL in the namespace where the properties are
1756 /// defined. It can also be overridden in the derived ConcreteOp.
1757 template <typename PropertiesTy>
1758 static Attribute getPropertiesAsAttr(MLIRContext *ctx,
1759 const PropertiesTy &prop) {
1760 return getPropertiesAsAttribute(ctx, prop);
1761 }
1762 /// Hash the provided properties. This default implementation forwards to a
1763 /// free function `computeHash` that can be looked up with ADL in the
1764 /// namespace where the properties are defined. It can also be overridden in
1765 /// the derived ConcreteOp.
1766 template <typename PropertiesTy>
1767 static llvm::hash_code computePropertiesHash(const PropertiesTy &prop) {
1768 return computeHash(prop);
1769 }
1770
1771private:
1772 /// Trait to check if T provides a 'fold' method for a single result op.
1773 template <typename T, typename... Args>
1774 using has_single_result_fold_t =
1775 decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
1776 template <typename T>
1777 constexpr static bool has_single_result_fold_v =
1778 llvm::is_detected<has_single_result_fold_t, T>::value;
1779 /// Trait to check if T provides a general 'fold' method.
1780 template <typename T, typename... Args>
1781 using has_fold_t = decltype(std::declval<T>().fold(
1782 std::declval<ArrayRef<Attribute>>(),
1783 std::declval<SmallVectorImpl<OpFoldResult> &>()));
1784 template <typename T>
1785 constexpr static bool has_fold_v = llvm::is_detected<has_fold_t, T>::value;
1786 /// Trait to check if T provides a 'fold' method with a FoldAdaptor for a
1787 /// single result op.
1788 template <typename T, typename... Args>
1789 using has_fold_adaptor_single_result_fold_t =
1790 decltype(std::declval<T>().fold(std::declval<typename T::FoldAdaptor>()));
1791 template <class T>
1792 constexpr static bool has_fold_adaptor_single_result_v =
1793 llvm::is_detected<has_fold_adaptor_single_result_fold_t, T>::value;
1794 /// Trait to check if T provides a general 'fold' method with a FoldAdaptor.
1795 template <typename T, typename... Args>
1796 using has_fold_adaptor_fold_t = decltype(std::declval<T>().fold(
1797 std::declval<typename T::FoldAdaptor>(),
1798 std::declval<SmallVectorImpl<OpFoldResult> &>()));
1799 template <class T>
1800 constexpr static bool has_fold_adaptor_v =
1801 llvm::is_detected<has_fold_adaptor_fold_t, T>::value;
1802
1803 /// Trait to check if T provides a 'print' method.
1804 template <typename T, typename... Args>
1805 using has_print =
1806 decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>()));
1807 template <typename T>
1808 using detect_has_print = llvm::is_detected<has_print, T>;
1809
1810 /// Trait to check if printProperties(OpAsmPrinter, T, ArrayRef<StringRef>)
1811 /// exist
1812 template <typename T, typename... Args>
1813 using has_print_properties =
1814 decltype(printProperties(std::declval<OpAsmPrinter &>(),
1815 std::declval<T>(),
1816 std::declval<ArrayRef<StringRef>>()));
1817 template <typename T>
1818 using detect_has_print_properties =
1819 llvm::is_detected<has_print_properties, T>;
1820
1821 /// Trait to check if parseProperties(OpAsmParser, T) exist
1822 template <typename T, typename... Args>
1823 using has_parse_properties = decltype(parseProperties(
1824 std::declval<OpAsmParser &>(), std::declval<T &>()));
1825 template <typename T>
1826 using detect_has_parse_properties =
1827 llvm::is_detected<has_parse_properties, T>;
1828
1829 /// Trait to check if T provides a 'ConcreteEntity' type alias.
1830 template <typename T>
1831 using has_concrete_entity_t = typename T::ConcreteEntity;
1832
1833public:
1834 /// Returns true if this operation defines a `Properties` inner type.
1835 static constexpr bool hasProperties() {
1836 return !std::is_same_v<
1837 typename ConcreteType::template InferredProperties<ConcreteType>,
1838 EmptyProperties>;
1839 }
1840
1841private:
1842 /// A struct-wrapped type alias to T::ConcreteEntity if provided and to
1843 /// ConcreteType otherwise. This is akin to std::conditional but doesn't fail
1844 /// on the missing typedef. Useful for checking if the interface is targeting
1845 /// the right class.
1846 template <typename T,
1847 bool = llvm::is_detected<has_concrete_entity_t, T>::value>
1848 struct InterfaceTargetOrOpT {
1849 using type = typename T::ConcreteEntity;
1850 };
1851 template <typename T>
1852 struct InterfaceTargetOrOpT<T, false> {
1853 using type = ConcreteType;
1854 };
1855
1856 /// A hook for static assertion that the external interface model T is
1857 /// targeting the concrete type of this op. The model can also be a fallback
1858 /// model that works for every op.
1859 template <typename T>
1860 static void checkInterfaceTarget() {
1861 static_assert(std::is_same<typename InterfaceTargetOrOpT<T>::type,
1862 ConcreteType>::value,
1863 "attaching an interface to the wrong op kind");
1864 }
1865
1866 /// Returns an interface map containing the interfaces registered to this
1867 /// operation.
1868 static detail::InterfaceMap getInterfaceMap() {
1869 return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
1870 }
1871
1872 /// Return the internal implementations of each of the OperationName
1873 /// hooks.
1874 /// Implementation of `FoldHookFn` OperationName hook.
1875 static OperationName::FoldHookFn getFoldHookFn() {
1876 // If the operation is single result and defines a `fold` method.
1877 if constexpr (llvm::is_one_of<OpTrait::OneResult<ConcreteType>,
1878 Traits<ConcreteType>...>::value &&
1879 (has_single_result_fold_v<ConcreteType> ||
1880 has_fold_adaptor_single_result_v<ConcreteType>))
1881 return [](Operation *op, ArrayRef<Attribute> operands,
1882 SmallVectorImpl<OpFoldResult> &results) {
1883 return foldSingleResultHook<ConcreteType>(op, operands, results);
1884 };
1885 // The operation is not single result and defines a `fold` method.
1886 if constexpr (has_fold_v<ConcreteType> || has_fold_adaptor_v<ConcreteType>)
1887 return [](Operation *op, ArrayRef<Attribute> operands,
1888 SmallVectorImpl<OpFoldResult> &results) {
1889 return foldHook<ConcreteType>(op, operands, results);
1890 };
1891 // The operation does not define a `fold` method.
1892 return [](Operation *op, ArrayRef<Attribute> operands,
1893 SmallVectorImpl<OpFoldResult> &results) {
1894 // In this case, we only need to fold the traits of the operation.
1895 return op_definition_impl::foldTraits<Traits<ConcreteType>...>(
1896 op, operands, results);
1897 };
1898 }
1899 /// Return the result of folding a single result operation that defines a
1900 /// `fold` method.
1901 template <typename ConcreteOpT>
1902 static LogicalResult
1903 foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
1904 SmallVectorImpl<OpFoldResult> &results) {
1905 OpFoldResult result;
1906 if constexpr (has_fold_adaptor_single_result_v<ConcreteOpT>) {
1907 result = cast<ConcreteOpT>(op).fold(
1908 typename ConcreteOpT::FoldAdaptor(operands, cast<ConcreteOpT>(op)));
1909 } else {
1910 result = cast<ConcreteOpT>(op).fold(operands);
1911 }
1912
1913 // If the fold failed or was in-place, try to fold the traits of the
1914 // operation.
1915 if (!result ||
1916 llvm::dyn_cast_if_present<Value>(Val&: result) == op->getResult(idx: 0)) {
1917 if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>(
1918 op, operands, results)))
1919 return success();
1920 return success(isSuccess: static_cast<bool>(result));
1921 }
1922 results.push_back(Elt: result);
1923 return success();
1924 }
1925 /// Return the result of folding an operation that defines a `fold` method.
1926 template <typename ConcreteOpT>
1927 static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
1928 SmallVectorImpl<OpFoldResult> &results) {
1929 auto result = LogicalResult::failure();
1930 if constexpr (has_fold_adaptor_v<ConcreteOpT>) {
1931 result = cast<ConcreteOpT>(op).fold(
1932 typename ConcreteOpT::FoldAdaptor(operands, cast<ConcreteOpT>(op)),
1933 results);
1934 } else {
1935 result = cast<ConcreteOpT>(op).fold(operands, results);
1936 }
1937
1938 // If the fold failed or was in-place, try to fold the traits of the
1939 // operation.
1940 if (failed(result) || results.empty()) {
1941 if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>(
1942 op, operands, results)))
1943 return success();
1944 }
1945 return result;
1946 }
1947
1948 /// Implementation of `GetHasTraitFn`
1949 static OperationName::HasTraitFn getHasTraitFn() {
1950 return
1951 [](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); };
1952 }
1953 /// Implementation of `PrintAssemblyFn` OperationName hook.
1954 static OperationName::PrintAssemblyFn getPrintAssemblyFn() {
1955 if constexpr (detect_has_print<ConcreteType>::value)
1956 return [](Operation *op, OpAsmPrinter &p, StringRef defaultDialect) {
1957 OpState::printOpName(op, p, defaultDialect);
1958 return cast<ConcreteType>(op).print(p);
1959 };
1960 return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) {
1961 return OpState::print(op, p&: printer, defaultDialect);
1962 };
1963 }
1964
1965public:
1966 template <typename T>
1967 using InferredProperties = typename PropertiesSelector<T>::type;
1968 template <typename T = ConcreteType>
1969 InferredProperties<T> &getProperties() {
1970 if constexpr (!hasProperties())
1971 return getEmptyProperties();
1972 return *getOperation()
1973 ->getPropertiesStorageUnsafe()
1974 .template as<InferredProperties<T> *>();
1975 }
1976
1977 /// This hook populates any unset default attrs when mapped to properties.
1978 template <typename T = ConcreteType>
1979 static void populateDefaultProperties(OperationName opName,
1980 InferredProperties<T> &properties) {}
1981
1982 /// Print the operation properties with names not included within
1983 /// 'elidedProps'. Unless overridden, this method will try to dispatch to a
1984 /// `printProperties` free-function if it exists, and otherwise by converting
1985 /// the properties to an Attribute.
1986 template <typename T>
1987 static void printProperties(MLIRContext *ctx, OpAsmPrinter &p,
1988 const T &properties,
1989 ArrayRef<StringRef> elidedProps = {}) {
1990 if constexpr (detect_has_print_properties<T>::value)
1991 return printProperties(p, properties, elidedProps);
1992 genericPrintProperties(
1993 p, properties: ConcreteType::getPropertiesAsAttr(ctx, properties), elidedProps);
1994 }
1995
1996 /// Parses 'prop-dict' for the operation. Unless overridden, the method will
1997 /// parse the properties using the generic property dictionary using the
1998 /// '<{ ... }>' syntax. The resulting properties are stored within the
1999 /// property structure of 'result', accessible via 'getOrAddProperties'.
2000 template <typename T = ConcreteType>
2001 static ParseResult parseProperties(OpAsmParser &parser,
2002 OperationState &result) {
2003 if constexpr (detect_has_parse_properties<InferredProperties<T>>::value) {
2004 return parseProperties(
2005 parser, result.getOrAddProperties<InferredProperties<T>>());
2006 }
2007
2008 Attribute propertyDictionary;
2009 if (genericParseProperties(parser, result&: propertyDictionary))
2010 return failure();
2011
2012 // The generated 'setPropertiesFromParsedAttr', like
2013 // 'setPropertiesFromAttr', expects a 'DictionaryAttr' that is not null.
2014 // Use an empty dictionary in the case that the whole dictionary is
2015 // optional.
2016 if (!propertyDictionary)
2017 propertyDictionary = DictionaryAttr::get(result.getContext());
2018
2019 auto emitError = [&]() {
2020 return mlir::emitError(loc: result.location, message: "invalid properties ")
2021 << propertyDictionary << " for op " << result.name.getStringRef()
2022 << ": ";
2023 };
2024
2025 // Copy the data from the dictionary attribute into the property struct of
2026 // the operation. This method is generated by ODS by default if there are
2027 // any occurrences of 'prop-dict' in the assembly format and should set
2028 // any properties that aren't parsed elsewhere.
2029 return ConcreteOpType::setPropertiesFromParsedAttr(
2030 result.getOrAddProperties<InferredProperties<T>>(), propertyDictionary,
2031 emitError);
2032 }
2033
2034private:
2035 /// Implementation of `PopulateDefaultAttrsFn` OperationName hook.
2036 static OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn() {
2037 return ConcreteType::populateDefaultAttrs;
2038 }
2039 /// Implementation of `VerifyInvariantsFn` OperationName hook.
2040 static LogicalResult verifyInvariants(Operation *op) {
2041 static_assert(hasNoDataMembers(),
2042 "Op class shouldn't define new data members");
2043 return failure(
2044 failed(op_definition_impl::verifyTraits<Traits<ConcreteType>...>(op)) ||
2045 failed(cast<ConcreteType>(op).verify()));
2046 }
2047 static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
2048 return static_cast<LogicalResult (*)(Operation *)>(&verifyInvariants);
2049 }
2050 /// Implementation of `VerifyRegionInvariantsFn` OperationName hook.
2051 static LogicalResult verifyRegionInvariants(Operation *op) {
2052 static_assert(hasNoDataMembers(),
2053 "Op class shouldn't define new data members");
2054 return failure(
2055 failed(op_definition_impl::verifyRegionTraits<Traits<ConcreteType>...>(
2056 op)) ||
2057 failed(cast<ConcreteType>(op).verifyRegions()));
2058 }
2059 static OperationName::VerifyRegionInvariantsFn getVerifyRegionInvariantsFn() {
2060 return static_cast<LogicalResult (*)(Operation *)>(&verifyRegionInvariants);
2061 }
2062
2063 static constexpr bool hasNoDataMembers() {
2064 // Checking that the derived class does not define any member by comparing
2065 // its size to an ad-hoc EmptyOp.
2066 class EmptyOp : public Op<EmptyOp, Traits...> {};
2067 return sizeof(ConcreteType) == sizeof(EmptyOp);
2068 }
2069
2070 /// Allow access to internal implementation methods.
2071 friend RegisteredOperationName;
2072};
2073
2074/// This class represents the base of an operation interface. See the definition
2075/// of `detail::Interface` for requirements on the `Traits` type.
2076template <typename ConcreteType, typename Traits>
2077class OpInterface
2078 : public detail::Interface<ConcreteType, Operation *, Traits,
2079 Op<ConcreteType>, OpTrait::TraitBase> {
2080public:
2081 using Base = OpInterface<ConcreteType, Traits>;
2082 using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits,
2083 Op<ConcreteType>, OpTrait::TraitBase>;
2084
2085 /// Inherit the base class constructor.
2086 using InterfaceBase::InterfaceBase;
2087
2088protected:
2089 /// Returns the impl interface instance for the given operation.
2090 static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
2091 OperationName name = op->getName();
2092
2093#ifndef NDEBUG
2094 // Check that the current interface isn't an unresolved promise for the
2095 // given operation.
2096 if (Dialect *dialect = name.getDialect()) {
2097 dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
2098 dialect&: *dialect, interfaceRequestorID: name.getTypeID(), interfaceID: ConcreteType::getInterfaceID(),
2099 interfaceName: llvm::getTypeName<ConcreteType>());
2100 }
2101#endif
2102
2103 // Access the raw interface from the operation info.
2104 if (std::optional<RegisteredOperationName> rInfo =
2105 name.getRegisteredInfo()) {
2106 if (auto *opIface = rInfo->getInterface<ConcreteType>())
2107 return opIface;
2108 // Fallback to the dialect to provide it with a chance to implement this
2109 // interface for this operation.
2110 return rInfo->getDialect().getRegisteredInterfaceForOp<ConcreteType>(
2111 op->getName());
2112 }
2113 // Fallback to the dialect to provide it with a chance to implement this
2114 // interface for this operation.
2115 if (Dialect *dialect = name.getDialect())
2116 return dialect->getRegisteredInterfaceForOp<ConcreteType>(name);
2117 return nullptr;
2118 }
2119
2120 /// Allow access to `getInterfaceFor`.
2121 friend InterfaceBase;
2122};
2123
2124} // namespace mlir
2125
2126namespace llvm {
2127
2128template <typename T>
2129struct DenseMapInfo<T,
2130 std::enable_if_t<std::is_base_of<mlir::OpState, T>::value &&
2131 !mlir::detail::IsInterface<T>::value>> {
2132 static inline T getEmptyKey() {
2133 auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
2134 return T::getFromOpaquePointer(pointer);
2135 }
2136 static inline T getTombstoneKey() {
2137 auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
2138 return T::getFromOpaquePointer(pointer);
2139 }
2140 static unsigned getHashValue(T val) {
2141 return hash_value(val.getAsOpaquePointer());
2142 }
2143 static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
2144};
2145} // namespace llvm
2146
2147#endif
2148

source code of mlir/include/mlir/IR/OpDefinition.h