1 | //===- CommonFolders.h - Common Operation Folders----------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This header file declares various common operation folders. These folders |
10 | // are intended to be used by dialects to support common folding behavior |
11 | // without requiring each dialect to provide its own implementation. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #ifndef MLIR_DIALECT_COMMONFOLDERS_H |
16 | #define MLIR_DIALECT_COMMONFOLDERS_H |
17 | |
18 | #include "mlir/IR/BuiltinAttributes.h" |
19 | #include "mlir/IR/BuiltinTypes.h" |
20 | #include "llvm/ADT/ArrayRef.h" |
21 | #include "llvm/ADT/STLExtras.h" |
22 | #include <optional> |
23 | |
24 | namespace mlir { |
25 | namespace ub { |
26 | class PoisonAttr; |
27 | } |
28 | /// Performs constant folding `calculate` with element-wise behavior on the two |
29 | /// attributes in `operands` and returns the result if possible. |
30 | /// Uses `resultType` for the type of the returned attribute. |
31 | /// Optional PoisonAttr template argument allows to specify 'poison' attribute |
32 | /// which will be directly propagated to result. |
33 | template <class AttrElementT, |
34 | class ElementValueT = typename AttrElementT::ValueType, |
35 | class PoisonAttr = ub::PoisonAttr, |
36 | class CalculationT = function_ref< |
37 | std::optional<ElementValueT>(ElementValueT, ElementValueT)>> |
38 | Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands, |
39 | Type resultType, |
40 | CalculationT &&calculate) { |
41 | assert(operands.size() == 2 && "binary op takes two operands" ); |
42 | static_assert( |
43 | std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>, |
44 | "PoisonAttr is undefined, either add a dependency on UB dialect or pass " |
45 | "void as template argument to opt-out from poison semantics." ); |
46 | if constexpr (!std::is_void_v<PoisonAttr>) { |
47 | if (isa_and_nonnull<PoisonAttr>(operands[0])) |
48 | return operands[0]; |
49 | |
50 | if (isa_and_nonnull<PoisonAttr>(operands[1])) |
51 | return operands[1]; |
52 | } |
53 | |
54 | if (!resultType || !operands[0] || !operands[1]) |
55 | return {}; |
56 | |
57 | if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1])) { |
58 | auto lhs = cast<AttrElementT>(operands[0]); |
59 | auto rhs = cast<AttrElementT>(operands[1]); |
60 | if (lhs.getType() != rhs.getType()) |
61 | return {}; |
62 | |
63 | auto calRes = calculate(lhs.getValue(), rhs.getValue()); |
64 | |
65 | if (!calRes) |
66 | return {}; |
67 | |
68 | return AttrElementT::get(resultType, *calRes); |
69 | } |
70 | |
71 | if (isa<SplatElementsAttr>(Val: operands[0]) && |
72 | isa<SplatElementsAttr>(Val: operands[1])) { |
73 | // Both operands are splats so we can avoid expanding the values out and |
74 | // just fold based on the splat value. |
75 | auto lhs = cast<SplatElementsAttr>(Val: operands[0]); |
76 | auto rhs = cast<SplatElementsAttr>(Val: operands[1]); |
77 | if (lhs.getType() != rhs.getType()) |
78 | return {}; |
79 | |
80 | auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(), |
81 | rhs.getSplatValue<ElementValueT>()); |
82 | if (!elementResult) |
83 | return {}; |
84 | |
85 | return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult); |
86 | } |
87 | |
88 | if (isa<ElementsAttr>(Val: operands[0]) && isa<ElementsAttr>(Val: operands[1])) { |
89 | // Operands are ElementsAttr-derived; perform an element-wise fold by |
90 | // expanding the values. |
91 | auto lhs = cast<ElementsAttr>(operands[0]); |
92 | auto rhs = cast<ElementsAttr>(operands[1]); |
93 | if (lhs.getType() != rhs.getType()) |
94 | return {}; |
95 | |
96 | auto maybeLhsIt = lhs.try_value_begin<ElementValueT>(); |
97 | auto maybeRhsIt = rhs.try_value_begin<ElementValueT>(); |
98 | if (!maybeLhsIt || !maybeRhsIt) |
99 | return {}; |
100 | auto lhsIt = *maybeLhsIt; |
101 | auto rhsIt = *maybeRhsIt; |
102 | SmallVector<ElementValueT, 4> elementResults; |
103 | elementResults.reserve(lhs.getNumElements()); |
104 | for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) { |
105 | auto elementResult = calculate(*lhsIt, *rhsIt); |
106 | if (!elementResult) |
107 | return {}; |
108 | elementResults.push_back(*elementResult); |
109 | } |
110 | |
111 | return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults); |
112 | } |
113 | return {}; |
114 | } |
115 | |
116 | /// Performs constant folding `calculate` with element-wise behavior on the two |
117 | /// attributes in `operands` and returns the result if possible. |
118 | /// Uses the operand element type for the element type of the returned |
119 | /// attribute. |
120 | /// Optional PoisonAttr template argument allows to specify 'poison' attribute |
121 | /// which will be directly propagated to result. |
122 | template <class AttrElementT, |
123 | class ElementValueT = typename AttrElementT::ValueType, |
124 | class PoisonAttr = ub::PoisonAttr, |
125 | class CalculationT = function_ref< |
126 | std::optional<ElementValueT>(ElementValueT, ElementValueT)>> |
127 | Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands, |
128 | CalculationT &&calculate) { |
129 | assert(operands.size() == 2 && "binary op takes two operands" ); |
130 | static_assert( |
131 | std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>, |
132 | "PoisonAttr is undefined, either add a dependency on UB dialect or pass " |
133 | "void as template argument to opt-out from poison semantics." ); |
134 | if constexpr (!std::is_void_v<PoisonAttr>) { |
135 | if (isa_and_nonnull<PoisonAttr>(operands[0])) |
136 | return operands[0]; |
137 | |
138 | if (isa_and_nonnull<PoisonAttr>(operands[1])) |
139 | return operands[1]; |
140 | } |
141 | |
142 | auto getResultType = [](Attribute attr) -> Type { |
143 | if (auto typed = dyn_cast_or_null<TypedAttr>(attr)) |
144 | return typed.getType(); |
145 | return {}; |
146 | }; |
147 | |
148 | Type lhsType = getResultType(operands[0]); |
149 | Type rhsType = getResultType(operands[1]); |
150 | if (!lhsType || !rhsType) |
151 | return {}; |
152 | if (lhsType != rhsType) |
153 | return {}; |
154 | |
155 | return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr, |
156 | CalculationT>( |
157 | operands, lhsType, std::forward<CalculationT>(calculate)); |
158 | } |
159 | |
160 | template <class AttrElementT, |
161 | class ElementValueT = typename AttrElementT::ValueType, |
162 | class PoisonAttr = void, |
163 | class CalculationT = |
164 | function_ref<ElementValueT(ElementValueT, ElementValueT)>> |
165 | Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType, |
166 | CalculationT &&calculate) { |
167 | return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>( |
168 | operands, resultType, |
169 | [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> { |
170 | return calculate(a, b); |
171 | }); |
172 | } |
173 | |
174 | template <class AttrElementT, |
175 | class ElementValueT = typename AttrElementT::ValueType, |
176 | class PoisonAttr = ub::PoisonAttr, |
177 | class CalculationT = |
178 | function_ref<ElementValueT(ElementValueT, ElementValueT)>> |
179 | Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, |
180 | CalculationT &&calculate) { |
181 | return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>( |
182 | operands, |
183 | [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> { |
184 | return calculate(a, b); |
185 | }); |
186 | } |
187 | |
188 | /// Performs constant folding `calculate` with element-wise behavior on the one |
189 | /// attributes in `operands` and returns the result if possible. |
190 | /// Optional PoisonAttr template argument allows to specify 'poison' attribute |
191 | /// which will be directly propagated to result. |
192 | template <class AttrElementT, |
193 | class ElementValueT = typename AttrElementT::ValueType, |
194 | class PoisonAttr = ub::PoisonAttr, |
195 | class CalculationT = |
196 | function_ref<std::optional<ElementValueT>(ElementValueT)>> |
197 | Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands, |
198 | CalculationT &&calculate) { |
199 | assert(operands.size() == 1 && "unary op takes one operands" ); |
200 | if (!operands[0]) |
201 | return {}; |
202 | |
203 | static_assert( |
204 | std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>, |
205 | "PoisonAttr is undefined, either add a dependency on UB dialect or pass " |
206 | "void as template argument to opt-out from poison semantics." ); |
207 | if constexpr (!std::is_void_v<PoisonAttr>) { |
208 | if (isa<PoisonAttr>(operands[0])) |
209 | return operands[0]; |
210 | } |
211 | |
212 | if (isa<AttrElementT>(operands[0])) { |
213 | auto op = cast<AttrElementT>(operands[0]); |
214 | |
215 | auto res = calculate(op.getValue()); |
216 | if (!res) |
217 | return {}; |
218 | return AttrElementT::get(op.getType(), *res); |
219 | } |
220 | if (isa<SplatElementsAttr>(Val: operands[0])) { |
221 | // Both operands are splats so we can avoid expanding the values out and |
222 | // just fold based on the splat value. |
223 | auto op = cast<SplatElementsAttr>(Val: operands[0]); |
224 | |
225 | auto elementResult = calculate(op.getSplatValue<ElementValueT>()); |
226 | if (!elementResult) |
227 | return {}; |
228 | return DenseElementsAttr::get(op.getType(), *elementResult); |
229 | } else if (isa<ElementsAttr>(Val: operands[0])) { |
230 | // Operands are ElementsAttr-derived; perform an element-wise fold by |
231 | // expanding the values. |
232 | auto op = cast<ElementsAttr>(operands[0]); |
233 | |
234 | auto maybeOpIt = op.try_value_begin<ElementValueT>(); |
235 | if (!maybeOpIt) |
236 | return {}; |
237 | auto opIt = *maybeOpIt; |
238 | SmallVector<ElementValueT> elementResults; |
239 | elementResults.reserve(op.getNumElements()); |
240 | for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) { |
241 | auto elementResult = calculate(*opIt); |
242 | if (!elementResult) |
243 | return {}; |
244 | elementResults.push_back(*elementResult); |
245 | } |
246 | return DenseElementsAttr::get(op.getShapedType(), elementResults); |
247 | } |
248 | return {}; |
249 | } |
250 | |
251 | template <class AttrElementT, |
252 | class ElementValueT = typename AttrElementT::ValueType, |
253 | class PoisonAttr = ub::PoisonAttr, |
254 | class CalculationT = function_ref<ElementValueT(ElementValueT)>> |
255 | Attribute constFoldUnaryOp(ArrayRef<Attribute> operands, |
256 | CalculationT &&calculate) { |
257 | return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>( |
258 | operands, [&](ElementValueT a) -> std::optional<ElementValueT> { |
259 | return calculate(a); |
260 | }); |
261 | } |
262 | |
263 | template < |
264 | class AttrElementT, class TargetAttrElementT, |
265 | class ElementValueT = typename AttrElementT::ValueType, |
266 | class TargetElementValueT = typename TargetAttrElementT::ValueType, |
267 | class PoisonAttr = ub::PoisonAttr, |
268 | class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>> |
269 | Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType, |
270 | CalculationT &&calculate) { |
271 | assert(operands.size() == 1 && "Cast op takes one operand" ); |
272 | if (!operands[0]) |
273 | return {}; |
274 | |
275 | static_assert( |
276 | std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>, |
277 | "PoisonAttr is undefined, either add a dependency on UB dialect or pass " |
278 | "void as template argument to opt-out from poison semantics." ); |
279 | if constexpr (!std::is_void_v<PoisonAttr>) { |
280 | if (isa<PoisonAttr>(operands[0])) |
281 | return operands[0]; |
282 | } |
283 | |
284 | if (isa<AttrElementT>(operands[0])) { |
285 | auto op = cast<AttrElementT>(operands[0]); |
286 | bool castStatus = true; |
287 | auto res = calculate(op.getValue(), castStatus); |
288 | if (!castStatus) |
289 | return {}; |
290 | return TargetAttrElementT::get(resType, res); |
291 | } |
292 | if (isa<SplatElementsAttr>(Val: operands[0])) { |
293 | // The operand is a splat so we can avoid expanding the values out and |
294 | // just fold based on the splat value. |
295 | auto op = cast<SplatElementsAttr>(Val: operands[0]); |
296 | bool castStatus = true; |
297 | auto elementResult = |
298 | calculate(op.getSplatValue<ElementValueT>(), castStatus); |
299 | if (!castStatus) |
300 | return {}; |
301 | return DenseElementsAttr::get(cast<ShapedType>(resType), elementResult); |
302 | } |
303 | if (auto op = dyn_cast<ElementsAttr>(operands[0])) { |
304 | // Operand is ElementsAttr-derived; perform an element-wise fold by |
305 | // expanding the value. |
306 | bool castStatus = true; |
307 | auto maybeOpIt = op.try_value_begin<ElementValueT>(); |
308 | if (!maybeOpIt) |
309 | return {}; |
310 | auto opIt = *maybeOpIt; |
311 | SmallVector<TargetElementValueT> elementResults; |
312 | elementResults.reserve(op.getNumElements()); |
313 | for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) { |
314 | auto elt = calculate(*opIt, castStatus); |
315 | if (!castStatus) |
316 | return {}; |
317 | elementResults.push_back(elt); |
318 | } |
319 | |
320 | return DenseElementsAttr::get(cast<ShapedType>(resType), elementResults); |
321 | } |
322 | return {}; |
323 | } |
324 | } // namespace mlir |
325 | |
326 | #endif // MLIR_DIALECT_COMMONFOLDERS_H |
327 | |