1 | //===- AttrTypeSubElements.h - Attr and Type SubElements -------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains utilities for querying the sub elements of an attribute or |
10 | // type. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef MLIR_IR_ATTRTYPESUBELEMENTS_H |
15 | #define MLIR_IR_ATTRTYPESUBELEMENTS_H |
16 | |
17 | #include "mlir/IR/MLIRContext.h" |
18 | #include "mlir/IR/Visitors.h" |
19 | #include "llvm/ADT/ArrayRef.h" |
20 | #include "llvm/ADT/DenseMap.h" |
21 | #include <optional> |
22 | |
23 | namespace mlir { |
24 | class Attribute; |
25 | class Type; |
26 | |
27 | //===----------------------------------------------------------------------===// |
28 | /// AttrTypeWalker |
29 | //===----------------------------------------------------------------------===// |
30 | |
31 | /// This class provides a utility for walking attributes/types, and their sub |
32 | /// elements. Multiple walk functions may be registered. |
33 | class AttrTypeWalker { |
34 | public: |
35 | //===--------------------------------------------------------------------===// |
36 | // Application |
37 | //===--------------------------------------------------------------------===// |
38 | |
39 | /// Walk the given attribute/type, and recursively walk any sub elements. |
40 | template <WalkOrder Order, typename T> |
41 | WalkResult walk(T element) { |
42 | return walkImpl(element, Order); |
43 | } |
44 | template <typename T> |
45 | WalkResult walk(T element) { |
46 | return walk<WalkOrder::PostOrder, T>(element); |
47 | } |
48 | |
49 | //===--------------------------------------------------------------------===// |
50 | // Registration |
51 | //===--------------------------------------------------------------------===// |
52 | |
53 | template <typename T> |
54 | using WalkFn = std::function<WalkResult(T)>; |
55 | |
56 | /// Register a walk function for a given attribute or type. A walk function |
57 | /// must be convertible to any of the following forms(where `T` is a class |
58 | /// derived from `Type` or `Attribute`: |
59 | /// |
60 | /// * WalkResult(T) |
61 | /// - Returns a walk result, which can be used to control the walk |
62 | /// |
63 | /// * void(T) |
64 | /// - Returns void, i.e. the walk always continues. |
65 | /// |
66 | /// Note: When walking, the mostly recently added walk functions will be |
67 | /// invoked first. |
68 | void addWalk(WalkFn<Attribute> &&fn) { |
69 | attrWalkFns.emplace_back(args: std::move(fn)); |
70 | } |
71 | void addWalk(WalkFn<Type> &&fn) { typeWalkFns.push_back(x: std::move(fn)); } |
72 | |
73 | /// Register a replacement function that doesn't match the default signature, |
74 | /// either because it uses a derived parameter type, or it uses a simplified |
75 | /// result type. |
76 | template <typename FnT, |
77 | typename T = typename llvm::function_traits< |
78 | std::decay_t<FnT>>::template arg_t<0>, |
79 | typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>, |
80 | Attribute, Type>, |
81 | typename ResultT = std::invoke_result_t<FnT, T>> |
82 | std::enable_if_t<!std::is_same_v<T, BaseT> || std::is_same_v<ResultT, void>> |
83 | addWalk(FnT &&callback) { |
84 | addWalk([callback = std::forward<FnT>(callback)](BaseT base) -> WalkResult { |
85 | if (auto derived = dyn_cast<T>(base)) { |
86 | if constexpr (std::is_convertible_v<ResultT, WalkResult>) |
87 | return callback(derived); |
88 | else |
89 | callback(derived); |
90 | } |
91 | return WalkResult::advance(); |
92 | }); |
93 | } |
94 | |
95 | private: |
96 | WalkResult walkImpl(Attribute attr, WalkOrder order); |
97 | WalkResult walkImpl(Type type, WalkOrder order); |
98 | |
99 | /// Internal implementation of the `walk` methods above. |
100 | template <typename T, typename WalkFns> |
101 | WalkResult walkImpl(T element, WalkFns &walkFns, WalkOrder order); |
102 | |
103 | /// Walk the sub elements of the given interface. |
104 | template <typename T> |
105 | WalkResult walkSubElements(T interface, WalkOrder order); |
106 | |
107 | /// The set of walk functions that map sub elements. |
108 | std::vector<WalkFn<Attribute>> attrWalkFns; |
109 | std::vector<WalkFn<Type>> typeWalkFns; |
110 | |
111 | /// The set of visited attributes/types. |
112 | DenseMap<std::pair<const void *, int>, WalkResult> visitedAttrTypes; |
113 | }; |
114 | |
115 | //===----------------------------------------------------------------------===// |
116 | /// AttrTypeReplacer |
117 | //===----------------------------------------------------------------------===// |
118 | |
119 | /// This class provides a utility for replacing attributes/types, and their sub |
120 | /// elements. Multiple replacement functions may be registered. |
121 | class AttrTypeReplacer { |
122 | public: |
123 | //===--------------------------------------------------------------------===// |
124 | // Application |
125 | //===--------------------------------------------------------------------===// |
126 | |
127 | /// Replace the elements within the given operation. If `replaceAttrs` is |
128 | /// true, this updates the attribute dictionary of the operation. If |
129 | /// `replaceLocs` is true, this also updates its location, and the locations |
130 | /// of any nested block arguments. If `replaceTypes` is true, this also |
131 | /// updates the result types of the operation, and the types of any nested |
132 | /// block arguments. |
133 | void replaceElementsIn(Operation *op, bool replaceAttrs = true, |
134 | bool replaceLocs = false, bool replaceTypes = false); |
135 | |
136 | /// Replace the elements within the given operation, and all nested |
137 | /// operations. |
138 | void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs = true, |
139 | bool replaceLocs = false, |
140 | bool replaceTypes = false); |
141 | |
142 | /// Replace the given attribute/type, and recursively replace any sub |
143 | /// elements. Returns either the new attribute/type, or nullptr in the case of |
144 | /// failure. |
145 | Attribute replace(Attribute attr); |
146 | Type replace(Type type); |
147 | |
148 | //===--------------------------------------------------------------------===// |
149 | // Registration |
150 | //===--------------------------------------------------------------------===// |
151 | |
152 | /// A replacement mapping function, which returns either std::nullopt (to |
153 | /// signal the element wasn't handled), or a pair of the replacement element |
154 | /// and a WalkResult. |
155 | template <typename T> |
156 | using ReplaceFnResult = std::optional<std::pair<T, WalkResult>>; |
157 | template <typename T> |
158 | using ReplaceFn = std::function<ReplaceFnResult<T>(T)>; |
159 | |
160 | /// Register a replacement function for mapping a given attribute or type. A |
161 | /// replacement function must be convertible to any of the following |
162 | /// forms(where `T` is a class derived from `Type` or `Attribute`, and `BaseT` |
163 | /// is either `Type` or `Attribute` respectively): |
164 | /// |
165 | /// * std::optional<BaseT>(T) |
166 | /// - This either returns a valid Attribute/Type in the case of success, |
167 | /// nullptr in the case of failure, or `std::nullopt` to signify that |
168 | /// additional replacement functions may be applied (i.e. this function |
169 | /// doesn't handle that instance). |
170 | /// |
171 | /// * std::optional<std::pair<BaseT, WalkResult>>(T) |
172 | /// - Similar to the above, but also allows specifying a WalkResult to |
173 | /// control the replacement of sub elements of a given attribute or |
174 | /// type. Returning a `skip` result, for example, will not recursively |
175 | /// process the resultant attribute or type value. |
176 | /// |
177 | /// Note: When replacing, the mostly recently added replacement functions will |
178 | /// be invoked first. |
179 | void addReplacement(ReplaceFn<Attribute> fn); |
180 | void addReplacement(ReplaceFn<Type> fn); |
181 | |
182 | /// Register a replacement function that doesn't match the default signature, |
183 | /// either because it uses a derived parameter type, or it uses a simplified |
184 | /// result type. |
185 | template <typename FnT, |
186 | typename T = typename llvm::function_traits< |
187 | std::decay_t<FnT>>::template arg_t<0>, |
188 | typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>, |
189 | Attribute, Type>, |
190 | typename ResultT = std::invoke_result_t<FnT, T>> |
191 | std::enable_if_t<!std::is_same_v<T, BaseT> || |
192 | !std::is_convertible_v<ResultT, ReplaceFnResult<BaseT>>> |
193 | addReplacement(FnT &&callback) { |
194 | addReplacement([callback = std::forward<FnT>(callback)]( |
195 | BaseT base) -> ReplaceFnResult<BaseT> { |
196 | if (auto derived = dyn_cast<T>(base)) { |
197 | if constexpr (std::is_convertible_v<ResultT, std::optional<BaseT>>) { |
198 | std::optional<BaseT> result = callback(derived); |
199 | return result ? std::make_pair(*result, WalkResult::advance()) |
200 | : ReplaceFnResult<BaseT>(); |
201 | } else { |
202 | return callback(derived); |
203 | } |
204 | } |
205 | return ReplaceFnResult<BaseT>(); |
206 | }); |
207 | } |
208 | |
209 | private: |
210 | /// Internal implementation of the `replace` methods above. |
211 | template <typename T, typename ReplaceFns> |
212 | T replaceImpl(T element, ReplaceFns &replaceFns); |
213 | |
214 | /// Replace the sub elements of the given interface. |
215 | template <typename T> |
216 | T replaceSubElements(T interface); |
217 | |
218 | /// The set of replacement functions that map sub elements. |
219 | std::vector<ReplaceFn<Attribute>> attrReplacementFns; |
220 | std::vector<ReplaceFn<Type>> typeReplacementFns; |
221 | |
222 | /// The set of cached mappings for attributes/types. |
223 | DenseMap<const void *, const void *> attrTypeMap; |
224 | }; |
225 | |
226 | //===----------------------------------------------------------------------===// |
227 | /// AttrTypeSubElementHandler |
228 | //===----------------------------------------------------------------------===// |
229 | |
230 | /// This class is used by AttrTypeSubElementHandler instances to walking sub |
231 | /// attributes and types. |
232 | class AttrTypeImmediateSubElementWalker { |
233 | public: |
234 | AttrTypeImmediateSubElementWalker(function_ref<void(Attribute)> walkAttrsFn, |
235 | function_ref<void(Type)> walkTypesFn) |
236 | : walkAttrsFn(walkAttrsFn), walkTypesFn(walkTypesFn) {} |
237 | |
238 | /// Walk an attribute. |
239 | void walk(Attribute element); |
240 | /// Walk a type. |
241 | void walk(Type element); |
242 | /// Walk a range of attributes or types. |
243 | template <typename RangeT> |
244 | void walkRange(RangeT &&elements) { |
245 | for (auto element : elements) |
246 | walk(element); |
247 | } |
248 | |
249 | private: |
250 | function_ref<void(Attribute)> walkAttrsFn; |
251 | function_ref<void(Type)> walkTypesFn; |
252 | }; |
253 | |
254 | /// This class is used by AttrTypeSubElementHandler instances to process sub |
255 | /// element replacements. |
256 | template <typename T> |
257 | class AttrTypeSubElementReplacements { |
258 | public: |
259 | AttrTypeSubElementReplacements(ArrayRef<T> repls) : repls(repls) {} |
260 | |
261 | /// Take the first N replacements as an ArrayRef, dropping them from |
262 | /// this replacement list. |
263 | ArrayRef<T> take_front(unsigned n) { |
264 | ArrayRef<T> elements = repls.take_front(n); |
265 | repls = repls.drop_front(n); |
266 | return elements; |
267 | } |
268 | |
269 | private: |
270 | /// The current set of replacements. |
271 | ArrayRef<T> repls; |
272 | }; |
273 | using AttrSubElementReplacements = AttrTypeSubElementReplacements<Attribute>; |
274 | using TypeSubElementReplacements = AttrTypeSubElementReplacements<Type>; |
275 | |
276 | /// This class provides support for interacting with the |
277 | /// SubElementInterfaces for different types of parameters. An |
278 | /// implementation of this class should be provided for any parameter class |
279 | /// that may contain an attribute or type. There are two main methods of |
280 | /// this class that need to be implemented: |
281 | /// |
282 | /// - walk |
283 | /// |
284 | /// This method should traverse into any sub elements of the parameter |
285 | /// using the provided walker, or by invoking handlers for sub-types. |
286 | /// |
287 | /// - replace |
288 | /// |
289 | /// This method should extract any necessary sub elements using the |
290 | /// provided replacer, or by invoking handlers for sub-types. The new |
291 | /// post-replacement parameter value should be returned. |
292 | /// |
293 | template <typename T, typename Enable = void> |
294 | struct AttrTypeSubElementHandler { |
295 | /// Default walk implementation that does nothing. |
296 | static inline void walk(const T ¶m, |
297 | AttrTypeImmediateSubElementWalker &walker) {} |
298 | |
299 | /// Default replace implementation just forwards the parameter. |
300 | template <typename ParamT> |
301 | static inline decltype(auto) replace(ParamT &¶m, |
302 | AttrSubElementReplacements &attrRepls, |
303 | TypeSubElementReplacements &typeRepls) { |
304 | return std::forward<ParamT>(param); |
305 | } |
306 | |
307 | /// Tag indicating that this handler does not support sub-elements. |
308 | using DefaultHandlerTag = void; |
309 | }; |
310 | |
311 | /// Detect if any of the given parameter types has a sub-element handler. |
312 | namespace detail { |
313 | template <typename T> |
314 | using has_default_sub_element_handler_t = decltype(T::DefaultHandlerTag); |
315 | } // namespace detail |
316 | template <typename... Ts> |
317 | inline constexpr bool has_sub_attr_or_type_v = |
318 | (!llvm::is_detected<detail::has_default_sub_element_handler_t, Ts>::value || |
319 | ...); |
320 | |
321 | /// Implementation for derived Attributes and Types. |
322 | template <typename T> |
323 | struct AttrTypeSubElementHandler< |
324 | T, std::enable_if_t<std::is_base_of_v<Attribute, T> || |
325 | std::is_base_of_v<Type, T>>> { |
326 | static void walk(T param, AttrTypeImmediateSubElementWalker &walker) { |
327 | walker.walk(param); |
328 | } |
329 | static T replace(T param, AttrSubElementReplacements &attrRepls, |
330 | TypeSubElementReplacements &typeRepls) { |
331 | if (!param) |
332 | return T(); |
333 | if constexpr (std::is_base_of_v<Attribute, T>) { |
334 | return cast<T>(attrRepls.take_front(n: 1)[0]); |
335 | } else { |
336 | return cast<T>(typeRepls.take_front(n: 1)[0]); |
337 | } |
338 | } |
339 | }; |
340 | /// Implementation for derived ArrayRef. |
341 | template <typename T> |
342 | struct AttrTypeSubElementHandler<ArrayRef<T>, |
343 | std::enable_if_t<has_sub_attr_or_type_v<T>>> { |
344 | using EltHandler = AttrTypeSubElementHandler<T>; |
345 | |
346 | static void walk(ArrayRef<T> param, |
347 | AttrTypeImmediateSubElementWalker &walker) { |
348 | for (const T &subElement : param) |
349 | EltHandler::walk(subElement, walker); |
350 | } |
351 | static auto replace(ArrayRef<T> param, AttrSubElementReplacements &attrRepls, |
352 | TypeSubElementReplacements &typeRepls) { |
353 | // Normal attributes/types can extract using the replacer directly. |
354 | if constexpr (std::is_base_of_v<Attribute, T> && |
355 | sizeof(T) == sizeof(void *)) { |
356 | ArrayRef<Attribute> attrs = attrRepls.take_front(n: param.size()); |
357 | return ArrayRef<T>((const T *)attrs.data(), attrs.size()); |
358 | } else if constexpr (std::is_base_of_v<Type, T> && |
359 | sizeof(T) == sizeof(void *)) { |
360 | ArrayRef<Type> types = typeRepls.take_front(n: param.size()); |
361 | return ArrayRef<T>((const T *)types.data(), types.size()); |
362 | } else { |
363 | // Otherwise, we need to allocate storage for the new elements. |
364 | SmallVector<T> newElements; |
365 | for (const T &element : param) |
366 | newElements.emplace_back( |
367 | EltHandler::replace(element, attrRepls, typeRepls)); |
368 | return newElements; |
369 | } |
370 | } |
371 | }; |
372 | /// Implementation for Tuple. |
373 | template <typename... Ts> |
374 | struct AttrTypeSubElementHandler< |
375 | std::tuple<Ts...>, std::enable_if_t<has_sub_attr_or_type_v<Ts...>>> { |
376 | static void walk(const std::tuple<Ts...> ¶m, |
377 | AttrTypeImmediateSubElementWalker &walker) { |
378 | std::apply( |
379 | [&](const Ts &...params) { |
380 | (AttrTypeSubElementHandler<Ts>::walk(params, walker), ...); |
381 | }, |
382 | param); |
383 | } |
384 | static auto replace(const std::tuple<Ts...> ¶m, |
385 | AttrSubElementReplacements &attrRepls, |
386 | TypeSubElementReplacements &typeRepls) { |
387 | return std::apply( |
388 | [&](const Ts &...params) |
389 | -> std::tuple<decltype(AttrTypeSubElementHandler<Ts>::replace( |
390 | params, attrRepls, typeRepls))...> { |
391 | return {AttrTypeSubElementHandler<Ts>::replace(params, attrRepls, |
392 | typeRepls)...}; |
393 | }, |
394 | param); |
395 | } |
396 | }; |
397 | |
398 | namespace detail { |
399 | template <typename T> |
400 | struct is_tuple : public std::false_type {}; |
401 | template <typename... Ts> |
402 | struct is_tuple<std::tuple<Ts...>> : public std::true_type {}; |
403 | |
404 | template <typename T> |
405 | struct is_pair : public std::false_type {}; |
406 | template <typename... Ts> |
407 | struct is_pair<std::pair<Ts...>> : public std::true_type {}; |
408 | |
409 | template <typename T, typename... Ts> |
410 | using has_get_method = decltype(T::get(std::declval<Ts>()...)); |
411 | template <typename T, typename... Ts> |
412 | using has_get_as_key = decltype(std::declval<T>().getAsKey()); |
413 | |
414 | /// This function provides the underlying implementation for the |
415 | /// SubElementInterface walk method, using the key type of the derived |
416 | /// attribute/type to interact with the individual parameters. |
417 | template <typename T> |
418 | void walkImmediateSubElementsImpl(T derived, |
419 | function_ref<void(Attribute)> walkAttrsFn, |
420 | function_ref<void(Type)> walkTypesFn) { |
421 | using ImplT = typename T::ImplType; |
422 | (void)derived; |
423 | (void)walkAttrsFn; |
424 | (void)walkTypesFn; |
425 | if constexpr (llvm::is_detected<has_get_as_key, ImplT>::value) { |
426 | auto key = static_cast<ImplT *>(derived.getImpl())->getAsKey(); |
427 | |
428 | // If we don't have any sub-elements, there is nothing to do. |
429 | if constexpr (!has_sub_attr_or_type_v<decltype(key)>) |
430 | return; |
431 | AttrTypeImmediateSubElementWalker walker(walkAttrsFn, walkTypesFn); |
432 | AttrTypeSubElementHandler<decltype(key)>::walk(key, walker); |
433 | } |
434 | } |
435 | |
436 | /// This function invokes the proper `get` method for a type `T` with the given |
437 | /// values. |
438 | template <typename T, typename... Ts> |
439 | auto constructSubElementReplacement(MLIRContext *ctx, Ts &&...params) { |
440 | // Prefer a direct `get` method if one exists. |
441 | if constexpr (llvm::is_detected<has_get_method, T, Ts...>::value) { |
442 | (void)ctx; |
443 | return T::get(std::forward<Ts>(params)...); |
444 | } else if constexpr (llvm::is_detected<has_get_method, T, MLIRContext *, |
445 | Ts...>::value) { |
446 | return T::get(ctx, std::forward<Ts>(params)...); |
447 | } else { |
448 | // Otherwise, pass to the base get. |
449 | return T::Base::get(ctx, std::forward<Ts>(params)...); |
450 | } |
451 | } |
452 | |
453 | /// This function provides the underlying implementation for the |
454 | /// SubElementInterface replace method, using the key type of the derived |
455 | /// attribute/type to interact with the individual parameters. |
456 | template <typename T> |
457 | auto replaceImmediateSubElementsImpl(T derived, ArrayRef<Attribute> &replAttrs, |
458 | ArrayRef<Type> &replTypes) { |
459 | using ImplT = typename T::ImplType; |
460 | if constexpr (llvm::is_detected<has_get_as_key, ImplT>::value) { |
461 | auto key = static_cast<ImplT *>(derived.getImpl())->getAsKey(); |
462 | |
463 | // If we don't have any sub-elements, we can just return the original. |
464 | if constexpr (!has_sub_attr_or_type_v<decltype(key)>) { |
465 | return derived; |
466 | |
467 | // Otherwise, we need to replace any necessary sub-elements. |
468 | } else { |
469 | // Functor used to build the replacement on success. |
470 | auto buildReplacement = [&](auto newKey, MLIRContext *ctx) { |
471 | if constexpr (is_tuple<decltype(key)>::value || |
472 | is_pair<decltype(key)>::value) { |
473 | return std::apply( |
474 | [&](auto &&...params) { |
475 | return constructSubElementReplacement<T>( |
476 | ctx, std::forward<decltype(params)>(params)...); |
477 | }, |
478 | newKey); |
479 | } else { |
480 | return constructSubElementReplacement<T>(ctx, newKey); |
481 | } |
482 | }; |
483 | |
484 | AttrSubElementReplacements attrRepls(replAttrs); |
485 | TypeSubElementReplacements typeRepls(replTypes); |
486 | auto newKey = AttrTypeSubElementHandler<decltype(key)>::replace( |
487 | key, attrRepls, typeRepls); |
488 | MLIRContext *ctx = derived.getContext(); |
489 | if constexpr (std::is_convertible_v<decltype(newKey), LogicalResult>) |
490 | return succeeded(newKey) ? buildReplacement(*newKey, ctx) : nullptr; |
491 | else |
492 | return buildReplacement(newKey, ctx); |
493 | } |
494 | } else { |
495 | return derived; |
496 | } |
497 | } |
498 | } // namespace detail |
499 | } // namespace mlir |
500 | |
501 | #endif // MLIR_IR_ATTRTYPESUBELEMENTS_H |
502 | |