1 | //===- AttrTypeSubElements.cpp - Attr and Type SubElement Interfaces ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/IR/Operation.h" |
10 | #include <optional> |
11 | |
12 | using namespace mlir; |
13 | |
14 | //===----------------------------------------------------------------------===// |
15 | // AttrTypeWalker |
16 | //===----------------------------------------------------------------------===// |
17 | |
18 | WalkResult AttrTypeWalker::walkImpl(Attribute attr, WalkOrder order) { |
19 | return walkImpl(element: attr, walkFns&: attrWalkFns, order); |
20 | } |
21 | WalkResult AttrTypeWalker::walkImpl(Type type, WalkOrder order) { |
22 | return walkImpl(element: type, walkFns&: typeWalkFns, order); |
23 | } |
24 | |
25 | template <typename T, typename WalkFns> |
26 | WalkResult AttrTypeWalker::walkImpl(T element, WalkFns &walkFns, |
27 | WalkOrder order) { |
28 | // Check if we've already walk this element before. |
29 | auto key = std::make_pair(element.getAsOpaquePointer(), (int)order); |
30 | auto [it, inserted] = |
31 | visitedAttrTypes.try_emplace(key, WalkResult::advance()); |
32 | if (!inserted) |
33 | return it->second; |
34 | |
35 | // If we are walking in post order, walk the sub elements first. |
36 | if (order == WalkOrder::PostOrder) { |
37 | if (walkSubElements(element, order).wasInterrupted()) |
38 | return visitedAttrTypes[key] = WalkResult::interrupt(); |
39 | } |
40 | |
41 | // Walk this element, bailing if skipped or interrupted. |
42 | for (auto &walkFn : llvm::reverse(walkFns)) { |
43 | WalkResult walkResult = walkFn(element); |
44 | if (walkResult.wasInterrupted()) |
45 | return visitedAttrTypes[key] = WalkResult::interrupt(); |
46 | if (walkResult.wasSkipped()) |
47 | return WalkResult::advance(); |
48 | } |
49 | |
50 | // If we are walking in pre-order, walk the sub elements last. |
51 | if (order == WalkOrder::PreOrder) { |
52 | if (walkSubElements(element, order).wasInterrupted()) |
53 | return WalkResult::interrupt(); |
54 | } |
55 | return WalkResult::advance(); |
56 | } |
57 | |
58 | template <typename T> |
59 | WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) { |
60 | WalkResult result = WalkResult::advance(); |
61 | auto walkFn = [&](auto element) { |
62 | if (element && !result.wasInterrupted()) |
63 | result = walkImpl(element, order); |
64 | }; |
65 | interface.walkImmediateSubElements(walkFn, walkFn); |
66 | return result.wasInterrupted() ? result : WalkResult::advance(); |
67 | } |
68 | |
69 | //===----------------------------------------------------------------------===// |
70 | /// AttrTypeReplacerBase |
71 | //===----------------------------------------------------------------------===// |
72 | |
73 | template <typename Concrete> |
74 | void detail::AttrTypeReplacerBase<Concrete>::addReplacement( |
75 | ReplaceFn<Attribute> fn) { |
76 | attrReplacementFns.emplace_back(args: std::move(fn)); |
77 | } |
78 | |
79 | template <typename Concrete> |
80 | void detail::AttrTypeReplacerBase<Concrete>::addReplacement( |
81 | ReplaceFn<Type> fn) { |
82 | typeReplacementFns.push_back(x: std::move(fn)); |
83 | } |
84 | |
85 | template <typename Concrete> |
86 | void detail::AttrTypeReplacerBase<Concrete>::replaceElementsIn( |
87 | Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) { |
88 | // Functor that replaces the given element if the new value is different, |
89 | // otherwise returns nullptr. |
90 | auto replaceIfDifferent = [&](auto element) { |
91 | auto replacement = static_cast<Concrete *>(this)->replace(element); |
92 | return (replacement && replacement != element) ? replacement : nullptr; |
93 | }; |
94 | |
95 | // Update the attribute dictionary. |
96 | if (replaceAttrs) { |
97 | if (auto newAttrs = replaceIfDifferent(op->getAttrDictionary())) |
98 | op->setAttrs(cast<DictionaryAttr>(newAttrs)); |
99 | } |
100 | |
101 | // If we aren't updating locations or types, we're done. |
102 | if (!replaceTypes && !replaceLocs) |
103 | return; |
104 | |
105 | // Update the location. |
106 | if (replaceLocs) { |
107 | if (Attribute newLoc = replaceIfDifferent(op->getLoc())) |
108 | op->setLoc(cast<LocationAttr>(Val&: newLoc)); |
109 | } |
110 | |
111 | // Update the result types. |
112 | if (replaceTypes) { |
113 | for (OpResult result : op->getResults()) |
114 | if (Type newType = replaceIfDifferent(result.getType())) |
115 | result.setType(newType); |
116 | } |
117 | |
118 | // Update any nested block arguments. |
119 | for (Region ®ion : op->getRegions()) { |
120 | for (Block &block : region) { |
121 | for (BlockArgument &arg : block.getArguments()) { |
122 | if (replaceLocs) { |
123 | if (Attribute newLoc = replaceIfDifferent(arg.getLoc())) |
124 | arg.setLoc(cast<LocationAttr>(Val&: newLoc)); |
125 | } |
126 | |
127 | if (replaceTypes) { |
128 | if (Type newType = replaceIfDifferent(arg.getType())) |
129 | arg.setType(newType); |
130 | } |
131 | } |
132 | } |
133 | } |
134 | } |
135 | |
136 | template <typename Concrete> |
137 | void detail::AttrTypeReplacerBase<Concrete>::recursivelyReplaceElementsIn( |
138 | Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) { |
139 | op->walk([&](Operation *nestedOp) { |
140 | replaceElementsIn(op: nestedOp, replaceAttrs, replaceLocs, replaceTypes); |
141 | }); |
142 | } |
143 | |
144 | template <typename T, typename Replacer> |
145 | static void updateSubElementImpl(T element, Replacer &replacer, |
146 | SmallVectorImpl<T> &newElements, |
147 | FailureOr<bool> &changed) { |
148 | // Bail early if we failed at any point. |
149 | if (failed(Result: changed)) |
150 | return; |
151 | |
152 | // Guard against potentially null inputs. We always map null to null. |
153 | if (!element) { |
154 | newElements.push_back(nullptr); |
155 | return; |
156 | } |
157 | |
158 | // Replace the element. |
159 | if (T result = replacer.replace(element)) { |
160 | newElements.push_back(result); |
161 | if (result != element) |
162 | changed = true; |
163 | } else { |
164 | changed = failure(); |
165 | } |
166 | } |
167 | |
168 | template <typename T, typename Replacer> |
169 | static T replaceSubElements(T interface, Replacer &replacer) { |
170 | // Walk the current sub-elements, replacing them as necessary. |
171 | SmallVector<Attribute, 16> newAttrs; |
172 | SmallVector<Type, 16> newTypes; |
173 | FailureOr<bool> changed = false; |
174 | interface.walkImmediateSubElements( |
175 | [&](Attribute element) { |
176 | updateSubElementImpl(element, replacer, newAttrs, changed); |
177 | }, |
178 | [&](Type element) { |
179 | updateSubElementImpl(element, replacer, newTypes, changed); |
180 | }); |
181 | if (failed(Result: changed)) |
182 | return nullptr; |
183 | |
184 | // If any sub-elements changed, use the new elements during the replacement. |
185 | T result = interface; |
186 | if (*changed) |
187 | result = interface.replaceImmediateSubElements(newAttrs, newTypes); |
188 | return result; |
189 | } |
190 | |
191 | /// Shared implementation of replacing a given attribute or type element. |
192 | template <typename T, typename ReplaceFns, typename Replacer> |
193 | static T replaceElementImpl(T element, ReplaceFns &replaceFns, |
194 | Replacer &replacer) { |
195 | T result = element; |
196 | WalkResult walkResult = WalkResult::advance(); |
197 | for (auto &replaceFn : llvm::reverse(replaceFns)) { |
198 | if (std::optional<std::pair<T, WalkResult>> newRes = replaceFn(element)) { |
199 | std::tie(result, walkResult) = *newRes; |
200 | break; |
201 | } |
202 | } |
203 | |
204 | // If an error occurred, return nullptr to indicate failure. |
205 | if (walkResult.wasInterrupted() || !result) { |
206 | return nullptr; |
207 | } |
208 | |
209 | // Handle replacing sub-elements if this element is also a container. |
210 | if (!walkResult.wasSkipped()) { |
211 | // Replace the sub elements of this element, bailing if we fail. |
212 | if (!(result = replaceSubElements(result, replacer))) { |
213 | return nullptr; |
214 | } |
215 | } |
216 | |
217 | return result; |
218 | } |
219 | |
220 | template <typename Concrete> |
221 | Attribute detail::AttrTypeReplacerBase<Concrete>::replaceBase(Attribute attr) { |
222 | return replaceElementImpl(attr, attrReplacementFns, |
223 | *static_cast<Concrete *>(this)); |
224 | } |
225 | |
226 | template <typename Concrete> |
227 | Type detail::AttrTypeReplacerBase<Concrete>::replaceBase(Type type) { |
228 | return replaceElementImpl(type, typeReplacementFns, |
229 | *static_cast<Concrete *>(this)); |
230 | } |
231 | |
232 | //===----------------------------------------------------------------------===// |
233 | /// AttrTypeReplacer |
234 | //===----------------------------------------------------------------------===// |
235 | |
236 | template class detail::AttrTypeReplacerBase<AttrTypeReplacer>; |
237 | |
238 | template <typename T> |
239 | T AttrTypeReplacer::cachedReplaceImpl(T element) { |
240 | const void *opaqueElement = element.getAsOpaquePointer(); |
241 | auto [it, inserted] = cache.try_emplace(Key: opaqueElement, Args&: opaqueElement); |
242 | if (!inserted) |
243 | return T::getFromOpaquePointer(it->second); |
244 | |
245 | T result = replaceBase(element); |
246 | |
247 | cache[opaqueElement] = result.getAsOpaquePointer(); |
248 | return result; |
249 | } |
250 | |
251 | Attribute AttrTypeReplacer::replace(Attribute attr) { |
252 | return cachedReplaceImpl(element: attr); |
253 | } |
254 | |
255 | Type AttrTypeReplacer::replace(Type type) { return cachedReplaceImpl(element: type); } |
256 | |
257 | //===----------------------------------------------------------------------===// |
258 | /// CyclicAttrTypeReplacer |
259 | //===----------------------------------------------------------------------===// |
260 | |
261 | template class detail::AttrTypeReplacerBase<CyclicAttrTypeReplacer>; |
262 | |
263 | CyclicAttrTypeReplacer::CyclicAttrTypeReplacer() |
264 | : cache([&](void *attr) { return breakCycleImpl(element: attr); }) {} |
265 | |
266 | void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn<Attribute> fn) { |
267 | attrCycleBreakerFns.emplace_back(args: std::move(fn)); |
268 | } |
269 | |
270 | void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn<Type> fn) { |
271 | typeCycleBreakerFns.emplace_back(args: std::move(fn)); |
272 | } |
273 | |
274 | template <typename T> |
275 | T CyclicAttrTypeReplacer::cachedReplaceImpl(T element) { |
276 | void *opaqueTaggedElement = AttrOrType(element).getOpaqueValue(); |
277 | CyclicReplacerCache<void *, const void *>::CacheEntry cacheEntry = |
278 | cache.lookupOrInit(element: opaqueTaggedElement); |
279 | if (auto resultOpt = cacheEntry.get()) |
280 | return T::getFromOpaquePointer(*resultOpt); |
281 | |
282 | T result = replaceBase(element); |
283 | |
284 | cacheEntry.resolve(result: result.getAsOpaquePointer()); |
285 | return result; |
286 | } |
287 | |
288 | Attribute CyclicAttrTypeReplacer::replace(Attribute attr) { |
289 | return cachedReplaceImpl(element: attr); |
290 | } |
291 | |
292 | Type CyclicAttrTypeReplacer::replace(Type type) { |
293 | return cachedReplaceImpl(element: type); |
294 | } |
295 | |
296 | std::optional<const void *> |
297 | CyclicAttrTypeReplacer::breakCycleImpl(void *element) { |
298 | AttrOrType attrType = AttrOrType::getFromOpaqueValue(VP: element); |
299 | if (auto attr = dyn_cast<Attribute>(Val&: attrType)) { |
300 | for (auto &cyclicReplaceFn : llvm::reverse(C&: attrCycleBreakerFns)) { |
301 | if (std::optional<Attribute> newRes = cyclicReplaceFn(attr)) { |
302 | return newRes->getAsOpaquePointer(); |
303 | } |
304 | } |
305 | } else { |
306 | auto type = dyn_cast<Type>(Val&: attrType); |
307 | for (auto &cyclicReplaceFn : llvm::reverse(C&: typeCycleBreakerFns)) { |
308 | if (std::optional<Type> newRes = cyclicReplaceFn(type)) { |
309 | return newRes->getAsOpaquePointer(); |
310 | } |
311 | } |
312 | } |
313 | return std::nullopt; |
314 | } |
315 | |
316 | //===----------------------------------------------------------------------===// |
317 | // AttrTypeImmediateSubElementWalker |
318 | //===----------------------------------------------------------------------===// |
319 | |
320 | void AttrTypeImmediateSubElementWalker::walk(Attribute element) { |
321 | if (element) |
322 | walkAttrsFn(element); |
323 | } |
324 | |
325 | void AttrTypeImmediateSubElementWalker::walk(Type element) { |
326 | if (element) |
327 | walkTypesFn(element); |
328 | } |
329 | |