1 | //===- EndomorphismSimplification.h -----------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_ |
10 | #define MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_ |
11 | |
12 | #include "mlir/Transforms/HomomorphismSimplification.h" |
13 | |
14 | namespace mlir { |
15 | |
16 | namespace detail { |
17 | struct CreateAlgebraicOpForEndomorphismSimplification { |
18 | Operation *operator()(Operation *op, IRMapping &operandsRemapping, |
19 | PatternRewriter &rewriter) const { |
20 | return rewriter.clone(op&: *op, mapper&: operandsRemapping); |
21 | } |
22 | }; |
23 | } // namespace detail |
24 | |
25 | // If `f` is an endomorphism with respect to the algebraic structure induced by |
26 | // function `g`, transforms `g(f(x1), f(x2) ..., f(xn))` into |
27 | // `f(g(x1, x2, ..., xn))`. |
28 | // `g` is the algebraic operation and `f` is the endomorphism. |
29 | // |
30 | // Functors: |
31 | // --------- |
32 | // `GetEndomorphismOpOperandFn`: `(Operation*) -> OpOperand*` |
33 | // Returns the operand relevant to the endomorphism. |
34 | // There may be other operands that are not relevant. |
35 | // |
36 | // `GetEndomorphismOpResultFn`: `(Operation*) -> OpResult` |
37 | // Returns the result relevant to the endomorphism. |
38 | // |
39 | // `GetAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) -> void` |
40 | // Populates into the vector the operands relevant to the endomorphism. |
41 | // |
42 | // `GetAlgebraicOpResultFn`: `(Operation*) -> OpResult` |
43 | // Return the result relevant to the endomorphism. |
44 | // |
45 | // `IsEndomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool` |
46 | // Check if the operation is an endomorphism of the required type. |
47 | // Additionally if the optional is present checks if the operations are |
48 | // compatible endomorphisms. |
49 | // |
50 | // `IsAlgebraicOpFn`: `(Operation*) -> bool` |
51 | // Check if the operation is an operation of the algebraic structure. |
52 | template <typename GetEndomorphismOpOperandFn, |
53 | typename GetEndomorphismOpResultFn, typename GetAlgebraicOpOperandsFn, |
54 | typename GetAlgebraicOpResultFn, typename IsEndomorphismOpFn, |
55 | typename IsAlgebraicOpFn> |
56 | struct EndomorphismSimplification |
57 | : HomomorphismSimplification< |
58 | GetEndomorphismOpOperandFn, GetEndomorphismOpResultFn, |
59 | GetAlgebraicOpOperandsFn, GetAlgebraicOpResultFn, |
60 | GetAlgebraicOpResultFn, IsEndomorphismOpFn, IsAlgebraicOpFn, |
61 | detail::CreateAlgebraicOpForEndomorphismSimplification> { |
62 | template <typename GetEndomorphismOpOperandFnArg, |
63 | typename GetEndomorphismOpResultFnArg, |
64 | typename GetAlgebraicOpOperandsFnArg, |
65 | typename GetAlgebraicOpResultFnArg, typename IsEndomorphismOpFnArg, |
66 | typename IsAlgebraicOpFnArg, typename... RewritePatternArgs> |
67 | EndomorphismSimplification( |
68 | GetEndomorphismOpOperandFnArg &&getEndomorphismOpOperand, |
69 | GetEndomorphismOpResultFnArg &&getEndomorphismOpResult, |
70 | GetAlgebraicOpOperandsFnArg &&getAlgebraicOpOperands, |
71 | GetAlgebraicOpResultFnArg &&getAlgebraicOpResult, |
72 | IsEndomorphismOpFnArg &&isEndomorphismOp, |
73 | IsAlgebraicOpFnArg &&isAlgebraicOp, RewritePatternArgs &&...args) |
74 | : HomomorphismSimplification< |
75 | GetEndomorphismOpOperandFn, GetEndomorphismOpResultFn, |
76 | GetAlgebraicOpOperandsFn, GetAlgebraicOpResultFn, |
77 | GetAlgebraicOpResultFn, IsEndomorphismOpFn, IsAlgebraicOpFn, |
78 | detail::CreateAlgebraicOpForEndomorphismSimplification>( |
79 | std::forward<GetEndomorphismOpOperandFnArg>( |
80 | getEndomorphismOpOperand), |
81 | std::forward<GetEndomorphismOpResultFnArg>(getEndomorphismOpResult), |
82 | std::forward<GetAlgebraicOpOperandsFnArg>(getAlgebraicOpOperands), |
83 | std::forward<GetAlgebraicOpResultFnArg>(getAlgebraicOpResult), |
84 | std::forward<GetAlgebraicOpResultFnArg>(getAlgebraicOpResult), |
85 | std::forward<IsEndomorphismOpFnArg>(isEndomorphismOp), |
86 | std::forward<IsAlgebraicOpFnArg>(isAlgebraicOp), |
87 | detail::CreateAlgebraicOpForEndomorphismSimplification(), |
88 | std::forward<RewritePatternArgs>(args)...) {} |
89 | }; |
90 | |
91 | } // namespace mlir |
92 | |
93 | #endif // MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_ |
94 | |