1//===- AttrTypeSubElements.h - Attr and Type SubElements -------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains utilities for querying the sub elements of an attribute or
10// type.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_IR_ATTRTYPESUBELEMENTS_H
15#define MLIR_IR_ATTRTYPESUBELEMENTS_H
16
17#include "mlir/IR/MLIRContext.h"
18#include "mlir/IR/Visitors.h"
19#include "llvm/ADT/ArrayRef.h"
20#include "llvm/ADT/DenseMap.h"
21#include <optional>
22
23namespace mlir {
24class Attribute;
25class Type;
26
27//===----------------------------------------------------------------------===//
28/// AttrTypeWalker
29//===----------------------------------------------------------------------===//
30
31/// This class provides a utility for walking attributes/types, and their sub
32/// elements. Multiple walk functions may be registered.
33class AttrTypeWalker {
34public:
35 //===--------------------------------------------------------------------===//
36 // Application
37 //===--------------------------------------------------------------------===//
38
39 /// Walk the given attribute/type, and recursively walk any sub elements.
40 template <WalkOrder Order, typename T>
41 WalkResult walk(T element) {
42 return walkImpl(element, Order);
43 }
44 template <typename T>
45 WalkResult walk(T element) {
46 return walk<WalkOrder::PostOrder, T>(element);
47 }
48
49 //===--------------------------------------------------------------------===//
50 // Registration
51 //===--------------------------------------------------------------------===//
52
53 template <typename T>
54 using WalkFn = std::function<WalkResult(T)>;
55
56 /// Register a walk function for a given attribute or type. A walk function
57 /// must be convertible to any of the following forms(where `T` is a class
58 /// derived from `Type` or `Attribute`:
59 ///
60 /// * WalkResult(T)
61 /// - Returns a walk result, which can be used to control the walk
62 ///
63 /// * void(T)
64 /// - Returns void, i.e. the walk always continues.
65 ///
66 /// Note: When walking, the mostly recently added walk functions will be
67 /// invoked first.
68 void addWalk(WalkFn<Attribute> &&fn) {
69 attrWalkFns.emplace_back(args: std::move(fn));
70 }
71 void addWalk(WalkFn<Type> &&fn) { typeWalkFns.push_back(x: std::move(fn)); }
72
73 /// Register a replacement function that doesn't match the default signature,
74 /// either because it uses a derived parameter type, or it uses a simplified
75 /// result type.
76 template <typename FnT,
77 typename T = typename llvm::function_traits<
78 std::decay_t<FnT>>::template arg_t<0>,
79 typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>,
80 Attribute, Type>,
81 typename ResultT = std::invoke_result_t<FnT, T>>
82 std::enable_if_t<!std::is_same_v<T, BaseT> || std::is_same_v<ResultT, void>>
83 addWalk(FnT &&callback) {
84 addWalk([callback = std::forward<FnT>(callback)](BaseT base) -> WalkResult {
85 if (auto derived = dyn_cast<T>(base)) {
86 if constexpr (std::is_convertible_v<ResultT, WalkResult>)
87 return callback(derived);
88 else
89 callback(derived);
90 }
91 return WalkResult::advance();
92 });
93 }
94
95private:
96 WalkResult walkImpl(Attribute attr, WalkOrder order);
97 WalkResult walkImpl(Type type, WalkOrder order);
98
99 /// Internal implementation of the `walk` methods above.
100 template <typename T, typename WalkFns>
101 WalkResult walkImpl(T element, WalkFns &walkFns, WalkOrder order);
102
103 /// Walk the sub elements of the given interface.
104 template <typename T>
105 WalkResult walkSubElements(T interface, WalkOrder order);
106
107 /// The set of walk functions that map sub elements.
108 std::vector<WalkFn<Attribute>> attrWalkFns;
109 std::vector<WalkFn<Type>> typeWalkFns;
110
111 /// The set of visited attributes/types.
112 DenseMap<std::pair<const void *, int>, WalkResult> visitedAttrTypes;
113};
114
115//===----------------------------------------------------------------------===//
116/// AttrTypeReplacer
117//===----------------------------------------------------------------------===//
118
119/// This class provides a utility for replacing attributes/types, and their sub
120/// elements. Multiple replacement functions may be registered.
121class AttrTypeReplacer {
122public:
123 //===--------------------------------------------------------------------===//
124 // Application
125 //===--------------------------------------------------------------------===//
126
127 /// Replace the elements within the given operation. If `replaceAttrs` is
128 /// true, this updates the attribute dictionary of the operation. If
129 /// `replaceLocs` is true, this also updates its location, and the locations
130 /// of any nested block arguments. If `replaceTypes` is true, this also
131 /// updates the result types of the operation, and the types of any nested
132 /// block arguments.
133 void replaceElementsIn(Operation *op, bool replaceAttrs = true,
134 bool replaceLocs = false, bool replaceTypes = false);
135
136 /// Replace the elements within the given operation, and all nested
137 /// operations.
138 void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs = true,
139 bool replaceLocs = false,
140 bool replaceTypes = false);
141
142 /// Replace the given attribute/type, and recursively replace any sub
143 /// elements. Returns either the new attribute/type, or nullptr in the case of
144 /// failure.
145 Attribute replace(Attribute attr);
146 Type replace(Type type);
147
148 //===--------------------------------------------------------------------===//
149 // Registration
150 //===--------------------------------------------------------------------===//
151
152 /// A replacement mapping function, which returns either std::nullopt (to
153 /// signal the element wasn't handled), or a pair of the replacement element
154 /// and a WalkResult.
155 template <typename T>
156 using ReplaceFnResult = std::optional<std::pair<T, WalkResult>>;
157 template <typename T>
158 using ReplaceFn = std::function<ReplaceFnResult<T>(T)>;
159
160 /// Register a replacement function for mapping a given attribute or type. A
161 /// replacement function must be convertible to any of the following
162 /// forms(where `T` is a class derived from `Type` or `Attribute`, and `BaseT`
163 /// is either `Type` or `Attribute` respectively):
164 ///
165 /// * std::optional<BaseT>(T)
166 /// - This either returns a valid Attribute/Type in the case of success,
167 /// nullptr in the case of failure, or `std::nullopt` to signify that
168 /// additional replacement functions may be applied (i.e. this function
169 /// doesn't handle that instance).
170 ///
171 /// * std::optional<std::pair<BaseT, WalkResult>>(T)
172 /// - Similar to the above, but also allows specifying a WalkResult to
173 /// control the replacement of sub elements of a given attribute or
174 /// type. Returning a `skip` result, for example, will not recursively
175 /// process the resultant attribute or type value.
176 ///
177 /// Note: When replacing, the mostly recently added replacement functions will
178 /// be invoked first.
179 void addReplacement(ReplaceFn<Attribute> fn);
180 void addReplacement(ReplaceFn<Type> fn);
181
182 /// Register a replacement function that doesn't match the default signature,
183 /// either because it uses a derived parameter type, or it uses a simplified
184 /// result type.
185 template <typename FnT,
186 typename T = typename llvm::function_traits<
187 std::decay_t<FnT>>::template arg_t<0>,
188 typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>,
189 Attribute, Type>,
190 typename ResultT = std::invoke_result_t<FnT, T>>
191 std::enable_if_t<!std::is_same_v<T, BaseT> ||
192 !std::is_convertible_v<ResultT, ReplaceFnResult<BaseT>>>
193 addReplacement(FnT &&callback) {
194 addReplacement([callback = std::forward<FnT>(callback)](
195 BaseT base) -> ReplaceFnResult<BaseT> {
196 if (auto derived = dyn_cast<T>(base)) {
197 if constexpr (std::is_convertible_v<ResultT, std::optional<BaseT>>) {
198 std::optional<BaseT> result = callback(derived);
199 return result ? std::make_pair(*result, WalkResult::advance())
200 : ReplaceFnResult<BaseT>();
201 } else {
202 return callback(derived);
203 }
204 }
205 return ReplaceFnResult<BaseT>();
206 });
207 }
208
209private:
210 /// Internal implementation of the `replace` methods above.
211 template <typename T, typename ReplaceFns>
212 T replaceImpl(T element, ReplaceFns &replaceFns);
213
214 /// Replace the sub elements of the given interface.
215 template <typename T>
216 T replaceSubElements(T interface);
217
218 /// The set of replacement functions that map sub elements.
219 std::vector<ReplaceFn<Attribute>> attrReplacementFns;
220 std::vector<ReplaceFn<Type>> typeReplacementFns;
221
222 /// The set of cached mappings for attributes/types.
223 DenseMap<const void *, const void *> attrTypeMap;
224};
225
226//===----------------------------------------------------------------------===//
227/// AttrTypeSubElementHandler
228//===----------------------------------------------------------------------===//
229
230/// This class is used by AttrTypeSubElementHandler instances to walking sub
231/// attributes and types.
232class AttrTypeImmediateSubElementWalker {
233public:
234 AttrTypeImmediateSubElementWalker(function_ref<void(Attribute)> walkAttrsFn,
235 function_ref<void(Type)> walkTypesFn)
236 : walkAttrsFn(walkAttrsFn), walkTypesFn(walkTypesFn) {}
237
238 /// Walk an attribute.
239 void walk(Attribute element);
240 /// Walk a type.
241 void walk(Type element);
242 /// Walk a range of attributes or types.
243 template <typename RangeT>
244 void walkRange(RangeT &&elements) {
245 for (auto element : elements)
246 walk(element);
247 }
248
249private:
250 function_ref<void(Attribute)> walkAttrsFn;
251 function_ref<void(Type)> walkTypesFn;
252};
253
254/// This class is used by AttrTypeSubElementHandler instances to process sub
255/// element replacements.
256template <typename T>
257class AttrTypeSubElementReplacements {
258public:
259 AttrTypeSubElementReplacements(ArrayRef<T> repls) : repls(repls) {}
260
261 /// Take the first N replacements as an ArrayRef, dropping them from
262 /// this replacement list.
263 ArrayRef<T> take_front(unsigned n) {
264 ArrayRef<T> elements = repls.take_front(n);
265 repls = repls.drop_front(n);
266 return elements;
267 }
268
269private:
270 /// The current set of replacements.
271 ArrayRef<T> repls;
272};
273using AttrSubElementReplacements = AttrTypeSubElementReplacements<Attribute>;
274using TypeSubElementReplacements = AttrTypeSubElementReplacements<Type>;
275
276/// This class provides support for interacting with the
277/// SubElementInterfaces for different types of parameters. An
278/// implementation of this class should be provided for any parameter class
279/// that may contain an attribute or type. There are two main methods of
280/// this class that need to be implemented:
281///
282/// - walk
283///
284/// This method should traverse into any sub elements of the parameter
285/// using the provided walker, or by invoking handlers for sub-types.
286///
287/// - replace
288///
289/// This method should extract any necessary sub elements using the
290/// provided replacer, or by invoking handlers for sub-types. The new
291/// post-replacement parameter value should be returned.
292///
293template <typename T, typename Enable = void>
294struct AttrTypeSubElementHandler {
295 /// Default walk implementation that does nothing.
296 static inline void walk(const T &param,
297 AttrTypeImmediateSubElementWalker &walker) {}
298
299 /// Default replace implementation just forwards the parameter.
300 template <typename ParamT>
301 static inline decltype(auto) replace(ParamT &&param,
302 AttrSubElementReplacements &attrRepls,
303 TypeSubElementReplacements &typeRepls) {
304 return std::forward<ParamT>(param);
305 }
306
307 /// Tag indicating that this handler does not support sub-elements.
308 using DefaultHandlerTag = void;
309};
310
311/// Detect if any of the given parameter types has a sub-element handler.
312namespace detail {
313template <typename T>
314using has_default_sub_element_handler_t = decltype(T::DefaultHandlerTag);
315} // namespace detail
316template <typename... Ts>
317inline constexpr bool has_sub_attr_or_type_v =
318 (!llvm::is_detected<detail::has_default_sub_element_handler_t, Ts>::value ||
319 ...);
320
321/// Implementation for derived Attributes and Types.
322template <typename T>
323struct AttrTypeSubElementHandler<
324 T, std::enable_if_t<std::is_base_of_v<Attribute, T> ||
325 std::is_base_of_v<Type, T>>> {
326 static void walk(T param, AttrTypeImmediateSubElementWalker &walker) {
327 walker.walk(param);
328 }
329 static T replace(T param, AttrSubElementReplacements &attrRepls,
330 TypeSubElementReplacements &typeRepls) {
331 if (!param)
332 return T();
333 if constexpr (std::is_base_of_v<Attribute, T>) {
334 return cast<T>(attrRepls.take_front(n: 1)[0]);
335 } else {
336 return cast<T>(typeRepls.take_front(n: 1)[0]);
337 }
338 }
339};
340/// Implementation for derived ArrayRef.
341template <typename T>
342struct AttrTypeSubElementHandler<ArrayRef<T>,
343 std::enable_if_t<has_sub_attr_or_type_v<T>>> {
344 using EltHandler = AttrTypeSubElementHandler<T>;
345
346 static void walk(ArrayRef<T> param,
347 AttrTypeImmediateSubElementWalker &walker) {
348 for (const T &subElement : param)
349 EltHandler::walk(subElement, walker);
350 }
351 static auto replace(ArrayRef<T> param, AttrSubElementReplacements &attrRepls,
352 TypeSubElementReplacements &typeRepls) {
353 // Normal attributes/types can extract using the replacer directly.
354 if constexpr (std::is_base_of_v<Attribute, T> &&
355 sizeof(T) == sizeof(void *)) {
356 ArrayRef<Attribute> attrs = attrRepls.take_front(n: param.size());
357 return ArrayRef<T>((const T *)attrs.data(), attrs.size());
358 } else if constexpr (std::is_base_of_v<Type, T> &&
359 sizeof(T) == sizeof(void *)) {
360 ArrayRef<Type> types = typeRepls.take_front(n: param.size());
361 return ArrayRef<T>((const T *)types.data(), types.size());
362 } else {
363 // Otherwise, we need to allocate storage for the new elements.
364 SmallVector<T> newElements;
365 for (const T &element : param)
366 newElements.emplace_back(
367 EltHandler::replace(element, attrRepls, typeRepls));
368 return newElements;
369 }
370 }
371};
372/// Implementation for Tuple.
373template <typename... Ts>
374struct AttrTypeSubElementHandler<
375 std::tuple<Ts...>, std::enable_if_t<has_sub_attr_or_type_v<Ts...>>> {
376 static void walk(const std::tuple<Ts...> &param,
377 AttrTypeImmediateSubElementWalker &walker) {
378 std::apply(
379 [&](const Ts &...params) {
380 (AttrTypeSubElementHandler<Ts>::walk(params, walker), ...);
381 },
382 param);
383 }
384 static auto replace(const std::tuple<Ts...> &param,
385 AttrSubElementReplacements &attrRepls,
386 TypeSubElementReplacements &typeRepls) {
387 return std::apply(
388 [&](const Ts &...params)
389 -> std::tuple<decltype(AttrTypeSubElementHandler<Ts>::replace(
390 params, attrRepls, typeRepls))...> {
391 return {AttrTypeSubElementHandler<Ts>::replace(params, attrRepls,
392 typeRepls)...};
393 },
394 param);
395 }
396};
397
398namespace detail {
399template <typename T>
400struct is_tuple : public std::false_type {};
401template <typename... Ts>
402struct is_tuple<std::tuple<Ts...>> : public std::true_type {};
403
404template <typename T>
405struct is_pair : public std::false_type {};
406template <typename... Ts>
407struct is_pair<std::pair<Ts...>> : public std::true_type {};
408
409template <typename T, typename... Ts>
410using has_get_method = decltype(T::get(std::declval<Ts>()...));
411template <typename T, typename... Ts>
412using has_get_as_key = decltype(std::declval<T>().getAsKey());
413
414/// This function provides the underlying implementation for the
415/// SubElementInterface walk method, using the key type of the derived
416/// attribute/type to interact with the individual parameters.
417template <typename T>
418void walkImmediateSubElementsImpl(T derived,
419 function_ref<void(Attribute)> walkAttrsFn,
420 function_ref<void(Type)> walkTypesFn) {
421 using ImplT = typename T::ImplType;
422 (void)derived;
423 (void)walkAttrsFn;
424 (void)walkTypesFn;
425 if constexpr (llvm::is_detected<has_get_as_key, ImplT>::value) {
426 auto key = static_cast<ImplT *>(derived.getImpl())->getAsKey();
427
428 // If we don't have any sub-elements, there is nothing to do.
429 if constexpr (!has_sub_attr_or_type_v<decltype(key)>)
430 return;
431 AttrTypeImmediateSubElementWalker walker(walkAttrsFn, walkTypesFn);
432 AttrTypeSubElementHandler<decltype(key)>::walk(key, walker);
433 }
434}
435
436/// This function invokes the proper `get` method for a type `T` with the given
437/// values.
438template <typename T, typename... Ts>
439auto constructSubElementReplacement(MLIRContext *ctx, Ts &&...params) {
440 // Prefer a direct `get` method if one exists.
441 if constexpr (llvm::is_detected<has_get_method, T, Ts...>::value) {
442 (void)ctx;
443 return T::get(std::forward<Ts>(params)...);
444 } else if constexpr (llvm::is_detected<has_get_method, T, MLIRContext *,
445 Ts...>::value) {
446 return T::get(ctx, std::forward<Ts>(params)...);
447 } else {
448 // Otherwise, pass to the base get.
449 return T::Base::get(ctx, std::forward<Ts>(params)...);
450 }
451}
452
453/// This function provides the underlying implementation for the
454/// SubElementInterface replace method, using the key type of the derived
455/// attribute/type to interact with the individual parameters.
456template <typename T>
457auto replaceImmediateSubElementsImpl(T derived, ArrayRef<Attribute> &replAttrs,
458 ArrayRef<Type> &replTypes) {
459 using ImplT = typename T::ImplType;
460 if constexpr (llvm::is_detected<has_get_as_key, ImplT>::value) {
461 auto key = static_cast<ImplT *>(derived.getImpl())->getAsKey();
462
463 // If we don't have any sub-elements, we can just return the original.
464 if constexpr (!has_sub_attr_or_type_v<decltype(key)>) {
465 return derived;
466
467 // Otherwise, we need to replace any necessary sub-elements.
468 } else {
469 // Functor used to build the replacement on success.
470 auto buildReplacement = [&](auto newKey, MLIRContext *ctx) {
471 if constexpr (is_tuple<decltype(key)>::value ||
472 is_pair<decltype(key)>::value) {
473 return std::apply(
474 [&](auto &&...params) {
475 return constructSubElementReplacement<T>(
476 ctx, std::forward<decltype(params)>(params)...);
477 },
478 newKey);
479 } else {
480 return constructSubElementReplacement<T>(ctx, newKey);
481 }
482 };
483
484 AttrSubElementReplacements attrRepls(replAttrs);
485 TypeSubElementReplacements typeRepls(replTypes);
486 auto newKey = AttrTypeSubElementHandler<decltype(key)>::replace(
487 key, attrRepls, typeRepls);
488 MLIRContext *ctx = derived.getContext();
489 if constexpr (std::is_convertible_v<decltype(newKey), LogicalResult>)
490 return succeeded(newKey) ? buildReplacement(*newKey, ctx) : nullptr;
491 else
492 return buildReplacement(newKey, ctx);
493 }
494 } else {
495 return derived;
496 }
497}
498} // namespace detail
499} // namespace mlir
500
501#endif // MLIR_IR_ATTRTYPESUBELEMENTS_H
502

source code of mlir/include/mlir/IR/AttrTypeSubElements.h