1 | //===- HomomorphismSimplification.h -----------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_ |
10 | #define MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_ |
11 | |
12 | #include "mlir/IR/IRMapping.h" |
13 | #include "mlir/IR/PatternMatch.h" |
14 | #include "mlir/IR/Value.h" |
15 | #include "mlir/Support/LLVM.h" |
16 | #include "mlir/Support/LogicalResult.h" |
17 | #include "llvm/ADT/SmallVector.h" |
18 | #include "llvm/Support/Casting.h" |
19 | #include <iterator> |
20 | #include <optional> |
21 | #include <type_traits> |
22 | #include <utility> |
23 | |
24 | namespace mlir { |
25 | |
26 | // If `h` is an homomorphism with respect to the source algebraic structure |
27 | // induced by function `s` and the target algebraic structure induced by |
28 | // function `t`, transforms `s(h(x1), h(x2) ..., h(xn))` into |
29 | // `h(t(x1, x2, ..., xn))`. |
30 | // |
31 | // Functors: |
32 | // --------- |
33 | // `GetHomomorphismOpOperandFn`: `(Operation*) -> OpOperand*` |
34 | // Returns the operand relevant to the homomorphism. |
35 | // There may be other operands that are not relevant. |
36 | // |
37 | // `GetHomomorphismOpResultFn`: `(Operation*) -> OpResult` |
38 | // Returns the result relevant to the homomorphism. |
39 | // |
40 | // `GetSourceAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) -> |
41 | // void` Populates into the vector the operands relevant to the homomorphism. |
42 | // |
43 | // `GetSourceAlgebraicOpResultFn`: `(Operation*) -> OpResult` |
44 | // Return the result of the source algebraic operation relevant to the |
45 | // homomorphism. |
46 | // |
47 | // `GetTargetAlgebraicOpResultFn`: `(Operation*) -> OpResult` |
48 | // Return the result of the target algebraic operation relevant to the |
49 | // homomorphism. |
50 | // |
51 | // `IsHomomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool` |
52 | // Check if the operation is an homomorphism of the required type. |
53 | // Additionally if the optional is present checks if the operations are |
54 | // compatible homomorphisms. |
55 | // |
56 | // `IsSourceAlgebraicOpFn`: `(Operation*) -> bool` |
57 | // Check if the operation is an operation of the algebraic structure. |
58 | // |
59 | // `CreateTargetAlgebraicOpFn`: `(Operation*, IRMapping& operandsRemapping, |
60 | // PatternRewriter &rewriter) -> Operation*` |
61 | template <typename GetHomomorphismOpOperandFn, |
62 | typename GetHomomorphismOpResultFn, |
63 | typename GetSourceAlgebraicOpOperandsFn, |
64 | typename GetSourceAlgebraicOpResultFn, |
65 | typename GetTargetAlgebraicOpResultFn, typename IsHomomorphismOpFn, |
66 | typename IsSourceAlgebraicOpFn, typename CreateTargetAlgebraicOpFn> |
67 | struct HomomorphismSimplification : public RewritePattern { |
68 | template <typename GetHomomorphismOpOperandFnArg, |
69 | typename GetHomomorphismOpResultFnArg, |
70 | typename GetSourceAlgebraicOpOperandsFnArg, |
71 | typename GetSourceAlgebraicOpResultFnArg, |
72 | typename GetTargetAlgebraicOpResultFnArg, |
73 | typename IsHomomorphismOpFnArg, typename IsSourceAlgebraicOpFnArg, |
74 | typename CreateTargetAlgebraicOpFnArg, |
75 | typename... RewritePatternArgs> |
76 | HomomorphismSimplification( |
77 | GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand, |
78 | GetHomomorphismOpResultFnArg &&getHomomorphismOpResult, |
79 | GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands, |
80 | GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult, |
81 | GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult, |
82 | IsHomomorphismOpFnArg &&isHomomorphismOp, |
83 | IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp, |
84 | CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn, |
85 | RewritePatternArgs &&...args) |
86 | : RewritePattern(std::forward<RewritePatternArgs>(args)...), |
87 | getHomomorphismOpOperand(std::forward<GetHomomorphismOpOperandFnArg>( |
88 | getHomomorphismOpOperand)), |
89 | getHomomorphismOpResult(std::forward<GetHomomorphismOpResultFnArg>( |
90 | getHomomorphismOpResult)), |
91 | getSourceAlgebraicOpOperands( |
92 | std::forward<GetSourceAlgebraicOpOperandsFnArg>( |
93 | getSourceAlgebraicOpOperands)), |
94 | getSourceAlgebraicOpResult( |
95 | std::forward<GetSourceAlgebraicOpResultFnArg>( |
96 | getSourceAlgebraicOpResult)), |
97 | getTargetAlgebraicOpResult( |
98 | std::forward<GetTargetAlgebraicOpResultFnArg>( |
99 | getTargetAlgebraicOpResult)), |
100 | isHomomorphismOp(std::forward<IsHomomorphismOpFnArg>(isHomomorphismOp)), |
101 | isSourceAlgebraicOp( |
102 | std::forward<IsSourceAlgebraicOpFnArg>(isSourceAlgebraicOp)), |
103 | createTargetAlgebraicOpFn(std::forward<CreateTargetAlgebraicOpFnArg>( |
104 | createTargetAlgebraicOpFn)) {} |
105 | |
106 | LogicalResult matchAndRewrite(Operation *op, |
107 | PatternRewriter &rewriter) const override { |
108 | SmallVector<OpOperand *> algebraicOpOperands; |
109 | if (failed(matchOp(sourceAlgebraicOp: op, sourceAlgebraicOpOperands&: algebraicOpOperands))) { |
110 | return failure(); |
111 | } |
112 | return rewriteOp(sourceAlgebraicOp: op, sourceAlgebraicOpOperands: algebraicOpOperands, rewriter); |
113 | } |
114 | |
115 | private: |
116 | LogicalResult |
117 | matchOp(Operation *sourceAlgebraicOp, |
118 | SmallVector<OpOperand *> &sourceAlgebraicOpOperands) const { |
119 | if (!isSourceAlgebraicOp(sourceAlgebraicOp)) { |
120 | return failure(); |
121 | } |
122 | sourceAlgebraicOpOperands.clear(); |
123 | getSourceAlgebraicOpOperands(sourceAlgebraicOp, sourceAlgebraicOpOperands); |
124 | if (sourceAlgebraicOpOperands.empty()) { |
125 | return failure(); |
126 | } |
127 | |
128 | Operation *firstHomomorphismOp = |
129 | sourceAlgebraicOpOperands.front()->get().getDefiningOp(); |
130 | if (!firstHomomorphismOp || |
131 | !isHomomorphismOp(firstHomomorphismOp, std::nullopt)) { |
132 | return failure(); |
133 | } |
134 | OpResult firstHomomorphismOpResult = |
135 | getHomomorphismOpResult(firstHomomorphismOp); |
136 | if (firstHomomorphismOpResult != sourceAlgebraicOpOperands.front()->get()) { |
137 | return failure(); |
138 | } |
139 | |
140 | for (auto operand : sourceAlgebraicOpOperands) { |
141 | Operation *homomorphismOp = operand->get().getDefiningOp(); |
142 | if (!homomorphismOp || |
143 | !isHomomorphismOp(homomorphismOp, firstHomomorphismOp)) { |
144 | return failure(); |
145 | } |
146 | } |
147 | return success(); |
148 | } |
149 | |
150 | LogicalResult |
151 | rewriteOp(Operation *sourceAlgebraicOp, |
152 | const SmallVector<OpOperand *> &sourceAlgebraicOpOperands, |
153 | PatternRewriter &rewriter) const { |
154 | IRMapping irMapping; |
155 | for (auto operand : sourceAlgebraicOpOperands) { |
156 | Operation *homomorphismOp = operand->get().getDefiningOp(); |
157 | irMapping.map(operand->get(), |
158 | getHomomorphismOpOperand(homomorphismOp)->get()); |
159 | } |
160 | Operation *targetAlgebraicOp = |
161 | createTargetAlgebraicOpFn(sourceAlgebraicOp, irMapping, rewriter); |
162 | |
163 | irMapping.clear(); |
164 | assert(!sourceAlgebraicOpOperands.empty()); |
165 | Operation *firstHomomorphismOp = |
166 | sourceAlgebraicOpOperands[0]->get().getDefiningOp(); |
167 | irMapping.map(getHomomorphismOpOperand(firstHomomorphismOp)->get(), |
168 | getTargetAlgebraicOpResult(targetAlgebraicOp)); |
169 | Operation *newHomomorphismOp = |
170 | rewriter.clone(op&: *firstHomomorphismOp, mapper&: irMapping); |
171 | rewriter.replaceAllUsesWith(getSourceAlgebraicOpResult(sourceAlgebraicOp), |
172 | getHomomorphismOpResult(newHomomorphismOp)); |
173 | return success(); |
174 | } |
175 | |
176 | GetHomomorphismOpOperandFn getHomomorphismOpOperand; |
177 | GetHomomorphismOpResultFn getHomomorphismOpResult; |
178 | GetSourceAlgebraicOpOperandsFn getSourceAlgebraicOpOperands; |
179 | GetSourceAlgebraicOpResultFn getSourceAlgebraicOpResult; |
180 | GetTargetAlgebraicOpResultFn getTargetAlgebraicOpResult; |
181 | IsHomomorphismOpFn isHomomorphismOp; |
182 | IsSourceAlgebraicOpFn isSourceAlgebraicOp; |
183 | CreateTargetAlgebraicOpFn createTargetAlgebraicOpFn; |
184 | }; |
185 | |
186 | } // namespace mlir |
187 | |
188 | #endif // MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_ |
189 | |