1//===- HomomorphismSimplification.h -----------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
10#define MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
11
12#include "mlir/IR/IRMapping.h"
13#include "mlir/IR/PatternMatch.h"
14#include "mlir/IR/Value.h"
15#include "mlir/Support/LLVM.h"
16#include "mlir/Support/LogicalResult.h"
17#include "llvm/ADT/SmallVector.h"
18#include "llvm/Support/Casting.h"
19#include <iterator>
20#include <optional>
21#include <type_traits>
22#include <utility>
23
24namespace mlir {
25
26// If `h` is an homomorphism with respect to the source algebraic structure
27// induced by function `s` and the target algebraic structure induced by
28// function `t`, transforms `s(h(x1), h(x2) ..., h(xn))` into
29// `h(t(x1, x2, ..., xn))`.
30//
31// Functors:
32// ---------
33// `GetHomomorphismOpOperandFn`: `(Operation*) -> OpOperand*`
34// Returns the operand relevant to the homomorphism.
35// There may be other operands that are not relevant.
36//
37// `GetHomomorphismOpResultFn`: `(Operation*) -> OpResult`
38// Returns the result relevant to the homomorphism.
39//
40// `GetSourceAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) ->
41// void` Populates into the vector the operands relevant to the homomorphism.
42//
43// `GetSourceAlgebraicOpResultFn`: `(Operation*) -> OpResult`
44// Return the result of the source algebraic operation relevant to the
45// homomorphism.
46//
47// `GetTargetAlgebraicOpResultFn`: `(Operation*) -> OpResult`
48// Return the result of the target algebraic operation relevant to the
49// homomorphism.
50//
51// `IsHomomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool`
52// Check if the operation is an homomorphism of the required type.
53// Additionally if the optional is present checks if the operations are
54// compatible homomorphisms.
55//
56// `IsSourceAlgebraicOpFn`: `(Operation*) -> bool`
57// Check if the operation is an operation of the algebraic structure.
58//
59// `CreateTargetAlgebraicOpFn`: `(Operation*, IRMapping& operandsRemapping,
60// PatternRewriter &rewriter) -> Operation*`
61template <typename GetHomomorphismOpOperandFn,
62 typename GetHomomorphismOpResultFn,
63 typename GetSourceAlgebraicOpOperandsFn,
64 typename GetSourceAlgebraicOpResultFn,
65 typename GetTargetAlgebraicOpResultFn, typename IsHomomorphismOpFn,
66 typename IsSourceAlgebraicOpFn, typename CreateTargetAlgebraicOpFn>
67struct HomomorphismSimplification : public RewritePattern {
68 template <typename GetHomomorphismOpOperandFnArg,
69 typename GetHomomorphismOpResultFnArg,
70 typename GetSourceAlgebraicOpOperandsFnArg,
71 typename GetSourceAlgebraicOpResultFnArg,
72 typename GetTargetAlgebraicOpResultFnArg,
73 typename IsHomomorphismOpFnArg, typename IsSourceAlgebraicOpFnArg,
74 typename CreateTargetAlgebraicOpFnArg,
75 typename... RewritePatternArgs>
76 HomomorphismSimplification(
77 GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand,
78 GetHomomorphismOpResultFnArg &&getHomomorphismOpResult,
79 GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands,
80 GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult,
81 GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult,
82 IsHomomorphismOpFnArg &&isHomomorphismOp,
83 IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp,
84 CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn,
85 RewritePatternArgs &&...args)
86 : RewritePattern(std::forward<RewritePatternArgs>(args)...),
87 getHomomorphismOpOperand(std::forward<GetHomomorphismOpOperandFnArg>(
88 getHomomorphismOpOperand)),
89 getHomomorphismOpResult(std::forward<GetHomomorphismOpResultFnArg>(
90 getHomomorphismOpResult)),
91 getSourceAlgebraicOpOperands(
92 std::forward<GetSourceAlgebraicOpOperandsFnArg>(
93 getSourceAlgebraicOpOperands)),
94 getSourceAlgebraicOpResult(
95 std::forward<GetSourceAlgebraicOpResultFnArg>(
96 getSourceAlgebraicOpResult)),
97 getTargetAlgebraicOpResult(
98 std::forward<GetTargetAlgebraicOpResultFnArg>(
99 getTargetAlgebraicOpResult)),
100 isHomomorphismOp(std::forward<IsHomomorphismOpFnArg>(isHomomorphismOp)),
101 isSourceAlgebraicOp(
102 std::forward<IsSourceAlgebraicOpFnArg>(isSourceAlgebraicOp)),
103 createTargetAlgebraicOpFn(std::forward<CreateTargetAlgebraicOpFnArg>(
104 createTargetAlgebraicOpFn)) {}
105
106 LogicalResult matchAndRewrite(Operation *op,
107 PatternRewriter &rewriter) const override {
108 SmallVector<OpOperand *> algebraicOpOperands;
109 if (failed(matchOp(sourceAlgebraicOp: op, sourceAlgebraicOpOperands&: algebraicOpOperands))) {
110 return failure();
111 }
112 return rewriteOp(sourceAlgebraicOp: op, sourceAlgebraicOpOperands: algebraicOpOperands, rewriter);
113 }
114
115private:
116 LogicalResult
117 matchOp(Operation *sourceAlgebraicOp,
118 SmallVector<OpOperand *> &sourceAlgebraicOpOperands) const {
119 if (!isSourceAlgebraicOp(sourceAlgebraicOp)) {
120 return failure();
121 }
122 sourceAlgebraicOpOperands.clear();
123 getSourceAlgebraicOpOperands(sourceAlgebraicOp, sourceAlgebraicOpOperands);
124 if (sourceAlgebraicOpOperands.empty()) {
125 return failure();
126 }
127
128 Operation *firstHomomorphismOp =
129 sourceAlgebraicOpOperands.front()->get().getDefiningOp();
130 if (!firstHomomorphismOp ||
131 !isHomomorphismOp(firstHomomorphismOp, std::nullopt)) {
132 return failure();
133 }
134 OpResult firstHomomorphismOpResult =
135 getHomomorphismOpResult(firstHomomorphismOp);
136 if (firstHomomorphismOpResult != sourceAlgebraicOpOperands.front()->get()) {
137 return failure();
138 }
139
140 for (auto operand : sourceAlgebraicOpOperands) {
141 Operation *homomorphismOp = operand->get().getDefiningOp();
142 if (!homomorphismOp ||
143 !isHomomorphismOp(homomorphismOp, firstHomomorphismOp)) {
144 return failure();
145 }
146 }
147 return success();
148 }
149
150 LogicalResult
151 rewriteOp(Operation *sourceAlgebraicOp,
152 const SmallVector<OpOperand *> &sourceAlgebraicOpOperands,
153 PatternRewriter &rewriter) const {
154 IRMapping irMapping;
155 for (auto operand : sourceAlgebraicOpOperands) {
156 Operation *homomorphismOp = operand->get().getDefiningOp();
157 irMapping.map(operand->get(),
158 getHomomorphismOpOperand(homomorphismOp)->get());
159 }
160 Operation *targetAlgebraicOp =
161 createTargetAlgebraicOpFn(sourceAlgebraicOp, irMapping, rewriter);
162
163 irMapping.clear();
164 assert(!sourceAlgebraicOpOperands.empty());
165 Operation *firstHomomorphismOp =
166 sourceAlgebraicOpOperands[0]->get().getDefiningOp();
167 irMapping.map(getHomomorphismOpOperand(firstHomomorphismOp)->get(),
168 getTargetAlgebraicOpResult(targetAlgebraicOp));
169 Operation *newHomomorphismOp =
170 rewriter.clone(op&: *firstHomomorphismOp, mapper&: irMapping);
171 rewriter.replaceAllUsesWith(getSourceAlgebraicOpResult(sourceAlgebraicOp),
172 getHomomorphismOpResult(newHomomorphismOp));
173 return success();
174 }
175
176 GetHomomorphismOpOperandFn getHomomorphismOpOperand;
177 GetHomomorphismOpResultFn getHomomorphismOpResult;
178 GetSourceAlgebraicOpOperandsFn getSourceAlgebraicOpOperands;
179 GetSourceAlgebraicOpResultFn getSourceAlgebraicOpResult;
180 GetTargetAlgebraicOpResultFn getTargetAlgebraicOpResult;
181 IsHomomorphismOpFn isHomomorphismOp;
182 IsSourceAlgebraicOpFn isSourceAlgebraicOp;
183 CreateTargetAlgebraicOpFn createTargetAlgebraicOpFn;
184};
185
186} // namespace mlir
187
188#endif // MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
189

source code of mlir/include/mlir/Transforms/HomomorphismSimplification.h