1//===- CommonFolders.h - Common Operation Folders----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header file declares various common operation folders. These folders
10// are intended to be used by dialects to support common folding behavior
11// without requiring each dialect to provide its own implementation.
12//
13//===----------------------------------------------------------------------===//
14
15#ifndef MLIR_DIALECT_COMMONFOLDERS_H
16#define MLIR_DIALECT_COMMONFOLDERS_H
17
18#include "mlir/IR/BuiltinAttributes.h"
19#include "mlir/IR/BuiltinTypes.h"
20#include "llvm/ADT/ArrayRef.h"
21#include "llvm/ADT/STLExtras.h"
22#include <optional>
23
24namespace mlir {
25namespace ub {
26class PoisonAttr;
27}
28/// Performs constant folding `calculate` with element-wise behavior on the two
29/// attributes in `operands` and returns the result if possible.
30/// Uses `resultType` for the type of the returned attribute.
31/// Optional PoisonAttr template argument allows to specify 'poison' attribute
32/// which will be directly propagated to result.
33template <class AttrElementT,
34 class ElementValueT = typename AttrElementT::ValueType,
35 class PoisonAttr = ub::PoisonAttr,
36 class CalculationT = function_ref<
37 std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
38Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
39 Type resultType,
40 CalculationT &&calculate) {
41 assert(operands.size() == 2 && "binary op takes two operands");
42 static_assert(
43 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
44 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
45 "void as template argument to opt-out from poison semantics.");
46 if constexpr (!std::is_void_v<PoisonAttr>) {
47 if (isa_and_nonnull<PoisonAttr>(operands[0]))
48 return operands[0];
49
50 if (isa_and_nonnull<PoisonAttr>(operands[1]))
51 return operands[1];
52 }
53
54 if (!resultType || !operands[0] || !operands[1])
55 return {};
56
57 if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1])) {
58 auto lhs = cast<AttrElementT>(operands[0]);
59 auto rhs = cast<AttrElementT>(operands[1]);
60 if (lhs.getType() != rhs.getType())
61 return {};
62
63 auto calRes = calculate(lhs.getValue(), rhs.getValue());
64
65 if (!calRes)
66 return {};
67
68 return AttrElementT::get(resultType, *calRes);
69 }
70
71 if (isa<SplatElementsAttr>(Val: operands[0]) &&
72 isa<SplatElementsAttr>(Val: operands[1])) {
73 // Both operands are splats so we can avoid expanding the values out and
74 // just fold based on the splat value.
75 auto lhs = cast<SplatElementsAttr>(Val: operands[0]);
76 auto rhs = cast<SplatElementsAttr>(Val: operands[1]);
77 if (lhs.getType() != rhs.getType())
78 return {};
79
80 auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
81 rhs.getSplatValue<ElementValueT>());
82 if (!elementResult)
83 return {};
84
85 return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult);
86 }
87
88 if (isa<ElementsAttr>(Val: operands[0]) && isa<ElementsAttr>(Val: operands[1])) {
89 // Operands are ElementsAttr-derived; perform an element-wise fold by
90 // expanding the values.
91 auto lhs = cast<ElementsAttr>(operands[0]);
92 auto rhs = cast<ElementsAttr>(operands[1]);
93 if (lhs.getType() != rhs.getType())
94 return {};
95
96 auto maybeLhsIt = lhs.try_value_begin<ElementValueT>();
97 auto maybeRhsIt = rhs.try_value_begin<ElementValueT>();
98 if (!maybeLhsIt || !maybeRhsIt)
99 return {};
100 auto lhsIt = *maybeLhsIt;
101 auto rhsIt = *maybeRhsIt;
102 SmallVector<ElementValueT, 4> elementResults;
103 elementResults.reserve(lhs.getNumElements());
104 for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) {
105 auto elementResult = calculate(*lhsIt, *rhsIt);
106 if (!elementResult)
107 return {};
108 elementResults.push_back(*elementResult);
109 }
110
111 return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
112 }
113 return {};
114}
115
116/// Performs constant folding `calculate` with element-wise behavior on the two
117/// attributes in `operands` and returns the result if possible.
118/// Uses the operand element type for the element type of the returned
119/// attribute.
120/// Optional PoisonAttr template argument allows to specify 'poison' attribute
121/// which will be directly propagated to result.
122template <class AttrElementT,
123 class ElementValueT = typename AttrElementT::ValueType,
124 class PoisonAttr = ub::PoisonAttr,
125 class CalculationT = function_ref<
126 std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
127Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
128 CalculationT &&calculate) {
129 assert(operands.size() == 2 && "binary op takes two operands");
130 static_assert(
131 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
132 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
133 "void as template argument to opt-out from poison semantics.");
134 if constexpr (!std::is_void_v<PoisonAttr>) {
135 if (isa_and_nonnull<PoisonAttr>(operands[0]))
136 return operands[0];
137
138 if (isa_and_nonnull<PoisonAttr>(operands[1]))
139 return operands[1];
140 }
141
142 auto getResultType = [](Attribute attr) -> Type {
143 if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
144 return typed.getType();
145 return {};
146 };
147
148 Type lhsType = getResultType(operands[0]);
149 Type rhsType = getResultType(operands[1]);
150 if (!lhsType || !rhsType)
151 return {};
152 if (lhsType != rhsType)
153 return {};
154
155 return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
156 CalculationT>(
157 operands, lhsType, std::forward<CalculationT>(calculate));
158}
159
160template <class AttrElementT,
161 class ElementValueT = typename AttrElementT::ValueType,
162 class PoisonAttr = void,
163 class CalculationT =
164 function_ref<ElementValueT(ElementValueT, ElementValueT)>>
165Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
166 CalculationT &&calculate) {
167 return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
168 operands, resultType,
169 [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
170 return calculate(a, b);
171 });
172}
173
174template <class AttrElementT,
175 class ElementValueT = typename AttrElementT::ValueType,
176 class PoisonAttr = ub::PoisonAttr,
177 class CalculationT =
178 function_ref<ElementValueT(ElementValueT, ElementValueT)>>
179Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
180 CalculationT &&calculate) {
181 return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
182 operands,
183 [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
184 return calculate(a, b);
185 });
186}
187
188/// Performs constant folding `calculate` with element-wise behavior on the one
189/// attributes in `operands` and returns the result if possible.
190/// Optional PoisonAttr template argument allows to specify 'poison' attribute
191/// which will be directly propagated to result.
192template <class AttrElementT,
193 class ElementValueT = typename AttrElementT::ValueType,
194 class PoisonAttr = ub::PoisonAttr,
195 class CalculationT =
196 function_ref<std::optional<ElementValueT>(ElementValueT)>>
197Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
198 CalculationT &&calculate) {
199 assert(operands.size() == 1 && "unary op takes one operands");
200 if (!operands[0])
201 return {};
202
203 static_assert(
204 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
205 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
206 "void as template argument to opt-out from poison semantics.");
207 if constexpr (!std::is_void_v<PoisonAttr>) {
208 if (isa<PoisonAttr>(operands[0]))
209 return operands[0];
210 }
211
212 if (isa<AttrElementT>(operands[0])) {
213 auto op = cast<AttrElementT>(operands[0]);
214
215 auto res = calculate(op.getValue());
216 if (!res)
217 return {};
218 return AttrElementT::get(op.getType(), *res);
219 }
220 if (isa<SplatElementsAttr>(Val: operands[0])) {
221 // Both operands are splats so we can avoid expanding the values out and
222 // just fold based on the splat value.
223 auto op = cast<SplatElementsAttr>(Val: operands[0]);
224
225 auto elementResult = calculate(op.getSplatValue<ElementValueT>());
226 if (!elementResult)
227 return {};
228 return DenseElementsAttr::get(op.getType(), *elementResult);
229 } else if (isa<ElementsAttr>(Val: operands[0])) {
230 // Operands are ElementsAttr-derived; perform an element-wise fold by
231 // expanding the values.
232 auto op = cast<ElementsAttr>(operands[0]);
233
234 auto maybeOpIt = op.try_value_begin<ElementValueT>();
235 if (!maybeOpIt)
236 return {};
237 auto opIt = *maybeOpIt;
238 SmallVector<ElementValueT> elementResults;
239 elementResults.reserve(op.getNumElements());
240 for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
241 auto elementResult = calculate(*opIt);
242 if (!elementResult)
243 return {};
244 elementResults.push_back(*elementResult);
245 }
246 return DenseElementsAttr::get(op.getShapedType(), elementResults);
247 }
248 return {};
249}
250
251template <class AttrElementT,
252 class ElementValueT = typename AttrElementT::ValueType,
253 class PoisonAttr = ub::PoisonAttr,
254 class CalculationT = function_ref<ElementValueT(ElementValueT)>>
255Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
256 CalculationT &&calculate) {
257 return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
258 operands, [&](ElementValueT a) -> std::optional<ElementValueT> {
259 return calculate(a);
260 });
261}
262
263template <
264 class AttrElementT, class TargetAttrElementT,
265 class ElementValueT = typename AttrElementT::ValueType,
266 class TargetElementValueT = typename TargetAttrElementT::ValueType,
267 class PoisonAttr = ub::PoisonAttr,
268 class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
269Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
270 CalculationT &&calculate) {
271 assert(operands.size() == 1 && "Cast op takes one operand");
272 if (!operands[0])
273 return {};
274
275 static_assert(
276 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
277 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
278 "void as template argument to opt-out from poison semantics.");
279 if constexpr (!std::is_void_v<PoisonAttr>) {
280 if (isa<PoisonAttr>(operands[0]))
281 return operands[0];
282 }
283
284 if (isa<AttrElementT>(operands[0])) {
285 auto op = cast<AttrElementT>(operands[0]);
286 bool castStatus = true;
287 auto res = calculate(op.getValue(), castStatus);
288 if (!castStatus)
289 return {};
290 return TargetAttrElementT::get(resType, res);
291 }
292 if (isa<SplatElementsAttr>(Val: operands[0])) {
293 // The operand is a splat so we can avoid expanding the values out and
294 // just fold based on the splat value.
295 auto op = cast<SplatElementsAttr>(Val: operands[0]);
296 bool castStatus = true;
297 auto elementResult =
298 calculate(op.getSplatValue<ElementValueT>(), castStatus);
299 if (!castStatus)
300 return {};
301 return DenseElementsAttr::get(cast<ShapedType>(resType), elementResult);
302 }
303 if (auto op = dyn_cast<ElementsAttr>(operands[0])) {
304 // Operand is ElementsAttr-derived; perform an element-wise fold by
305 // expanding the value.
306 bool castStatus = true;
307 auto maybeOpIt = op.try_value_begin<ElementValueT>();
308 if (!maybeOpIt)
309 return {};
310 auto opIt = *maybeOpIt;
311 SmallVector<TargetElementValueT> elementResults;
312 elementResults.reserve(op.getNumElements());
313 for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
314 auto elt = calculate(*opIt, castStatus);
315 if (!castStatus)
316 return {};
317 elementResults.push_back(elt);
318 }
319
320 return DenseElementsAttr::get(cast<ShapedType>(resType), elementResults);
321 }
322 return {};
323}
324} // namespace mlir
325
326#endif // MLIR_DIALECT_COMMONFOLDERS_H
327

source code of mlir/include/mlir/Dialect/CommonFolders.h