1//===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements helper classes for implementing the "Op" types. This
10// includes the Op type, which is the base class for Op class definitions,
11// as well as number of traits in the OpTrait namespace that provide a
12// declarative way to specify properties of Ops.
13//
14// The purpose of these types are to allow light-weight implementation of
15// concrete ops (like DimOp) with very little boilerplate.
16//
17//===----------------------------------------------------------------------===//
18
19#ifndef MLIR_IR_OPDEFINITION_H
20#define MLIR_IR_OPDEFINITION_H
21
22#include "mlir/IR/Dialect.h"
23#include "mlir/IR/ODSSupport.h"
24#include "mlir/IR/Operation.h"
25#include "llvm/Support/PointerLikeTypeTraits.h"
26
27#include <optional>
28#include <type_traits>
29
30namespace mlir {
31class Builder;
32class OpBuilder;
33
34/// This class implements `Optional` functionality for ParseResult. We don't
35/// directly use Optional here, because it provides an implicit conversion
36/// to 'bool' which we want to avoid. This class is used to implement tri-state
37/// 'parseOptional' functions that may have a failure mode when parsing that
38/// shouldn't be attributed to "not present".
39class OptionalParseResult {
40public:
41 OptionalParseResult() = default;
42 OptionalParseResult(LogicalResult result) : impl(result) {}
43 OptionalParseResult(ParseResult result) : impl(result) {}
44 OptionalParseResult(const InFlightDiagnostic &)
45 : OptionalParseResult(failure()) {}
46 OptionalParseResult(std::nullopt_t) : impl(std::nullopt) {}
47
48 /// Returns true if we contain a valid ParseResult value.
49 bool has_value() const { return impl.has_value(); }
50
51 /// Access the internal ParseResult value.
52 ParseResult value() const { return *impl; }
53 ParseResult operator*() const { return value(); }
54
55private:
56 std::optional<ParseResult> impl;
57};
58
59// These functions are out-of-line utilities, which avoids them being template
60// instantiated/duplicated.
61namespace impl {
62/// Insert an operation, generated by `buildTerminatorOp`, at the end of the
63/// region's only block if it does not have a terminator already. If the region
64/// is empty, insert a new block first. `buildTerminatorOp` should return the
65/// terminator operation to insert.
66void ensureRegionTerminator(
67 Region &region, OpBuilder &builder, Location loc,
68 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
69void ensureRegionTerminator(
70 Region &region, Builder &builder, Location loc,
71 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
72
73} // namespace impl
74
75/// Structure used by default as a "marker" when no "Properties" are set on an
76/// Operation.
77struct EmptyProperties {
78 bool operator==(const EmptyProperties &) const { return true; }
79 bool operator!=(const EmptyProperties &) const { return false; }
80};
81
82/// Traits to detect whether an Operation defined a `Properties` type, otherwise
83/// it'll default to `EmptyProperties`.
84template <class Op, class = void>
85struct PropertiesSelector {
86 using type = EmptyProperties;
87};
88template <class Op>
89struct PropertiesSelector<Op, std::void_t<typename Op::Properties>> {
90 using type = typename Op::Properties;
91};
92
93/// This is the concrete base class that holds the operation pointer and has
94/// non-generic methods that only depend on State (to avoid having them
95/// instantiated on template types that don't affect them.
96///
97/// This also has the fallback implementations of customization hooks for when
98/// they aren't customized.
99class OpState {
100public:
101 /// Ops are pointer-like, so we allow conversion to bool.
102 explicit operator bool() { return getOperation() != nullptr; }
103
104 /// This implicitly converts to Operation*.
105 operator Operation *() const { return state; }
106
107 /// Shortcut of `->` to access a member of Operation.
108 Operation *operator->() const { return state; }
109
110 /// Return the operation that this refers to.
111 Operation *getOperation() { return state; }
112
113 /// Return the context this operation belongs to.
114 MLIRContext *getContext() { return getOperation()->getContext(); }
115
116 /// Print the operation to the given stream.
117 void print(raw_ostream &os, OpPrintingFlags flags = std::nullopt) {
118 state->print(os, flags);
119 }
120 void print(raw_ostream &os, AsmState &asmState) {
121 state->print(os, state&: asmState);
122 }
123
124 /// Dump this operation.
125 void dump() { state->dump(); }
126
127 /// The source location the operation was defined or derived from.
128 Location getLoc() { return state->getLoc(); }
129
130 /// Return true if there are no users of any results of this operation.
131 bool use_empty() { return state->use_empty(); }
132
133 /// Remove this operation from its parent block and delete it.
134 void erase() { state->erase(); }
135
136 /// Emit an error with the op name prefixed, like "'dim' op " which is
137 /// convenient for verifiers.
138 InFlightDiagnostic emitOpError(const Twine &message = {});
139
140 /// Emit an error about fatal conditions with this operation, reporting up to
141 /// any diagnostic handlers that may be listening.
142 InFlightDiagnostic emitError(const Twine &message = {});
143
144 /// Emit a warning about this operation, reporting up to any diagnostic
145 /// handlers that may be listening.
146 InFlightDiagnostic emitWarning(const Twine &message = {});
147
148 /// Emit a remark about this operation, reporting up to any diagnostic
149 /// handlers that may be listening.
150 InFlightDiagnostic emitRemark(const Twine &message = {});
151
152 /// Walk the operation by calling the callback for each nested operation
153 /// (including this one), block or region, depending on the callback provided.
154 /// The order in which regions, blocks and operations the same nesting level
155 /// are visited (e.g., lexicographical or reverse lexicographical order) is
156 /// determined by 'Iterator'. The walk order for enclosing regions, blocks
157 /// and operations with respect to their nested ones is specified by 'Order'
158 /// (post-order by default). A callback on a block or operation is allowed to
159 /// erase that block or operation if either:
160 /// * the walk is in post-order, or
161 /// * the walk is in pre-order and the walk is skipped after the erasure.
162 /// See Operation::walk for more details.
163 template <WalkOrder Order = WalkOrder::PostOrder,
164 typename Iterator = ForwardIterator, typename FnT,
165 typename RetT = detail::walkResultType<FnT>>
166 std::enable_if_t<llvm::function_traits<std::decay_t<FnT>>::num_args == 1,
167 RetT>
168 walk(FnT &&callback) {
169 return state->walk<Order, Iterator>(std::forward<FnT>(callback));
170 }
171
172 /// Generic walker with a stage aware callback. Walk the operation by calling
173 /// the callback for each nested operation (including this one) N+1 times,
174 /// where N is the number of regions attached to that operation.
175 ///
176 /// The callback method can take any of the following forms:
177 /// void(Operation *, const WalkStage &) : Walk all operation opaquely
178 /// * op.walk([](Operation *nestedOp, const WalkStage &stage) { ...});
179 /// void(OpT, const WalkStage &) : Walk all operations of the given derived
180 /// type.
181 /// * op.walk([](ReturnOp returnOp, const WalkStage &stage) { ...});
182 /// WalkResult(Operation*|OpT, const WalkStage &stage) : Walk operations,
183 /// but allow for interruption/skipping.
184 /// * op.walk([](... op, const WalkStage &stage) {
185 /// // Skip the walk of this op based on some invariant.
186 /// if (some_invariant)
187 /// return WalkResult::skip();
188 /// // Interrupt, i.e cancel, the walk based on some invariant.
189 /// if (another_invariant)
190 /// return WalkResult::interrupt();
191 /// return WalkResult::advance();
192 /// });
193 template <typename FnT, typename RetT = detail::walkResultType<FnT>>
194 std::enable_if_t<llvm::function_traits<std::decay_t<FnT>>::num_args == 2,
195 RetT>
196 walk(FnT &&callback) {
197 return state->walk(std::forward<FnT>(callback));
198 }
199
200 // These are default implementations of customization hooks.
201public:
202 /// This hook returns any canonicalization pattern rewrites that the operation
203 /// supports, for use by the canonicalization pass.
204 static void getCanonicalizationPatterns(RewritePatternSet &results,
205 MLIRContext *context) {}
206
207 /// This hook populates any unset default attrs.
208 static void populateDefaultAttrs(const OperationName &, NamedAttrList &) {}
209
210protected:
211 /// If the concrete type didn't implement a custom verifier hook, just fall
212 /// back to this one which accepts everything.
213 LogicalResult verify() { return success(); }
214 LogicalResult verifyRegions() { return success(); }
215
216 /// Parse the custom form of an operation. Unless overridden, this method will
217 /// first try to get an operation parser from the op's dialect. Otherwise the
218 /// custom assembly form of an op is always rejected. Op implementations
219 /// should implement this to return failure. On success, they should fill in
220 /// result with the fields to use.
221 static ParseResult parse(OpAsmParser &parser, OperationState &result);
222
223 /// Print the operation. Unless overridden, this method will first try to get
224 /// an operation printer from the dialect. Otherwise, it prints the operation
225 /// in generic form.
226 static void print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect);
227
228 /// Parse properties as a Attribute.
229 static ParseResult genericParseProperties(OpAsmParser &parser,
230 Attribute &result);
231
232 /// Print the properties as a Attribute with names not included within
233 /// 'elidedProps'
234 static void genericPrintProperties(OpAsmPrinter &p, Attribute properties,
235 ArrayRef<StringRef> elidedProps = {});
236
237 /// Print an operation name, eliding the dialect prefix if necessary.
238 static void printOpName(Operation *op, OpAsmPrinter &p,
239 StringRef defaultDialect);
240
241 /// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
242 /// so we can cast it away here.
243 explicit OpState(Operation *state) : state(state) {}
244
245 /// For all op which don't have properties, we keep a single instance of
246 /// `EmptyProperties` to be used where a reference to a properties is needed:
247 /// this allow to bind a pointer to the reference without triggering UB.
248 static EmptyProperties &getEmptyProperties() {
249 static EmptyProperties emptyProperties;
250 return emptyProperties;
251 }
252
253private:
254 Operation *state;
255
256 /// Allow access to internal hook implementation methods.
257 friend RegisteredOperationName;
258};
259
260// Allow comparing operators.
261inline bool operator==(OpState lhs, OpState rhs) {
262 return lhs.getOperation() == rhs.getOperation();
263}
264inline bool operator!=(OpState lhs, OpState rhs) {
265 return lhs.getOperation() != rhs.getOperation();
266}
267
268raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr);
269
270/// This class represents a single result from folding an operation.
271class OpFoldResult : public PointerUnion<Attribute, Value> {
272 using PointerUnion<Attribute, Value>::PointerUnion;
273
274public:
275 void dump() const { llvm::errs() << *this << "\n"; }
276
277 MLIRContext *getContext() const {
278 PointerUnion pu = *this;
279 return isa<Attribute>(Val: pu) ? cast<Attribute>(Val&: pu).getContext()
280 : cast<Value>(Val&: pu).getContext();
281 }
282};
283
284// Temporarily exit the MLIR namespace to add casting support as later code in
285// this uses it. The CastInfo must come after the OpFoldResult definition and
286// before any cast function calls depending on CastInfo.
287
288} // namespace mlir
289
290namespace llvm {
291
292// Allow llvm::cast style functions.
293template <typename To>
294struct CastInfo<To, mlir::OpFoldResult>
295 : public CastInfo<To, mlir::OpFoldResult::PointerUnion> {};
296
297template <typename To>
298struct CastInfo<To, const mlir::OpFoldResult>
299 : public CastInfo<To, const mlir::OpFoldResult::PointerUnion> {};
300
301} // namespace llvm
302
303namespace mlir {
304
305/// Allow printing to a stream.
306inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) {
307 if (Value value = llvm::dyn_cast_if_present<Value>(Val&: ofr))
308 value.print(os);
309 else
310 llvm::dyn_cast_if_present<Attribute>(Val&: ofr).print(os);
311 return os;
312}
313/// Allow printing to a stream.
314inline raw_ostream &operator<<(raw_ostream &os, OpState op) {
315 op.print(os, flags: OpPrintingFlags().useLocalScope());
316 return os;
317}
318
319//===----------------------------------------------------------------------===//
320// Operation Trait Types
321//===----------------------------------------------------------------------===//
322
323namespace OpTrait {
324
325// These functions are out-of-line implementations of the methods in the
326// corresponding trait classes. This avoids them being template
327// instantiated/duplicated.
328namespace impl {
329LogicalResult foldCommutative(Operation *op, ArrayRef<Attribute> operands,
330 SmallVectorImpl<OpFoldResult> &results);
331OpFoldResult foldIdempotent(Operation *op);
332OpFoldResult foldInvolution(Operation *op);
333LogicalResult verifyZeroOperands(Operation *op);
334LogicalResult verifyOneOperand(Operation *op);
335LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
336LogicalResult verifyIsIdempotent(Operation *op);
337LogicalResult verifyIsInvolution(Operation *op);
338LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
339LogicalResult verifyOperandsAreFloatLike(Operation *op);
340LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
341LogicalResult verifySameTypeOperands(Operation *op);
342LogicalResult verifyZeroRegions(Operation *op);
343LogicalResult verifyOneRegion(Operation *op);
344LogicalResult verifyNRegions(Operation *op, unsigned numRegions);
345LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions);
346LogicalResult verifyZeroResults(Operation *op);
347LogicalResult verifyOneResult(Operation *op);
348LogicalResult verifyNResults(Operation *op, unsigned numOperands);
349LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
350LogicalResult verifySameOperandsShape(Operation *op);
351LogicalResult verifySameOperandsAndResultShape(Operation *op);
352LogicalResult verifySameOperandsElementType(Operation *op);
353LogicalResult verifySameOperandsAndResultElementType(Operation *op);
354LogicalResult verifySameOperandsAndResultType(Operation *op);
355LogicalResult verifySameOperandsAndResultRank(Operation *op);
356LogicalResult verifyResultsAreBoolLike(Operation *op);
357LogicalResult verifyResultsAreFloatLike(Operation *op);
358LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
359LogicalResult verifyIsTerminator(Operation *op);
360LogicalResult verifyZeroSuccessors(Operation *op);
361LogicalResult verifyOneSuccessor(Operation *op);
362LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors);
363LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
364LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName,
365 StringRef valueGroupName,
366 size_t expectedCount);
367LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
368LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
369LogicalResult verifyNoRegionArguments(Operation *op);
370LogicalResult verifyElementwise(Operation *op);
371LogicalResult verifyIsIsolatedFromAbove(Operation *op);
372} // namespace impl
373
374/// Helper class for implementing traits. Clients are not expected to interact
375/// with this directly, so its members are all protected.
376template <typename ConcreteType, template <typename> class TraitType>
377class TraitBase {
378protected:
379 /// Return the ultimate Operation being worked on.
380 Operation *getOperation() {
381 auto *concrete = static_cast<ConcreteType *>(this);
382 return concrete->getOperation();
383 }
384};
385
386//===----------------------------------------------------------------------===//
387// Operand Traits
388//===----------------------------------------------------------------------===//
389
390namespace detail {
391/// Utility trait base that provides accessors for derived traits that have
392/// multiple operands.
393template <typename ConcreteType, template <typename> class TraitType>
394struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
395 using operand_iterator = Operation::operand_iterator;
396 using operand_range = Operation::operand_range;
397 using operand_type_iterator = Operation::operand_type_iterator;
398 using operand_type_range = Operation::operand_type_range;
399
400 /// Return the number of operands.
401 unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }
402
403 /// Return the operand at index 'i'.
404 Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); }
405
406 /// Set the operand at index 'i' to 'value'.
407 void setOperand(unsigned i, Value value) {
408 this->getOperation()->setOperand(i, value);
409 }
410
411 /// Operand iterator access.
412 operand_iterator operand_begin() {
413 return this->getOperation()->operand_begin();
414 }
415 operand_iterator operand_end() { return this->getOperation()->operand_end(); }
416 operand_range getOperands() { return this->getOperation()->getOperands(); }
417
418 /// Operand type access.
419 operand_type_iterator operand_type_begin() {
420 return this->getOperation()->operand_type_begin();
421 }
422 operand_type_iterator operand_type_end() {
423 return this->getOperation()->operand_type_end();
424 }
425 operand_type_range getOperandTypes() {
426 return this->getOperation()->getOperandTypes();
427 }
428};
429} // namespace detail
430
431/// `verifyInvariantsImpl` verifies the invariants like the types, attrs, .etc.
432/// It should be run after core traits and before any other user defined traits.
433/// In order to run it in the correct order, wrap it with OpInvariants trait so
434/// that tblgen will be able to put it in the right order.
435template <typename ConcreteType>
436class OpInvariants : public TraitBase<ConcreteType, OpInvariants> {
437public:
438 static LogicalResult verifyTrait(Operation *op) {
439 return cast<ConcreteType>(op).verifyInvariantsImpl();
440 }
441};
442
443/// This class provides the API for ops that are known to have no
444/// SSA operand.
445template <typename ConcreteType>
446class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> {
447public:
448 static LogicalResult verifyTrait(Operation *op) {
449 return impl::verifyZeroOperands(op);
450 }
451
452private:
453 // Disable these.
454 void getOperand() {}
455 void setOperand() {}
456};
457
458/// This class provides the API for ops that are known to have exactly one
459/// SSA operand.
460template <typename ConcreteType>
461class OneOperand : public TraitBase<ConcreteType, OneOperand> {
462public:
463 Value getOperand() { return this->getOperation()->getOperand(0); }
464
465 void setOperand(Value value) { this->getOperation()->setOperand(0, value); }
466
467 static LogicalResult verifyTrait(Operation *op) {
468 return impl::verifyOneOperand(op);
469 }
470};
471
472/// This class provides the API for ops that are known to have a specified
473/// number of operands. This is used as a trait like this:
474///
475/// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> {
476///
477template <unsigned N>
478class NOperands {
479public:
480 static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2");
481
482 template <typename ConcreteType>
483 class Impl
484 : public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> {
485 public:
486 static LogicalResult verifyTrait(Operation *op) {
487 return impl::verifyNOperands(op, numOperands: N);
488 }
489 };
490};
491
492/// This class provides the API for ops that are known to have a at least a
493/// specified number of operands. This is used as a trait like this:
494///
495/// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> {
496///
497template <unsigned N>
498class AtLeastNOperands {
499public:
500 template <typename ConcreteType>
501 class Impl : public detail::MultiOperandTraitBase<ConcreteType,
502 AtLeastNOperands<N>::Impl> {
503 public:
504 static LogicalResult verifyTrait(Operation *op) {
505 return impl::verifyAtLeastNOperands(op, numOperands: N);
506 }
507 };
508};
509
510/// This class provides the API for ops which have an unknown number of
511/// SSA operands.
512template <typename ConcreteType>
513class VariadicOperands
514 : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};
515
516//===----------------------------------------------------------------------===//
517// Region Traits
518//===----------------------------------------------------------------------===//
519
520/// This class provides verification for ops that are known to have zero
521/// regions.
522template <typename ConcreteType>
523class ZeroRegions : public TraitBase<ConcreteType, ZeroRegions> {
524public:
525 static LogicalResult verifyTrait(Operation *op) {
526 return impl::verifyZeroRegions(op);
527 }
528};
529
530namespace detail {
531/// Utility trait base that provides accessors for derived traits that have
532/// multiple regions.
533template <typename ConcreteType, template <typename> class TraitType>
534struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> {
535 using region_iterator = MutableArrayRef<Region>;
536 using region_range = RegionRange;
537
538 /// Return the number of regions.
539 unsigned getNumRegions() { return this->getOperation()->getNumRegions(); }
540
541 /// Return the region at `index`.
542 Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); }
543
544 /// Region iterator access.
545 region_iterator region_begin() {
546 return this->getOperation()->region_begin();
547 }
548 region_iterator region_end() { return this->getOperation()->region_end(); }
549 region_range getRegions() { return this->getOperation()->getRegions(); }
550};
551} // namespace detail
552
553/// This class provides APIs for ops that are known to have a single region.
554template <typename ConcreteType>
555class OneRegion : public TraitBase<ConcreteType, OneRegion> {
556public:
557 Region &getRegion() { return this->getOperation()->getRegion(0); }
558
559 /// Returns a range of operations within the region of this operation.
560 auto getOps() { return getRegion().getOps(); }
561 template <typename OpT>
562 auto getOps() {
563 return getRegion().template getOps<OpT>();
564 }
565
566 static LogicalResult verifyTrait(Operation *op) {
567 return impl::verifyOneRegion(op);
568 }
569};
570
571/// This class provides the API for ops that are known to have a specified
572/// number of regions.
573template <unsigned N>
574class NRegions {
575public:
576 static_assert(N > 1, "use ZeroRegions/OneRegion for N < 2");
577
578 template <typename ConcreteType>
579 class Impl
580 : public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> {
581 public:
582 static LogicalResult verifyTrait(Operation *op) {
583 return impl::verifyNRegions(op, numRegions: N);
584 }
585 };
586};
587
588/// This class provides APIs for ops that are known to have at least a specified
589/// number of regions.
590template <unsigned N>
591class AtLeastNRegions {
592public:
593 template <typename ConcreteType>
594 class Impl : public detail::MultiRegionTraitBase<ConcreteType,
595 AtLeastNRegions<N>::Impl> {
596 public:
597 static LogicalResult verifyTrait(Operation *op) {
598 return impl::verifyAtLeastNRegions(op, numRegions: N);
599 }
600 };
601};
602
603/// This class provides the API for ops which have an unknown number of
604/// regions.
605template <typename ConcreteType>
606class VariadicRegions
607 : public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {};
608
609//===----------------------------------------------------------------------===//
610// Result Traits
611//===----------------------------------------------------------------------===//
612
613/// This class provides return value APIs for ops that are known to have
614/// zero results.
615template <typename ConcreteType>
616class ZeroResults : public TraitBase<ConcreteType, ZeroResults> {
617public:
618 static LogicalResult verifyTrait(Operation *op) {
619 return impl::verifyZeroResults(op);
620 }
621};
622
623namespace detail {
624/// Utility trait base that provides accessors for derived traits that have
625/// multiple results.
626template <typename ConcreteType, template <typename> class TraitType>
627struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
628 using result_iterator = Operation::result_iterator;
629 using result_range = Operation::result_range;
630 using result_type_iterator = Operation::result_type_iterator;
631 using result_type_range = Operation::result_type_range;
632
633 /// Return the number of results.
634 unsigned getNumResults() { return this->getOperation()->getNumResults(); }
635
636 /// Return the result at index 'i'.
637 Value getResult(unsigned i) { return this->getOperation()->getResult(i); }
638
639 /// Replace all uses of results of this operation with the provided 'values'.
640 /// 'values' may correspond to an existing operation, or a range of 'Value'.
641 template <typename ValuesT>
642 void replaceAllUsesWith(ValuesT &&values) {
643 this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
644 }
645
646 /// Return the type of the `i`-th result.
647 Type getType(unsigned i) { return getResult(i).getType(); }
648
649 /// Result iterator access.
650 result_iterator result_begin() {
651 return this->getOperation()->result_begin();
652 }
653 result_iterator result_end() { return this->getOperation()->result_end(); }
654 result_range getResults() { return this->getOperation()->getResults(); }
655
656 /// Result type access.
657 result_type_iterator result_type_begin() {
658 return this->getOperation()->result_type_begin();
659 }
660 result_type_iterator result_type_end() {
661 return this->getOperation()->result_type_end();
662 }
663 result_type_range getResultTypes() {
664 return this->getOperation()->getResultTypes();
665 }
666};
667} // namespace detail
668
669/// This class provides return value APIs for ops that are known to have a
670/// single result. ResultType is the concrete type returned by getType().
671template <typename ConcreteType>
672class OneResult : public TraitBase<ConcreteType, OneResult> {
673public:
674 /// Replace all uses of 'this' value with the new value, updating anything
675 /// in the IR that uses 'this' to use the other value instead. When this
676 /// returns there are zero uses of 'this'.
677 void replaceAllUsesWith(Value newValue) {
678 this->getOperation()->getResult(0).replaceAllUsesWith(newValue);
679 }
680
681 /// Replace all uses of 'this' value with the result of 'op'.
682 void replaceAllUsesWith(Operation *op) {
683 this->getOperation()->replaceAllUsesWith(op);
684 }
685
686 static LogicalResult verifyTrait(Operation *op) {
687 return impl::verifyOneResult(op);
688 }
689};
690
691/// This trait is used for return value APIs for ops that are known to have a
692/// specific type other than `Type`. This allows the "getType()" member to be
693/// more specific for an op. This should be used in conjunction with OneResult,
694/// and occur in the trait list before OneResult.
695template <typename ResultType>
696class OneTypedResult {
697public:
698 /// This class provides return value APIs for ops that are known to have a
699 /// single result. ResultType is the concrete type returned by getType().
700 template <typename ConcreteType>
701 class Impl
702 : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
703 public:
704 mlir::TypedValue<ResultType> getResult() {
705 return cast<mlir::TypedValue<ResultType>>(
706 this->getOperation()->getResult(0));
707 }
708
709 /// If the operation returns a single value, then the Op can be implicitly
710 /// converted to a Value. This yields the value of the only result.
711 operator mlir::TypedValue<ResultType>() { return getResult(); }
712
713 ResultType getType() { return getResult().getType(); }
714 };
715};
716
717/// This class provides the API for ops that are known to have a specified
718/// number of results. This is used as a trait like this:
719///
720/// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> {
721///
722template <unsigned N>
723class NResults {
724public:
725 static_assert(N > 1, "use ZeroResults/OneResult for N < 2");
726
727 template <typename ConcreteType>
728 class Impl
729 : public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> {
730 public:
731 static LogicalResult verifyTrait(Operation *op) {
732 return impl::verifyNResults(op, numOperands: N);
733 }
734 };
735};
736
737/// This class provides the API for ops that are known to have at least a
738/// specified number of results. This is used as a trait like this:
739///
740/// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> {
741///
742template <unsigned N>
743class AtLeastNResults {
744public:
745 template <typename ConcreteType>
746 class Impl : public detail::MultiResultTraitBase<ConcreteType,
747 AtLeastNResults<N>::Impl> {
748 public:
749 static LogicalResult verifyTrait(Operation *op) {
750 return impl::verifyAtLeastNResults(op, numOperands: N);
751 }
752 };
753};
754
755/// This class provides the API for ops which have an unknown number of
756/// results.
757template <typename ConcreteType>
758class VariadicResults
759 : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};
760
761//===----------------------------------------------------------------------===//
762// Terminator Traits
763//===----------------------------------------------------------------------===//
764
765/// This class indicates that the regions associated with this op don't have
766/// terminators.
767template <typename ConcreteType>
768class NoTerminator : public TraitBase<ConcreteType, NoTerminator> {};
769
770/// This class provides the API for ops that are known to be terminators.
771template <typename ConcreteType>
772class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
773public:
774 static LogicalResult verifyTrait(Operation *op) {
775 return impl::verifyIsTerminator(op);
776 }
777};
778
779/// This class provides verification for ops that are known to have zero
780/// successors.
781template <typename ConcreteType>
782class ZeroSuccessors : public TraitBase<ConcreteType, ZeroSuccessors> {
783public:
784 static LogicalResult verifyTrait(Operation *op) {
785 return impl::verifyZeroSuccessors(op);
786 }
787};
788
789namespace detail {
790/// Utility trait base that provides accessors for derived traits that have
791/// multiple successors.
792template <typename ConcreteType, template <typename> class TraitType>
793struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> {
794 using succ_iterator = Operation::succ_iterator;
795 using succ_range = SuccessorRange;
796
797 /// Return the number of successors.
798 unsigned getNumSuccessors() {
799 return this->getOperation()->getNumSuccessors();
800 }
801
802 /// Return the successor at `index`.
803 Block *getSuccessor(unsigned i) {
804 return this->getOperation()->getSuccessor(i);
805 }
806
807 /// Set the successor at `index`.
808 void setSuccessor(Block *block, unsigned i) {
809 return this->getOperation()->setSuccessor(block, i);
810 }
811
812 /// Successor iterator access.
813 succ_iterator succ_begin() { return this->getOperation()->succ_begin(); }
814 succ_iterator succ_end() { return this->getOperation()->succ_end(); }
815 succ_range getSuccessors() { return this->getOperation()->getSuccessors(); }
816};
817} // namespace detail
818
819/// This class provides APIs for ops that are known to have a single successor.
820template <typename ConcreteType>
821class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> {
822public:
823 Block *getSuccessor() { return this->getOperation()->getSuccessor(0); }
824 void setSuccessor(Block *succ) {
825 this->getOperation()->setSuccessor(succ, 0);
826 }
827
828 static LogicalResult verifyTrait(Operation *op) {
829 return impl::verifyOneSuccessor(op);
830 }
831};
832
833/// This class provides the API for ops that are known to have a specified
834/// number of successors.
835template <unsigned N>
836class NSuccessors {
837public:
838 static_assert(N > 1, "use ZeroSuccessors/OneSuccessor for N < 2");
839
840 template <typename ConcreteType>
841 class Impl : public detail::MultiSuccessorTraitBase<ConcreteType,
842 NSuccessors<N>::Impl> {
843 public:
844 static LogicalResult verifyTrait(Operation *op) {
845 return impl::verifyNSuccessors(op, numSuccessors: N);
846 }
847 };
848};
849
850/// This class provides APIs for ops that are known to have at least a specified
851/// number of successors.
852template <unsigned N>
853class AtLeastNSuccessors {
854public:
855 template <typename ConcreteType>
856 class Impl
857 : public detail::MultiSuccessorTraitBase<ConcreteType,
858 AtLeastNSuccessors<N>::Impl> {
859 public:
860 static LogicalResult verifyTrait(Operation *op) {
861 return impl::verifyAtLeastNSuccessors(op, numSuccessors: N);
862 }
863 };
864};
865
866/// This class provides the API for ops which have an unknown number of
867/// successors.
868template <typename ConcreteType>
869class VariadicSuccessors
870 : public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> {
871};
872
873//===----------------------------------------------------------------------===//
874// SingleBlock
875//===----------------------------------------------------------------------===//
876
877/// This class provides APIs and verifiers for ops with regions having a single
878/// block.
879template <typename ConcreteType>
880struct SingleBlock : public TraitBase<ConcreteType, SingleBlock> {
881public:
882 static LogicalResult verifyTrait(Operation *op) {
883 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
884 Region &region = op->getRegion(index: i);
885
886 // Empty regions are fine.
887 if (region.empty())
888 continue;
889
890 // Non-empty regions must contain a single basic block.
891 if (!llvm::hasSingleElement(C&: region))
892 return op->emitOpError(message: "expects region #")
893 << i << " to have 0 or 1 blocks";
894
895 if (!ConcreteType::template hasTrait<NoTerminator>()) {
896 Block &block = region.front();
897 if (block.empty())
898 return op->emitOpError() << "expects a non-empty block";
899 }
900 }
901 return success();
902 }
903
904 Block *getBody(unsigned idx = 0) {
905 Region &region = this->getOperation()->getRegion(idx);
906 assert(!region.empty() && "unexpected empty region");
907 return &region.front();
908 }
909 Region &getBodyRegion(unsigned idx = 0) {
910 return this->getOperation()->getRegion(idx);
911 }
912
913 //===------------------------------------------------------------------===//
914 // Single Region Utilities
915 //===------------------------------------------------------------------===//
916
917 /// The following are a set of methods only enabled when the parent
918 /// operation has a single region. Each of these methods take an additional
919 /// template parameter that represents the concrete operation so that we
920 /// can use SFINAE to disable the methods for non-single region operations.
921 template <typename OpT, typename T = void>
922 using enable_if_single_region =
923 std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
924
925 template <typename OpT = ConcreteType>
926 enable_if_single_region<OpT, Block::iterator> begin() {
927 return getBody()->begin();
928 }
929 template <typename OpT = ConcreteType>
930 enable_if_single_region<OpT, Block::iterator> end() {
931 return getBody()->end();
932 }
933 template <typename OpT = ConcreteType>
934 enable_if_single_region<OpT, Operation &> front() {
935 return *begin();
936 }
937
938 /// Insert the operation into the back of the body.
939 template <typename OpT = ConcreteType>
940 enable_if_single_region<OpT> push_back(Operation *op) {
941 insert(Block::iterator(getBody()->end()), op);
942 }
943
944 /// Insert the operation at the given insertion point.
945 template <typename OpT = ConcreteType>
946 enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
947 insert(Block::iterator(insertPt), op);
948 }
949 template <typename OpT = ConcreteType>
950 enable_if_single_region<OpT> insert(Block::iterator insertPt, Operation *op) {
951 getBody()->getOperations().insert(insertPt, op);
952 }
953};
954
955//===----------------------------------------------------------------------===//
956// SingleBlockImplicitTerminator
957//===----------------------------------------------------------------------===//
958
959/// This class provides APIs and verifiers for ops with regions having a single
960/// block that must terminate with `TerminatorOpType`.
961template <typename TerminatorOpType>
962struct SingleBlockImplicitTerminator {
963 template <typename ConcreteType>
964 class Impl : public TraitBase<ConcreteType, SingleBlockImplicitTerminator<
965 TerminatorOpType>::Impl> {
966 private:
967 /// Builds a terminator operation without relying on OpBuilder APIs to avoid
968 /// cyclic header inclusion.
969 static Operation *buildTerminator(OpBuilder &builder, Location loc) {
970 OperationState state(loc, TerminatorOpType::getOperationName());
971 TerminatorOpType::build(builder, state);
972 return Operation::create(state);
973 }
974
975 public:
976 /// The type of the operation used as the implicit terminator type.
977 using ImplicitTerminatorOpT = TerminatorOpType;
978
979 static LogicalResult verifyRegionTrait(Operation *op) {
980 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
981 Region &region = op->getRegion(index: i);
982 // Empty regions are fine.
983 if (region.empty())
984 continue;
985 Operation &terminator = region.front().back();
986 if (isa<TerminatorOpType>(terminator))
987 continue;
988
989 return op->emitOpError(message: "expects regions to end with '" +
990 TerminatorOpType::getOperationName() +
991 "', found '" +
992 terminator.getName().getStringRef() + "'")
993 .attachNote()
994 << "in custom textual format, the absence of terminator implies "
995 "'"
996 << TerminatorOpType::getOperationName() << '\'';
997 }
998
999 return success();
1000 }
1001
1002 /// Ensure that the given region has the terminator required by this trait.
1003 /// If OpBuilder is provided, use it to build the terminator and notify the
1004 /// OpBuilder listeners accordingly. If only a Builder is provided, locally
1005 /// construct an OpBuilder with no listeners; this should only be used if no
1006 /// OpBuilder is available at the call site, e.g., in the parser.
1007 static void ensureTerminator(Region &region, Builder &builder,
1008 Location loc) {
1009 ::mlir::impl::ensureRegionTerminator(region, builder, loc,
1010 buildTerminatorOp: buildTerminator);
1011 }
1012 static void ensureTerminator(Region &region, OpBuilder &builder,
1013 Location loc) {
1014 ::mlir::impl::ensureRegionTerminator(region, builder, loc,
1015 buildTerminatorOp: buildTerminator);
1016 }
1017 };
1018};
1019
1020/// Check is an op defines the `ImplicitTerminatorOpT` member. This is intended
1021/// to be used with `llvm::is_detected`.
1022template <class T>
1023using has_implicit_terminator_t = typename T::ImplicitTerminatorOpT;
1024
1025/// Support to check if an operation has the SingleBlockImplicitTerminator
1026/// trait. We can't just use `hasTrait` because this class is templated on a
1027/// specific terminator op.
1028template <class Op, bool hasTerminator =
1029 llvm::is_detected<has_implicit_terminator_t, Op>::value>
1030struct hasSingleBlockImplicitTerminator {
1031 static constexpr bool value = std::is_base_of<
1032 typename OpTrait::SingleBlockImplicitTerminator<
1033 typename Op::ImplicitTerminatorOpT>::template Impl<Op>,
1034 Op>::value;
1035};
1036template <class Op>
1037struct hasSingleBlockImplicitTerminator<Op, false> {
1038 static constexpr bool value = false;
1039};
1040
1041//===----------------------------------------------------------------------===//
1042// Misc Traits
1043//===----------------------------------------------------------------------===//
1044
1045/// This class provides verification for ops that are known to have the same
1046/// operand shape: all operands are scalars, vectors/tensors of the same
1047/// shape.
1048template <typename ConcreteType>
1049class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> {
1050public:
1051 static LogicalResult verifyTrait(Operation *op) {
1052 return impl::verifySameOperandsShape(op);
1053 }
1054};
1055
1056/// This class provides verification for ops that are known to have the same
1057/// operand and result shape: both are scalars, vectors/tensors of the same
1058/// shape.
1059template <typename ConcreteType>
1060class SameOperandsAndResultShape
1061 : public TraitBase<ConcreteType, SameOperandsAndResultShape> {
1062public:
1063 static LogicalResult verifyTrait(Operation *op) {
1064 return impl::verifySameOperandsAndResultShape(op);
1065 }
1066};
1067
1068/// This class provides verification for ops that are known to have the same
1069/// operand element type (or the type itself if it is scalar).
1070///
1071template <typename ConcreteType>
1072class SameOperandsElementType
1073 : public TraitBase<ConcreteType, SameOperandsElementType> {
1074public:
1075 static LogicalResult verifyTrait(Operation *op) {
1076 return impl::verifySameOperandsElementType(op);
1077 }
1078};
1079
1080/// This class provides verification for ops that are known to have the same
1081/// operand and result element type (or the type itself if it is scalar).
1082///
1083template <typename ConcreteType>
1084class SameOperandsAndResultElementType
1085 : public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
1086public:
1087 static LogicalResult verifyTrait(Operation *op) {
1088 return impl::verifySameOperandsAndResultElementType(op);
1089 }
1090};
1091
1092/// This class provides verification for ops that are known to have the same
1093/// operand and result type.
1094///
1095/// Note: this trait subsumes the SameOperandsAndResultShape and
1096/// SameOperandsAndResultElementType traits.
1097template <typename ConcreteType>
1098class SameOperandsAndResultType
1099 : public TraitBase<ConcreteType, SameOperandsAndResultType> {
1100public:
1101 static LogicalResult verifyTrait(Operation *op) {
1102 return impl::verifySameOperandsAndResultType(op);
1103 }
1104};
1105
1106/// This class verifies that op has same ranks for all
1107/// operands and results types, if known.
1108template <typename ConcreteType>
1109class SameOperandsAndResultRank
1110 : public TraitBase<ConcreteType, SameOperandsAndResultRank> {
1111public:
1112 static LogicalResult verifyTrait(Operation *op) {
1113 return impl::verifySameOperandsAndResultRank(op);
1114 }
1115};
1116
1117/// This class verifies that any results of the specified op have a boolean
1118/// type, a vector thereof, or a tensor thereof.
1119template <typename ConcreteType>
1120class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> {
1121public:
1122 static LogicalResult verifyTrait(Operation *op) {
1123 return impl::verifyResultsAreBoolLike(op);
1124 }
1125};
1126
1127/// This class verifies that any results of the specified op have a floating
1128/// point type, a vector thereof, or a tensor thereof.
1129template <typename ConcreteType>
1130class ResultsAreFloatLike
1131 : public TraitBase<ConcreteType, ResultsAreFloatLike> {
1132public:
1133 static LogicalResult verifyTrait(Operation *op) {
1134 return impl::verifyResultsAreFloatLike(op);
1135 }
1136};
1137
1138/// This class verifies that any results of the specified op have a signless
1139/// integer or index type, a vector thereof, or a tensor thereof.
1140template <typename ConcreteType>
1141class ResultsAreSignlessIntegerLike
1142 : public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> {
1143public:
1144 static LogicalResult verifyTrait(Operation *op) {
1145 return impl::verifyResultsAreSignlessIntegerLike(op);
1146 }
1147};
1148
1149/// This class adds property that the operation is commutative.
1150template <typename ConcreteType>
1151class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {
1152public:
1153 static LogicalResult foldTrait(Operation *op, ArrayRef<Attribute> operands,
1154 SmallVectorImpl<OpFoldResult> &results) {
1155 return impl::foldCommutative(op, operands, results);
1156 }
1157};
1158
1159/// This class adds property that the operation is an involution.
1160/// This means a unary to unary operation "f" that satisfies f(f(x)) = x
1161template <typename ConcreteType>
1162class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
1163public:
1164 static LogicalResult verifyTrait(Operation *op) {
1165 static_assert(ConcreteType::template hasTrait<OneResult>(),
1166 "expected operation to produce one result");
1167 static_assert(ConcreteType::template hasTrait<OneOperand>(),
1168 "expected operation to take one operand");
1169 static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1170 "expected operation to preserve type");
1171 // Involution requires the operation to be side effect free as well
1172 // but currently this check is under a FIXME and is not actually done.
1173 return impl::verifyIsInvolution(op);
1174 }
1175
1176 static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1177 return impl::foldInvolution(op);
1178 }
1179};
1180
1181/// This class adds property that the operation is idempotent.
1182/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x),
1183/// or a binary operation "g" that satisfies g(x, x) = x.
1184template <typename ConcreteType>
1185class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> {
1186public:
1187 static LogicalResult verifyTrait(Operation *op) {
1188 static_assert(ConcreteType::template hasTrait<OneResult>(),
1189 "expected operation to produce one result");
1190 static_assert(ConcreteType::template hasTrait<OneOperand>() ||
1191 ConcreteType::template hasTrait<NOperands<2>::Impl>(),
1192 "expected operation to take one or two operands");
1193 static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1194 "expected operation to preserve type");
1195 // Idempotent requires the operation to be side effect free as well
1196 // but currently this check is under a FIXME and is not actually done.
1197 return impl::verifyIsIdempotent(op);
1198 }
1199
1200 static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1201 return impl::foldIdempotent(op);
1202 }
1203};
1204
1205/// This class verifies that all operands of the specified op have a float type,
1206/// a vector thereof, or a tensor thereof.
1207template <typename ConcreteType>
1208class OperandsAreFloatLike
1209 : public TraitBase<ConcreteType, OperandsAreFloatLike> {
1210public:
1211 static LogicalResult verifyTrait(Operation *op) {
1212 return impl::verifyOperandsAreFloatLike(op);
1213 }
1214};
1215
1216/// This class verifies that all operands of the specified op have a signless
1217/// integer or index type, a vector thereof, or a tensor thereof.
1218template <typename ConcreteType>
1219class OperandsAreSignlessIntegerLike
1220 : public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> {
1221public:
1222 static LogicalResult verifyTrait(Operation *op) {
1223 return impl::verifyOperandsAreSignlessIntegerLike(op);
1224 }
1225};
1226
1227/// This class verifies that all operands of the specified op have the same
1228/// type.
1229template <typename ConcreteType>
1230class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> {
1231public:
1232 static LogicalResult verifyTrait(Operation *op) {
1233 return impl::verifySameTypeOperands(op);
1234 }
1235};
1236
1237/// This class provides the API for a sub-set of ops that are known to be
1238/// constant-like. These are non-side effecting operations with one result and
1239/// zero operands that can always be folded to a specific attribute value.
1240template <typename ConcreteType>
1241class ConstantLike : public TraitBase<ConcreteType, ConstantLike> {
1242public:
1243 static LogicalResult verifyTrait(Operation *op) {
1244 static_assert(ConcreteType::template hasTrait<OneResult>(),
1245 "expected operation to produce one result");
1246 static_assert(ConcreteType::template hasTrait<ZeroOperands>(),
1247 "expected operation to take zero operands");
1248 // TODO: We should verify that the operation can always be folded, but this
1249 // requires that the attributes of the op already be verified. We should add
1250 // support for verifying traits "after" the operation to enable this use
1251 // case.
1252 return success();
1253 }
1254};
1255
1256/// This class provides the API for ops that are known to be isolated from
1257/// above.
1258template <typename ConcreteType>
1259class IsIsolatedFromAbove
1260 : public TraitBase<ConcreteType, IsIsolatedFromAbove> {
1261public:
1262 static LogicalResult verifyRegionTrait(Operation *op) {
1263 return impl::verifyIsIsolatedFromAbove(op);
1264 }
1265};
1266
1267/// A trait of region holding operations that defines a new scope for polyhedral
1268/// optimization purposes. Any SSA values of 'index' type that either dominate
1269/// such an operation or are used at the top-level of such an operation
1270/// automatically become valid symbols for the polyhedral scope defined by that
1271/// operation. For more details, see `Traits.md#AffineScope`.
1272template <typename ConcreteType>
1273class AffineScope : public TraitBase<ConcreteType, AffineScope> {
1274public:
1275 static LogicalResult verifyTrait(Operation *op) {
1276 static_assert(!ConcreteType::template hasTrait<ZeroRegions>(),
1277 "expected operation to have one or more regions");
1278 return success();
1279 }
1280};
1281
1282/// A trait of region holding operations that define a new scope for automatic
1283/// allocations, i.e., allocations that are freed when control is transferred
1284/// back from the operation's region. Any operations performing such allocations
1285/// (for eg. memref.alloca) will have their allocations automatically freed at
1286/// their closest enclosing operation with this trait.
1287template <typename ConcreteType>
1288class AutomaticAllocationScope
1289 : public TraitBase<ConcreteType, AutomaticAllocationScope> {
1290public:
1291 static LogicalResult verifyTrait(Operation *op) {
1292 static_assert(!ConcreteType::template hasTrait<ZeroRegions>(),
1293 "expected operation to have one or more regions");
1294 return success();
1295 }
1296};
1297
1298/// This class provides a verifier for ops that are expecting their parent
1299/// to be one of the given parent ops
1300template <typename... ParentOpTypes>
1301struct HasParent {
1302 template <typename ConcreteType>
1303 class Impl : public TraitBase<ConcreteType, Impl> {
1304 public:
1305 static LogicalResult verifyTrait(Operation *op) {
1306 if (llvm::isa_and_nonnull<ParentOpTypes...>(op->getParentOp()))
1307 return success();
1308
1309 return op->emitOpError()
1310 << "expects parent op "
1311 << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'")
1312 << llvm::ArrayRef({ParentOpTypes::getOperationName()...}) << "'";
1313 }
1314
1315 template <typename ParentOpType =
1316 std::tuple_element_t<0, std::tuple<ParentOpTypes...>>>
1317 std::enable_if_t<sizeof...(ParentOpTypes) == 1, ParentOpType>
1318 getParentOp() {
1319 Operation *parent = this->getOperation()->getParentOp();
1320 return llvm::cast<ParentOpType>(parent);
1321 }
1322 };
1323};
1324
1325/// A trait for operations that have an attribute specifying operand segments.
1326///
1327/// Certain operations can have multiple variadic operands and their size
1328/// relationship is not always known statically. For such cases, we need
1329/// a per-op-instance specification to divide the operands into logical groups
1330/// or segments. This can be modeled by attributes. The attribute will be named
1331/// as `operandSegmentSizes`.
1332///
1333/// This trait verifies the attribute for specifying operand segments has
1334/// the correct type (1D vector) and values (non-negative), etc.
1335template <typename ConcreteType>
1336class AttrSizedOperandSegments
1337 : public TraitBase<ConcreteType, AttrSizedOperandSegments> {
1338public:
1339 static StringRef getOperandSegmentSizeAttr() { return "operandSegmentSizes"; }
1340
1341 static LogicalResult verifyTrait(Operation *op) {
1342 return ::mlir::OpTrait::impl::verifyOperandSizeAttr(
1343 op, sizeAttrName: getOperandSegmentSizeAttr());
1344 }
1345};
1346
1347/// Similar to AttrSizedOperandSegments but used for results.
1348template <typename ConcreteType>
1349class AttrSizedResultSegments
1350 : public TraitBase<ConcreteType, AttrSizedResultSegments> {
1351public:
1352 static StringRef getResultSegmentSizeAttr() { return "resultSegmentSizes"; }
1353
1354 static LogicalResult verifyTrait(Operation *op) {
1355 return ::mlir::OpTrait::impl::verifyResultSizeAttr(
1356 op, sizeAttrName: getResultSegmentSizeAttr());
1357 }
1358};
1359
1360/// This trait provides a verifier for ops that are expecting their regions to
1361/// not have any arguments
1362template <typename ConcrentType>
1363struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> {
1364 static LogicalResult verifyTrait(Operation *op) {
1365 return ::mlir::OpTrait::impl::verifyNoRegionArguments(op);
1366 }
1367};
1368
1369// This trait is used to flag operations that consume or produce
1370// values of `MemRef` type where those references can be 'normalized'.
1371// TODO: Right now, the operands of an operation are either all normalizable,
1372// or not. In the future, we may want to allow some of the operands to be
1373// normalizable.
1374template <typename ConcrentType>
1375struct MemRefsNormalizable
1376 : public TraitBase<ConcrentType, MemRefsNormalizable> {};
1377
1378/// This trait tags element-wise ops on vectors or tensors.
1379///
1380/// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this
1381/// trait. In particular, broadcasting behavior is not allowed.
1382///
1383/// An `Elementwise` op must satisfy the following properties:
1384///
1385/// 1. If any result is a vector/tensor then at least one operand must also be a
1386/// vector/tensor.
1387/// 2. If any operand is a vector/tensor then there must be at least one result
1388/// and all results must be vectors/tensors.
1389/// 3. All operand and result vector/tensor types must be of the same shape. The
1390/// shape may be dynamic in which case the op's behaviour is undefined for
1391/// non-matching shapes.
1392/// 4. The operation must be elementwise on its vector/tensor operands and
1393/// results. When applied to single-element vectors/tensors, the result must
1394/// be the same per elememnt.
1395///
1396/// TODO: Avoid hardcoding vector/tensor, and generalize this trait to a new
1397/// interface `ElementwiseTypeInterface` that describes the container types for
1398/// which the operation is elementwise.
1399///
1400/// Rationale:
1401/// - 1. and 2. guarantee a well-defined iteration space and exclude the cases
1402/// of 0 non-scalar operands or 0 non-scalar results, which complicate a
1403/// generic definition of the iteration space.
1404/// - 3. guarantees that folding can be done across scalars/vectors/tensors with
1405/// the same pattern, as otherwise lots of special handling for type
1406/// mismatches would be needed.
1407/// - 4. guarantees that no error handling is needed. Higher-level dialects
1408/// should reify any needed guards or error handling code before lowering to
1409/// an `Elementwise` op.
1410template <typename ConcreteType>
1411struct Elementwise : public TraitBase<ConcreteType, Elementwise> {
1412 static LogicalResult verifyTrait(Operation *op) {
1413 return ::mlir::OpTrait::impl::verifyElementwise(op);
1414 }
1415};
1416
1417/// This trait tags `Elementwise` operatons that can be systematically
1418/// scalarized. All vector/tensor operands and results are then replaced by
1419/// scalars of the respective element type. Semantically, this is the operation
1420/// on a single element of the vector/tensor.
1421///
1422/// Rationale:
1423/// Allow to define the vector/tensor semantics of elementwise operations based
1424/// on the same op's behavior on scalars. This provides a constructive procedure
1425/// for IR transformations to, e.g., create scalar loop bodies from tensor ops.
1426///
1427/// Example:
1428/// ```
1429/// %tensor_select = "arith.select"(%pred_tensor, %true_val, %false_val)
1430/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
1431/// -> tensor<?xf32>
1432/// ```
1433/// can be scalarized to
1434///
1435/// ```
1436/// %scalar_select = "arith.select"(%pred, %true_val_scalar, %false_val_scalar)
1437/// : (i1, f32, f32) -> f32
1438/// ```
1439template <typename ConcreteType>
1440struct Scalarizable : public TraitBase<ConcreteType, Scalarizable> {
1441 static LogicalResult verifyTrait(Operation *op) {
1442 static_assert(
1443 ConcreteType::template hasTrait<Elementwise>(),
1444 "`Scalarizable` trait is only applicable to `Elementwise` ops.");
1445 return success();
1446 }
1447};
1448
1449/// This trait tags `Elementwise` operatons that can be systematically
1450/// vectorized. All scalar operands and results are then replaced by vectors
1451/// with the respective element type. Semantically, this is the operation on
1452/// multiple elements simultaneously. See also `Tensorizable`.
1453///
1454/// Rationale:
1455/// Provide the reverse to `Scalarizable` which, when chained together, allows
1456/// reasoning about the relationship between the tensor and vector case.
1457/// Additionally, it permits reasoning about promoting scalars to vectors via
1458/// broadcasting in cases like `%select_scalar_pred` below.
1459template <typename ConcreteType>
1460struct Vectorizable : public TraitBase<ConcreteType, Vectorizable> {
1461 static LogicalResult verifyTrait(Operation *op) {
1462 static_assert(
1463 ConcreteType::template hasTrait<Elementwise>(),
1464 "`Vectorizable` trait is only applicable to `Elementwise` ops.");
1465 return success();
1466 }
1467};
1468
1469/// This trait tags `Elementwise` operatons that can be systematically
1470/// tensorized. All scalar operands and results are then replaced by tensors
1471/// with the respective element type. Semantically, this is the operation on
1472/// multiple elements simultaneously. See also `Vectorizable`.
1473///
1474/// Rationale:
1475/// Provide the reverse to `Scalarizable` which, when chained together, allows
1476/// reasoning about the relationship between the tensor and vector case.
1477/// Additionally, it permits reasoning about promoting scalars to tensors via
1478/// broadcasting in cases like `%select_scalar_pred` below.
1479///
1480/// Examples:
1481/// ```
1482/// %scalar = "arith.addf"(%a, %b) : (f32, f32) -> f32
1483/// ```
1484/// can be tensorized to
1485/// ```
1486/// %tensor = "arith.addf"(%a, %b) : (tensor<?xf32>, tensor<?xf32>)
1487/// -> tensor<?xf32>
1488/// ```
1489///
1490/// ```
1491/// %scalar_pred = "arith.select"(%pred, %true_val, %false_val)
1492/// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
1493/// ```
1494/// can be tensorized to
1495/// ```
1496/// %tensor_pred = "arith.select"(%pred, %true_val, %false_val)
1497/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
1498/// -> tensor<?xf32>
1499/// ```
1500template <typename ConcreteType>
1501struct Tensorizable : public TraitBase<ConcreteType, Tensorizable> {
1502 static LogicalResult verifyTrait(Operation *op) {
1503 static_assert(
1504 ConcreteType::template hasTrait<Elementwise>(),
1505 "`Tensorizable` trait is only applicable to `Elementwise` ops.");
1506 return success();
1507 }
1508};
1509
1510/// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable`
1511/// provide an easy way for scalar operations to conveniently generalize their
1512/// behavior to vectors/tensors, and systematize conversion between these forms.
1513bool hasElementwiseMappableTraits(Operation *op);
1514
1515} // namespace OpTrait
1516
1517//===----------------------------------------------------------------------===//
1518// Internal Trait Utilities
1519//===----------------------------------------------------------------------===//
1520
1521namespace op_definition_impl {
1522//===----------------------------------------------------------------------===//
1523// Trait Existence
1524//===----------------------------------------------------------------------===//
1525
1526/// Returns true if this given Trait ID matches the IDs of any of the provided
1527/// trait types `Traits`.
1528template <template <typename T> class... Traits>
1529inline bool hasTrait(TypeID traitID) {
1530 TypeID traitIDs[] = {TypeID::get<Traits>()...};
1531 for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
1532 if (traitIDs[i] == traitID)
1533 return true;
1534 return false;
1535}
1536template <>
1537inline bool hasTrait<>(TypeID traitID) {
1538 return false;
1539}
1540
1541//===----------------------------------------------------------------------===//
1542// Trait Folding
1543//===----------------------------------------------------------------------===//
1544
1545/// Trait to check if T provides a 'foldTrait' method for single result
1546/// operations.
1547template <typename T, typename... Args>
1548using has_single_result_fold_trait = decltype(T::foldTrait(
1549 std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>()));
1550template <typename T>
1551using detect_has_single_result_fold_trait =
1552 llvm::is_detected<has_single_result_fold_trait, T>;
1553/// Trait to check if T provides a general 'foldTrait' method.
1554template <typename T, typename... Args>
1555using has_fold_trait =
1556 decltype(T::foldTrait(std::declval<Operation *>(),
1557 std::declval<ArrayRef<Attribute>>(),
1558 std::declval<SmallVectorImpl<OpFoldResult> &>()));
1559template <typename T>
1560using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>;
1561/// Trait to check if T provides any `foldTrait` method.
1562template <typename T>
1563using detect_has_any_fold_trait =
1564 std::disjunction<detect_has_fold_trait<T>,
1565 detect_has_single_result_fold_trait<T>>;
1566
1567/// Returns the result of folding a trait that implements a `foldTrait` function
1568/// that is specialized for operations that have a single result.
1569template <typename Trait>
1570static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value,
1571 LogicalResult>
1572foldTrait(Operation *op, ArrayRef<Attribute> operands,
1573 SmallVectorImpl<OpFoldResult> &results) {
1574 assert(op->hasTrait<OpTrait::OneResult>() &&
1575 "expected trait on non single-result operation to implement the "
1576 "general `foldTrait` method");
1577 // If a previous trait has already been folded and replaced this operation, we
1578 // fail to fold this trait.
1579 if (!results.empty())
1580 return failure();
1581
1582 if (OpFoldResult result = Trait::foldTrait(op, operands)) {
1583 if (llvm::dyn_cast_if_present<Value>(Val&: result) != op->getResult(idx: 0))
1584 results.push_back(Elt: result);
1585 return success();
1586 }
1587 return failure();
1588}
1589/// Returns the result of folding a trait that implements a generalized
1590/// `foldTrait` function that is supports any operation type.
1591template <typename Trait>
1592static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult>
1593foldTrait(Operation *op, ArrayRef<Attribute> operands,
1594 SmallVectorImpl<OpFoldResult> &results) {
1595 // If a previous trait has already been folded and replaced this operation, we
1596 // fail to fold this trait.
1597 return results.empty() ? Trait::foldTrait(op, operands, results) : failure();
1598}
1599template <typename Trait>
1600static inline std::enable_if_t<!detect_has_any_fold_trait<Trait>::value,
1601 LogicalResult>
1602foldTrait(Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) {
1603 return failure();
1604}
1605
1606/// Given a tuple type containing a set of traits, return the result of folding
1607/// the given operation.
1608template <typename... Ts>
1609static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
1610 SmallVectorImpl<OpFoldResult> &results) {
1611 return success((succeeded(foldTrait<Ts>(op, operands, results)) || ...));
1612}
1613
1614//===----------------------------------------------------------------------===//
1615// Trait Verification
1616//===----------------------------------------------------------------------===//
1617
1618/// Trait to check if T provides a `verifyTrait` method.
1619template <typename T, typename... Args>
1620using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>()));
1621template <typename T>
1622using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>;
1623
1624/// Trait to check if T provides a `verifyTrait` method.
1625template <typename T, typename... Args>
1626using has_verify_region_trait =
1627 decltype(T::verifyRegionTrait(std::declval<Operation *>()));
1628template <typename T>
1629using detect_has_verify_region_trait =
1630 llvm::is_detected<has_verify_region_trait, T>;
1631
1632/// Verify the given trait if it provides a verifier.
1633template <typename T>
1634LogicalResult verifyTrait(Operation *op) {
1635 if constexpr (detect_has_verify_trait<T>::value)
1636 return T::verifyTrait(op);
1637 else
1638 return success();
1639}
1640
1641/// Given a set of traits, return the result of verifying the given operation.
1642template <typename... Ts>
1643LogicalResult verifyTraits(Operation *op) {
1644 return success((succeeded(verifyTrait<Ts>(op)) && ...));
1645}
1646
1647/// Verify the given trait if it provides a region verifier.
1648template <typename T>
1649LogicalResult verifyRegionTrait(Operation *op) {
1650 if constexpr (detect_has_verify_region_trait<T>::value)
1651 return T::verifyRegionTrait(op);
1652 else
1653 return success();
1654}
1655
1656/// Given a set of traits, return the result of verifying the regions of the
1657/// given operation.
1658template <typename... Ts>
1659LogicalResult verifyRegionTraits(Operation *op) {
1660 return success((succeeded(verifyRegionTrait<Ts>(op)) && ...));
1661}
1662} // namespace op_definition_impl
1663
1664//===----------------------------------------------------------------------===//
1665// Operation Definition classes
1666//===----------------------------------------------------------------------===//
1667
1668/// This provides public APIs that all operations should have. The template
1669/// argument 'ConcreteType' should be the concrete type by CRTP and the others
1670/// are base classes by the policy pattern.
1671template <typename ConcreteType, template <typename T> class... Traits>
1672class Op : public OpState, public Traits<ConcreteType>... {
1673public:
1674 /// Inherit getOperation from `OpState`.
1675 using OpState::getOperation;
1676 using OpState::verify;
1677 using OpState::verifyRegions;
1678
1679 /// Return if this operation contains the provided trait.
1680 template <template <typename T> class Trait>
1681 static constexpr bool hasTrait() {
1682 return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value;
1683 }
1684
1685 /// Create a deep copy of this operation.
1686 ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); }
1687
1688 /// Create a partial copy of this operation without traversing into attached
1689 /// regions. The new operation will have the same number of regions as the
1690 /// original one, but they will be left empty.
1691 ConcreteType cloneWithoutRegions() {
1692 return cast<ConcreteType>(getOperation()->cloneWithoutRegions());
1693 }
1694
1695 /// Return true if this "op class" can match against the specified operation.
1696 static bool classof(Operation *op) {
1697 if (auto info = op->getRegisteredInfo())
1698 return TypeID::get<ConcreteType>() == info->getTypeID();
1699#ifndef NDEBUG
1700 if (op->getName().getStringRef() == ConcreteType::getOperationName())
1701 llvm::report_fatal_error(
1702 "classof on '" + ConcreteType::getOperationName() +
1703 "' failed due to the operation not being registered");
1704#endif
1705 return false;
1706 }
1707 /// Provide `classof` support for other OpBase derived classes, such as
1708 /// Interfaces.
1709 template <typename T>
1710 static std::enable_if_t<std::is_base_of<OpState, T>::value, bool>
1711 classof(const T *op) {
1712 return classof(const_cast<T *>(op)->getOperation());
1713 }
1714
1715 /// Expose the type we are instantiated on to template machinery that may want
1716 /// to introspect traits on this operation.
1717 using ConcreteOpType = ConcreteType;
1718
1719 /// This is a public constructor. Any op can be initialized to null.
1720 explicit Op() : OpState(nullptr) {}
1721 Op(std::nullptr_t) : OpState(nullptr) {}
1722
1723 /// This is a public constructor to enable access via the llvm::cast family of
1724 /// methods. This should not be used directly.
1725 explicit Op(Operation *state) : OpState(state) {}
1726
1727 /// Methods for supporting PointerLikeTypeTraits.
1728 const void *getAsOpaquePointer() const {
1729 return static_cast<const void *>((Operation *)*this);
1730 }
1731 static ConcreteOpType getFromOpaquePointer(const void *pointer) {
1732 return ConcreteOpType(
1733 reinterpret_cast<Operation *>(const_cast<void *>(pointer)));
1734 }
1735
1736 /// Attach the given models as implementations of the corresponding
1737 /// interfaces for the concrete operation.
1738 template <typename... Models>
1739 static void attachInterface(MLIRContext &context) {
1740 std::optional<RegisteredOperationName> info =
1741 RegisteredOperationName::lookup(TypeID::get<ConcreteType>(), &context);
1742 if (!info)
1743 llvm::report_fatal_error(
1744 "Attempting to attach an interface to an unregistered operation " +
1745 ConcreteType::getOperationName() + ".");
1746 (checkInterfaceTarget<Models>(), ...);
1747 info->attachInterface<Models...>();
1748 }
1749 /// Convert the provided attribute to a property and assigned it to the
1750 /// provided properties. This default implementation forwards to a free
1751 /// function `setPropertiesFromAttribute` that can be looked up with ADL in
1752 /// the namespace where the properties are defined. It can also be overridden
1753 /// in the derived ConcreteOp.
1754 template <typename PropertiesTy>
1755 static LogicalResult
1756 setPropertiesFromAttr(PropertiesTy &prop, Attribute attr,
1757 function_ref<InFlightDiagnostic()> emitError) {
1758 return setPropertiesFromAttribute(prop, attr, emitError);
1759 }
1760 /// Convert the provided properties to an attribute. This default
1761 /// implementation forwards to a free function `getPropertiesAsAttribute` that
1762 /// can be looked up with ADL in the namespace where the properties are
1763 /// defined. It can also be overridden in the derived ConcreteOp.
1764 template <typename PropertiesTy>
1765 static Attribute getPropertiesAsAttr(MLIRContext *ctx,
1766 const PropertiesTy &prop) {
1767 return getPropertiesAsAttribute(ctx, prop);
1768 }
1769 /// Hash the provided properties. This default implementation forwards to a
1770 /// free function `computeHash` that can be looked up with ADL in the
1771 /// namespace where the properties are defined. It can also be overridden in
1772 /// the derived ConcreteOp.
1773 template <typename PropertiesTy>
1774 static llvm::hash_code computePropertiesHash(const PropertiesTy &prop) {
1775 return computeHash(prop);
1776 }
1777
1778private:
1779 /// Trait to check if T provides a 'fold' method for a single result op.
1780 template <typename T, typename... Args>
1781 using has_single_result_fold_t =
1782 decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
1783 template <typename T>
1784 constexpr static bool has_single_result_fold_v =
1785 llvm::is_detected<has_single_result_fold_t, T>::value;
1786 /// Trait to check if T provides a general 'fold' method.
1787 template <typename T, typename... Args>
1788 using has_fold_t = decltype(std::declval<T>().fold(
1789 std::declval<ArrayRef<Attribute>>(),
1790 std::declval<SmallVectorImpl<OpFoldResult> &>()));
1791 template <typename T>
1792 constexpr static bool has_fold_v = llvm::is_detected<has_fold_t, T>::value;
1793 /// Trait to check if T provides a 'fold' method with a FoldAdaptor for a
1794 /// single result op.
1795 template <typename T, typename... Args>
1796 using has_fold_adaptor_single_result_fold_t =
1797 decltype(std::declval<T>().fold(std::declval<typename T::FoldAdaptor>()));
1798 template <class T>
1799 constexpr static bool has_fold_adaptor_single_result_v =
1800 llvm::is_detected<has_fold_adaptor_single_result_fold_t, T>::value;
1801 /// Trait to check if T provides a general 'fold' method with a FoldAdaptor.
1802 template <typename T, typename... Args>
1803 using has_fold_adaptor_fold_t = decltype(std::declval<T>().fold(
1804 std::declval<typename T::FoldAdaptor>(),
1805 std::declval<SmallVectorImpl<OpFoldResult> &>()));
1806 template <class T>
1807 constexpr static bool has_fold_adaptor_v =
1808 llvm::is_detected<has_fold_adaptor_fold_t, T>::value;
1809
1810 /// Trait to check if T provides a 'print' method.
1811 template <typename T, typename... Args>
1812 using has_print =
1813 decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>()));
1814 template <typename T>
1815 using detect_has_print = llvm::is_detected<has_print, T>;
1816
1817 /// Trait to check if printProperties(OpAsmPrinter, T, ArrayRef<StringRef>)
1818 /// exist
1819 template <typename T, typename... Args>
1820 using has_print_properties =
1821 decltype(printProperties(std::declval<OpAsmPrinter &>(),
1822 std::declval<T>(),
1823 std::declval<ArrayRef<StringRef>>()));
1824 template <typename T>
1825 using detect_has_print_properties =
1826 llvm::is_detected<has_print_properties, T>;
1827
1828 /// Trait to check if parseProperties(OpAsmParser, T) exist
1829 template <typename T, typename... Args>
1830 using has_parse_properties = decltype(parseProperties(
1831 std::declval<OpAsmParser &>(), std::declval<T &>()));
1832 template <typename T>
1833 using detect_has_parse_properties =
1834 llvm::is_detected<has_parse_properties, T>;
1835
1836 /// Trait to check if T provides a 'ConcreteEntity' type alias.
1837 template <typename T>
1838 using has_concrete_entity_t = typename T::ConcreteEntity;
1839
1840public:
1841 /// Returns true if this operation defines a `Properties` inner type.
1842 static constexpr bool hasProperties() {
1843 return !std::is_same_v<
1844 typename ConcreteType::template InferredProperties<ConcreteType>,
1845 EmptyProperties>;
1846 }
1847
1848private:
1849 /// A struct-wrapped type alias to T::ConcreteEntity if provided and to
1850 /// ConcreteType otherwise. This is akin to std::conditional but doesn't fail
1851 /// on the missing typedef. Useful for checking if the interface is targeting
1852 /// the right class.
1853 template <typename T,
1854 bool = llvm::is_detected<has_concrete_entity_t, T>::value>
1855 struct InterfaceTargetOrOpT {
1856 using type = typename T::ConcreteEntity;
1857 };
1858 template <typename T>
1859 struct InterfaceTargetOrOpT<T, false> {
1860 using type = ConcreteType;
1861 };
1862
1863 /// A hook for static assertion that the external interface model T is
1864 /// targeting the concrete type of this op. The model can also be a fallback
1865 /// model that works for every op.
1866 template <typename T>
1867 static void checkInterfaceTarget() {
1868 static_assert(std::is_same<typename InterfaceTargetOrOpT<T>::type,
1869 ConcreteType>::value,
1870 "attaching an interface to the wrong op kind");
1871 }
1872
1873 /// Returns an interface map containing the interfaces registered to this
1874 /// operation.
1875 static detail::InterfaceMap getInterfaceMap() {
1876 return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
1877 }
1878
1879 /// Return the internal implementations of each of the OperationName
1880 /// hooks.
1881 /// Implementation of `FoldHookFn` OperationName hook.
1882 static OperationName::FoldHookFn getFoldHookFn() {
1883 // If the operation is single result and defines a `fold` method.
1884 if constexpr (llvm::is_one_of<OpTrait::OneResult<ConcreteType>,
1885 Traits<ConcreteType>...>::value &&
1886 (has_single_result_fold_v<ConcreteType> ||
1887 has_fold_adaptor_single_result_v<ConcreteType>))
1888 return [](Operation *op, ArrayRef<Attribute> operands,
1889 SmallVectorImpl<OpFoldResult> &results) {
1890 return foldSingleResultHook<ConcreteType>(op, operands, results);
1891 };
1892 // The operation is not single result and defines a `fold` method.
1893 if constexpr (has_fold_v<ConcreteType> || has_fold_adaptor_v<ConcreteType>)
1894 return [](Operation *op, ArrayRef<Attribute> operands,
1895 SmallVectorImpl<OpFoldResult> &results) {
1896 return foldHook<ConcreteType>(op, operands, results);
1897 };
1898 // The operation does not define a `fold` method.
1899 return [](Operation *op, ArrayRef<Attribute> operands,
1900 SmallVectorImpl<OpFoldResult> &results) {
1901 // In this case, we only need to fold the traits of the operation.
1902 return op_definition_impl::foldTraits<Traits<ConcreteType>...>(
1903 op, operands, results);
1904 };
1905 }
1906 /// Return the result of folding a single result operation that defines a
1907 /// `fold` method.
1908 template <typename ConcreteOpT>
1909 static LogicalResult
1910 foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
1911 SmallVectorImpl<OpFoldResult> &results) {
1912 OpFoldResult result;
1913 if constexpr (has_fold_adaptor_single_result_v<ConcreteOpT>) {
1914 result = cast<ConcreteOpT>(op).fold(
1915 typename ConcreteOpT::FoldAdaptor(operands, cast<ConcreteOpT>(op)));
1916 } else {
1917 result = cast<ConcreteOpT>(op).fold(operands);
1918 }
1919
1920 // If the fold failed or was in-place, try to fold the traits of the
1921 // operation.
1922 if (!result ||
1923 llvm::dyn_cast_if_present<Value>(Val&: result) == op->getResult(idx: 0)) {
1924 if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>(
1925 op, operands, results)))
1926 return success();
1927 return success(IsSuccess: static_cast<bool>(result));
1928 }
1929 results.push_back(Elt: result);
1930 return success();
1931 }
1932 /// Return the result of folding an operation that defines a `fold` method.
1933 template <typename ConcreteOpT>
1934 static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
1935 SmallVectorImpl<OpFoldResult> &results) {
1936 auto result = LogicalResult::failure();
1937 if constexpr (has_fold_adaptor_v<ConcreteOpT>) {
1938 result = cast<ConcreteOpT>(op).fold(
1939 typename ConcreteOpT::FoldAdaptor(operands, cast<ConcreteOpT>(op)),
1940 results);
1941 } else {
1942 result = cast<ConcreteOpT>(op).fold(operands, results);
1943 }
1944
1945 // If the fold failed or was in-place, try to fold the traits of the
1946 // operation.
1947 if (failed(Result: result) || results.empty()) {
1948 if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>(
1949 op, operands, results)))
1950 return success();
1951 }
1952 return result;
1953 }
1954
1955 /// Implementation of `GetHasTraitFn`
1956 static OperationName::HasTraitFn getHasTraitFn() {
1957 return
1958 [](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); };
1959 }
1960 /// Implementation of `PrintAssemblyFn` OperationName hook.
1961 static OperationName::PrintAssemblyFn getPrintAssemblyFn() {
1962 if constexpr (detect_has_print<ConcreteType>::value)
1963 return [](Operation *op, OpAsmPrinter &p, StringRef defaultDialect) {
1964 OpState::printOpName(op, p, defaultDialect);
1965 return cast<ConcreteType>(op).print(p);
1966 };
1967 return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) {
1968 return OpState::print(op, p&: printer, defaultDialect);
1969 };
1970 }
1971
1972public:
1973 template <typename T>
1974 using InferredProperties = typename PropertiesSelector<T>::type;
1975 template <typename T = ConcreteType>
1976 InferredProperties<T> &getProperties() {
1977 if constexpr (!hasProperties())
1978 return getEmptyProperties();
1979 return *getOperation()
1980 ->getPropertiesStorageUnsafe()
1981 .template as<InferredProperties<T> *>();
1982 }
1983
1984 /// This hook populates any unset default attrs when mapped to properties.
1985 template <typename T = ConcreteType>
1986 static void populateDefaultProperties(OperationName opName,
1987 InferredProperties<T> &properties) {}
1988
1989 /// Print the operation properties with names not included within
1990 /// 'elidedProps'. Unless overridden, this method will try to dispatch to a
1991 /// `printProperties` free-function if it exists, and otherwise by converting
1992 /// the properties to an Attribute.
1993 template <typename T>
1994 static void printProperties(MLIRContext *ctx, OpAsmPrinter &p,
1995 const T &properties,
1996 ArrayRef<StringRef> elidedProps = {}) {
1997 if constexpr (detect_has_print_properties<T>::value)
1998 return printProperties(p, properties, elidedProps);
1999 genericPrintProperties(
2000 p, properties: ConcreteType::getPropertiesAsAttr(ctx, properties), elidedProps);
2001 }
2002
2003 /// Parses 'prop-dict' for the operation. Unless overridden, the method will
2004 /// parse the properties using the generic property dictionary using the
2005 /// '<{ ... }>' syntax. The resulting properties are stored within the
2006 /// property structure of 'result', accessible via 'getOrAddProperties'.
2007 template <typename T = ConcreteType>
2008 static ParseResult parseProperties(OpAsmParser &parser,
2009 OperationState &result) {
2010 if constexpr (detect_has_parse_properties<InferredProperties<T>>::value) {
2011 return parseProperties(
2012 parser, result.getOrAddProperties<InferredProperties<T>>());
2013 }
2014
2015 Attribute propertyDictionary;
2016 if (genericParseProperties(parser, result&: propertyDictionary))
2017 return failure();
2018
2019 // The generated 'setPropertiesFromParsedAttr', like
2020 // 'setPropertiesFromAttr', expects a 'DictionaryAttr' that is not null.
2021 // Use an empty dictionary in the case that the whole dictionary is
2022 // optional.
2023 if (!propertyDictionary)
2024 propertyDictionary = DictionaryAttr::get(result.getContext());
2025
2026 auto emitError = [&]() {
2027 return mlir::emitError(loc: result.location, message: "invalid properties ")
2028 << propertyDictionary << " for op " << result.name.getStringRef()
2029 << ": ";
2030 };
2031
2032 // Copy the data from the dictionary attribute into the property struct of
2033 // the operation. This method is generated by ODS by default if there are
2034 // any occurrences of 'prop-dict' in the assembly format and should set
2035 // any properties that aren't parsed elsewhere.
2036 return ConcreteOpType::setPropertiesFromParsedAttr(
2037 result.getOrAddProperties<InferredProperties<T>>(), propertyDictionary,
2038 emitError);
2039 }
2040
2041private:
2042 /// Implementation of `PopulateDefaultAttrsFn` OperationName hook.
2043 static OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn() {
2044 return ConcreteType::populateDefaultAttrs;
2045 }
2046 /// Implementation of `VerifyInvariantsFn` OperationName hook.
2047 static LogicalResult verifyInvariants(Operation *op) {
2048 static_assert(hasNoDataMembers(),
2049 "Op class shouldn't define new data members");
2050 return failure(
2051 failed(op_definition_impl::verifyTraits<Traits<ConcreteType>...>(op)) ||
2052 failed(cast<ConcreteType>(op).verify()));
2053 }
2054 static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
2055 return static_cast<LogicalResult (*)(Operation *)>(&verifyInvariants);
2056 }
2057 /// Implementation of `VerifyRegionInvariantsFn` OperationName hook.
2058 static LogicalResult verifyRegionInvariants(Operation *op) {
2059 static_assert(hasNoDataMembers(),
2060 "Op class shouldn't define new data members");
2061 return failure(
2062 failed(op_definition_impl::verifyRegionTraits<Traits<ConcreteType>...>(
2063 op)) ||
2064 failed(cast<ConcreteType>(op).verifyRegions()));
2065 }
2066 static OperationName::VerifyRegionInvariantsFn getVerifyRegionInvariantsFn() {
2067 return static_cast<LogicalResult (*)(Operation *)>(&verifyRegionInvariants);
2068 }
2069
2070 static constexpr bool hasNoDataMembers() {
2071 // Checking that the derived class does not define any member by comparing
2072 // its size to an ad-hoc EmptyOp.
2073 class EmptyOp : public Op<EmptyOp, Traits...> {};
2074 return sizeof(ConcreteType) == sizeof(EmptyOp);
2075 }
2076
2077 /// Allow access to internal implementation methods.
2078 friend RegisteredOperationName;
2079};
2080
2081/// This class represents the base of an operation interface. See the definition
2082/// of `detail::Interface` for requirements on the `Traits` type.
2083template <typename ConcreteType, typename Traits>
2084class OpInterface
2085 : public detail::Interface<ConcreteType, Operation *, Traits,
2086 Op<ConcreteType>, OpTrait::TraitBase> {
2087public:
2088 using Base = OpInterface<ConcreteType, Traits>;
2089 using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits,
2090 Op<ConcreteType>, OpTrait::TraitBase>;
2091
2092 /// Inherit the base class constructor.
2093 using InterfaceBase::InterfaceBase;
2094
2095protected:
2096 /// Returns the impl interface instance for the given operation.
2097 static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
2098 OperationName name = op->getName();
2099
2100#ifndef NDEBUG
2101 // Check that the current interface isn't an unresolved promise for the
2102 // given operation.
2103 if (Dialect *dialect = name.getDialect()) {
2104 dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
2105 dialect&: *dialect, interfaceRequestorID: name.getTypeID(), interfaceID: ConcreteType::getInterfaceID(),
2106 interfaceName: llvm::getTypeName<ConcreteType>());
2107 }
2108#endif
2109
2110 // Access the raw interface from the operation info.
2111 if (std::optional<RegisteredOperationName> rInfo =
2112 name.getRegisteredInfo()) {
2113 if (auto *opIface = rInfo->getInterface<ConcreteType>())
2114 return opIface;
2115 // Fallback to the dialect to provide it with a chance to implement this
2116 // interface for this operation.
2117 return rInfo->getDialect().getRegisteredInterfaceForOp<ConcreteType>(
2118 op->getName());
2119 }
2120 // Fallback to the dialect to provide it with a chance to implement this
2121 // interface for this operation.
2122 if (Dialect *dialect = name.getDialect())
2123 return dialect->getRegisteredInterfaceForOp<ConcreteType>(name);
2124 return nullptr;
2125 }
2126
2127 /// Allow access to `getInterfaceFor`.
2128 friend InterfaceBase;
2129};
2130
2131} // namespace mlir
2132
2133namespace llvm {
2134
2135template <typename T>
2136struct DenseMapInfo<T,
2137 std::enable_if_t<std::is_base_of<mlir::OpState, T>::value &&
2138 !mlir::detail::IsInterface<T>::value>> {
2139 static inline T getEmptyKey() {
2140 auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
2141 return T::getFromOpaquePointer(pointer);
2142 }
2143 static inline T getTombstoneKey() {
2144 auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
2145 return T::getFromOpaquePointer(pointer);
2146 }
2147 static unsigned getHashValue(T val) {
2148 return hash_value(val.getAsOpaquePointer());
2149 }
2150 static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
2151};
2152} // namespace llvm
2153
2154#endif
2155

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/include/mlir/IR/OpDefinition.h