1 | //===- AttrTypeSubElements.cpp - Attr and Type SubElement Interfaces ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/IR/Operation.h" |
10 | #include <optional> |
11 | |
12 | using namespace mlir; |
13 | |
14 | //===----------------------------------------------------------------------===// |
15 | // AttrTypeWalker |
16 | //===----------------------------------------------------------------------===// |
17 | |
18 | WalkResult AttrTypeWalker::walkImpl(Attribute attr, WalkOrder order) { |
19 | return walkImpl(element: attr, walkFns&: attrWalkFns, order); |
20 | } |
21 | WalkResult AttrTypeWalker::walkImpl(Type type, WalkOrder order) { |
22 | return walkImpl(element: type, walkFns&: typeWalkFns, order); |
23 | } |
24 | |
25 | template <typename T, typename WalkFns> |
26 | WalkResult AttrTypeWalker::walkImpl(T element, WalkFns &walkFns, |
27 | WalkOrder order) { |
28 | // Check if we've already walk this element before. |
29 | auto key = std::make_pair(element.getAsOpaquePointer(), (int)order); |
30 | auto it = visitedAttrTypes.find(key); |
31 | if (it != visitedAttrTypes.end()) |
32 | return it->second; |
33 | visitedAttrTypes.try_emplace(key, WalkResult::advance()); |
34 | |
35 | // If we are walking in post order, walk the sub elements first. |
36 | if (order == WalkOrder::PostOrder) { |
37 | if (walkSubElements(element, order).wasInterrupted()) |
38 | return visitedAttrTypes[key] = WalkResult::interrupt(); |
39 | } |
40 | |
41 | // Walk this element, bailing if skipped or interrupted. |
42 | for (auto &walkFn : llvm::reverse(walkFns)) { |
43 | WalkResult walkResult = walkFn(element); |
44 | if (walkResult.wasInterrupted()) |
45 | return visitedAttrTypes[key] = WalkResult::interrupt(); |
46 | if (walkResult.wasSkipped()) |
47 | return WalkResult::advance(); |
48 | } |
49 | |
50 | // If we are walking in pre-order, walk the sub elements last. |
51 | if (order == WalkOrder::PreOrder) { |
52 | if (walkSubElements(element, order).wasInterrupted()) |
53 | return WalkResult::interrupt(); |
54 | } |
55 | return WalkResult::advance(); |
56 | } |
57 | |
58 | template <typename T> |
59 | WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) { |
60 | WalkResult result = WalkResult::advance(); |
61 | auto walkFn = [&](auto element) { |
62 | if (element && !result.wasInterrupted()) |
63 | result = walkImpl(element, order); |
64 | }; |
65 | interface.walkImmediateSubElements(walkFn, walkFn); |
66 | return result.wasInterrupted() ? result : WalkResult::advance(); |
67 | } |
68 | |
69 | //===----------------------------------------------------------------------===// |
70 | /// AttrTypeReplacer |
71 | //===----------------------------------------------------------------------===// |
72 | |
73 | void AttrTypeReplacer::addReplacement(ReplaceFn<Attribute> fn) { |
74 | attrReplacementFns.emplace_back(args: std::move(fn)); |
75 | } |
76 | void AttrTypeReplacer::addReplacement(ReplaceFn<Type> fn) { |
77 | typeReplacementFns.push_back(x: std::move(fn)); |
78 | } |
79 | |
80 | void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs, |
81 | bool replaceLocs, bool replaceTypes) { |
82 | // Functor that replaces the given element if the new value is different, |
83 | // otherwise returns nullptr. |
84 | auto replaceIfDifferent = [&](auto element) { |
85 | auto replacement = replace(element); |
86 | return (replacement && replacement != element) ? replacement : nullptr; |
87 | }; |
88 | |
89 | // Update the attribute dictionary. |
90 | if (replaceAttrs) { |
91 | if (auto newAttrs = replaceIfDifferent(op->getAttrDictionary())) |
92 | op->setAttrs(cast<DictionaryAttr>(newAttrs)); |
93 | } |
94 | |
95 | // If we aren't updating locations or types, we're done. |
96 | if (!replaceTypes && !replaceLocs) |
97 | return; |
98 | |
99 | // Update the location. |
100 | if (replaceLocs) { |
101 | if (Attribute newLoc = replaceIfDifferent(op->getLoc())) |
102 | op->setLoc(cast<LocationAttr>(Val&: newLoc)); |
103 | } |
104 | |
105 | // Update the result types. |
106 | if (replaceTypes) { |
107 | for (OpResult result : op->getResults()) |
108 | if (Type newType = replaceIfDifferent(result.getType())) |
109 | result.setType(newType); |
110 | } |
111 | |
112 | // Update any nested block arguments. |
113 | for (Region ®ion : op->getRegions()) { |
114 | for (Block &block : region) { |
115 | for (BlockArgument &arg : block.getArguments()) { |
116 | if (replaceLocs) { |
117 | if (Attribute newLoc = replaceIfDifferent(arg.getLoc())) |
118 | arg.setLoc(cast<LocationAttr>(Val&: newLoc)); |
119 | } |
120 | |
121 | if (replaceTypes) { |
122 | if (Type newType = replaceIfDifferent(arg.getType())) |
123 | arg.setType(newType); |
124 | } |
125 | } |
126 | } |
127 | } |
128 | } |
129 | |
130 | void AttrTypeReplacer::recursivelyReplaceElementsIn(Operation *op, |
131 | bool replaceAttrs, |
132 | bool replaceLocs, |
133 | bool replaceTypes) { |
134 | op->walk(callback: [&](Operation *nestedOp) { |
135 | replaceElementsIn(op: nestedOp, replaceAttrs, replaceLocs, replaceTypes); |
136 | }); |
137 | } |
138 | |
139 | template <typename T> |
140 | static void updateSubElementImpl(T element, AttrTypeReplacer &replacer, |
141 | SmallVectorImpl<T> &newElements, |
142 | FailureOr<bool> &changed) { |
143 | // Bail early if we failed at any point. |
144 | if (failed(result: changed)) |
145 | return; |
146 | |
147 | // Guard against potentially null inputs. We always map null to null. |
148 | if (!element) { |
149 | newElements.push_back(nullptr); |
150 | return; |
151 | } |
152 | |
153 | // Replace the element. |
154 | if (T result = replacer.replace(element)) { |
155 | newElements.push_back(result); |
156 | if (result != element) |
157 | changed = true; |
158 | } else { |
159 | changed = failure(); |
160 | } |
161 | } |
162 | |
163 | template <typename T> |
164 | T AttrTypeReplacer::replaceSubElements(T interface) { |
165 | // Walk the current sub-elements, replacing them as necessary. |
166 | SmallVector<Attribute, 16> newAttrs; |
167 | SmallVector<Type, 16> newTypes; |
168 | FailureOr<bool> changed = false; |
169 | interface.walkImmediateSubElements( |
170 | [&](Attribute element) { |
171 | updateSubElementImpl(element, replacer&: *this, newElements&: newAttrs, changed); |
172 | }, |
173 | [&](Type element) { |
174 | updateSubElementImpl(element, replacer&: *this, newElements&: newTypes, changed); |
175 | }); |
176 | if (failed(result: changed)) |
177 | return nullptr; |
178 | |
179 | // If any sub-elements changed, use the new elements during the replacement. |
180 | T result = interface; |
181 | if (*changed) |
182 | result = interface.replaceImmediateSubElements(newAttrs, newTypes); |
183 | return result; |
184 | } |
185 | |
186 | /// Shared implementation of replacing a given attribute or type element. |
187 | template <typename T, typename ReplaceFns> |
188 | T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) { |
189 | const void *opaqueElement = element.getAsOpaquePointer(); |
190 | auto [it, inserted] = attrTypeMap.try_emplace(Key: opaqueElement, Args&: opaqueElement); |
191 | if (!inserted) |
192 | return T::getFromOpaquePointer(it->second); |
193 | |
194 | T result = element; |
195 | WalkResult walkResult = WalkResult::advance(); |
196 | for (auto &replaceFn : llvm::reverse(replaceFns)) { |
197 | if (std::optional<std::pair<T, WalkResult>> newRes = replaceFn(element)) { |
198 | std::tie(result, walkResult) = *newRes; |
199 | break; |
200 | } |
201 | } |
202 | |
203 | // If an error occurred, return nullptr to indicate failure. |
204 | if (walkResult.wasInterrupted() || !result) { |
205 | attrTypeMap[opaqueElement] = nullptr; |
206 | return nullptr; |
207 | } |
208 | |
209 | // Handle replacing sub-elements if this element is also a container. |
210 | if (!walkResult.wasSkipped()) { |
211 | // Replace the sub elements of this element, bailing if we fail. |
212 | if (!(result = replaceSubElements(result))) { |
213 | attrTypeMap[opaqueElement] = nullptr; |
214 | return nullptr; |
215 | } |
216 | } |
217 | |
218 | attrTypeMap[opaqueElement] = result.getAsOpaquePointer(); |
219 | return result; |
220 | } |
221 | |
222 | Attribute AttrTypeReplacer::replace(Attribute attr) { |
223 | return replaceImpl(element: attr, replaceFns&: attrReplacementFns); |
224 | } |
225 | |
226 | Type AttrTypeReplacer::replace(Type type) { |
227 | return replaceImpl(element: type, replaceFns&: typeReplacementFns); |
228 | } |
229 | |
230 | //===----------------------------------------------------------------------===// |
231 | // AttrTypeImmediateSubElementWalker |
232 | //===----------------------------------------------------------------------===// |
233 | |
234 | void AttrTypeImmediateSubElementWalker::walk(Attribute element) { |
235 | if (element) |
236 | walkAttrsFn(element); |
237 | } |
238 | |
239 | void AttrTypeImmediateSubElementWalker::walk(Type element) { |
240 | if (element) |
241 | walkTypesFn(element); |
242 | } |
243 | |