1//===- AttrTypeSubElements.cpp - Attr and Type SubElement Interfaces ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/Operation.h"
10#include <optional>
11
12using namespace mlir;
13
14//===----------------------------------------------------------------------===//
15// AttrTypeWalker
16//===----------------------------------------------------------------------===//
17
18WalkResult AttrTypeWalker::walkImpl(Attribute attr, WalkOrder order) {
19 return walkImpl(element: attr, walkFns&: attrWalkFns, order);
20}
21WalkResult AttrTypeWalker::walkImpl(Type type, WalkOrder order) {
22 return walkImpl(element: type, walkFns&: typeWalkFns, order);
23}
24
25template <typename T, typename WalkFns>
26WalkResult AttrTypeWalker::walkImpl(T element, WalkFns &walkFns,
27 WalkOrder order) {
28 // Check if we've already walk this element before.
29 auto key = std::make_pair(element.getAsOpaquePointer(), (int)order);
30 auto it = visitedAttrTypes.find(key);
31 if (it != visitedAttrTypes.end())
32 return it->second;
33 visitedAttrTypes.try_emplace(key, WalkResult::advance());
34
35 // If we are walking in post order, walk the sub elements first.
36 if (order == WalkOrder::PostOrder) {
37 if (walkSubElements(element, order).wasInterrupted())
38 return visitedAttrTypes[key] = WalkResult::interrupt();
39 }
40
41 // Walk this element, bailing if skipped or interrupted.
42 for (auto &walkFn : llvm::reverse(walkFns)) {
43 WalkResult walkResult = walkFn(element);
44 if (walkResult.wasInterrupted())
45 return visitedAttrTypes[key] = WalkResult::interrupt();
46 if (walkResult.wasSkipped())
47 return WalkResult::advance();
48 }
49
50 // If we are walking in pre-order, walk the sub elements last.
51 if (order == WalkOrder::PreOrder) {
52 if (walkSubElements(element, order).wasInterrupted())
53 return WalkResult::interrupt();
54 }
55 return WalkResult::advance();
56}
57
58template <typename T>
59WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) {
60 WalkResult result = WalkResult::advance();
61 auto walkFn = [&](auto element) {
62 if (element && !result.wasInterrupted())
63 result = walkImpl(element, order);
64 };
65 interface.walkImmediateSubElements(walkFn, walkFn);
66 return result.wasInterrupted() ? result : WalkResult::advance();
67}
68
69//===----------------------------------------------------------------------===//
70/// AttrTypeReplacer
71//===----------------------------------------------------------------------===//
72
73void AttrTypeReplacer::addReplacement(ReplaceFn<Attribute> fn) {
74 attrReplacementFns.emplace_back(args: std::move(fn));
75}
76void AttrTypeReplacer::addReplacement(ReplaceFn<Type> fn) {
77 typeReplacementFns.push_back(x: std::move(fn));
78}
79
80void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs,
81 bool replaceLocs, bool replaceTypes) {
82 // Functor that replaces the given element if the new value is different,
83 // otherwise returns nullptr.
84 auto replaceIfDifferent = [&](auto element) {
85 auto replacement = replace(element);
86 return (replacement && replacement != element) ? replacement : nullptr;
87 };
88
89 // Update the attribute dictionary.
90 if (replaceAttrs) {
91 if (auto newAttrs = replaceIfDifferent(op->getAttrDictionary()))
92 op->setAttrs(cast<DictionaryAttr>(newAttrs));
93 }
94
95 // If we aren't updating locations or types, we're done.
96 if (!replaceTypes && !replaceLocs)
97 return;
98
99 // Update the location.
100 if (replaceLocs) {
101 if (Attribute newLoc = replaceIfDifferent(op->getLoc()))
102 op->setLoc(cast<LocationAttr>(Val&: newLoc));
103 }
104
105 // Update the result types.
106 if (replaceTypes) {
107 for (OpResult result : op->getResults())
108 if (Type newType = replaceIfDifferent(result.getType()))
109 result.setType(newType);
110 }
111
112 // Update any nested block arguments.
113 for (Region &region : op->getRegions()) {
114 for (Block &block : region) {
115 for (BlockArgument &arg : block.getArguments()) {
116 if (replaceLocs) {
117 if (Attribute newLoc = replaceIfDifferent(arg.getLoc()))
118 arg.setLoc(cast<LocationAttr>(Val&: newLoc));
119 }
120
121 if (replaceTypes) {
122 if (Type newType = replaceIfDifferent(arg.getType()))
123 arg.setType(newType);
124 }
125 }
126 }
127 }
128}
129
130void AttrTypeReplacer::recursivelyReplaceElementsIn(Operation *op,
131 bool replaceAttrs,
132 bool replaceLocs,
133 bool replaceTypes) {
134 op->walk(callback: [&](Operation *nestedOp) {
135 replaceElementsIn(op: nestedOp, replaceAttrs, replaceLocs, replaceTypes);
136 });
137}
138
139template <typename T>
140static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
141 SmallVectorImpl<T> &newElements,
142 FailureOr<bool> &changed) {
143 // Bail early if we failed at any point.
144 if (failed(result: changed))
145 return;
146
147 // Guard against potentially null inputs. We always map null to null.
148 if (!element) {
149 newElements.push_back(nullptr);
150 return;
151 }
152
153 // Replace the element.
154 if (T result = replacer.replace(element)) {
155 newElements.push_back(result);
156 if (result != element)
157 changed = true;
158 } else {
159 changed = failure();
160 }
161}
162
163template <typename T>
164T AttrTypeReplacer::replaceSubElements(T interface) {
165 // Walk the current sub-elements, replacing them as necessary.
166 SmallVector<Attribute, 16> newAttrs;
167 SmallVector<Type, 16> newTypes;
168 FailureOr<bool> changed = false;
169 interface.walkImmediateSubElements(
170 [&](Attribute element) {
171 updateSubElementImpl(element, replacer&: *this, newElements&: newAttrs, changed);
172 },
173 [&](Type element) {
174 updateSubElementImpl(element, replacer&: *this, newElements&: newTypes, changed);
175 });
176 if (failed(result: changed))
177 return nullptr;
178
179 // If any sub-elements changed, use the new elements during the replacement.
180 T result = interface;
181 if (*changed)
182 result = interface.replaceImmediateSubElements(newAttrs, newTypes);
183 return result;
184}
185
186/// Shared implementation of replacing a given attribute or type element.
187template <typename T, typename ReplaceFns>
188T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) {
189 const void *opaqueElement = element.getAsOpaquePointer();
190 auto [it, inserted] = attrTypeMap.try_emplace(Key: opaqueElement, Args&: opaqueElement);
191 if (!inserted)
192 return T::getFromOpaquePointer(it->second);
193
194 T result = element;
195 WalkResult walkResult = WalkResult::advance();
196 for (auto &replaceFn : llvm::reverse(replaceFns)) {
197 if (std::optional<std::pair<T, WalkResult>> newRes = replaceFn(element)) {
198 std::tie(result, walkResult) = *newRes;
199 break;
200 }
201 }
202
203 // If an error occurred, return nullptr to indicate failure.
204 if (walkResult.wasInterrupted() || !result) {
205 attrTypeMap[opaqueElement] = nullptr;
206 return nullptr;
207 }
208
209 // Handle replacing sub-elements if this element is also a container.
210 if (!walkResult.wasSkipped()) {
211 // Replace the sub elements of this element, bailing if we fail.
212 if (!(result = replaceSubElements(result))) {
213 attrTypeMap[opaqueElement] = nullptr;
214 return nullptr;
215 }
216 }
217
218 attrTypeMap[opaqueElement] = result.getAsOpaquePointer();
219 return result;
220}
221
222Attribute AttrTypeReplacer::replace(Attribute attr) {
223 return replaceImpl(element: attr, replaceFns&: attrReplacementFns);
224}
225
226Type AttrTypeReplacer::replace(Type type) {
227 return replaceImpl(element: type, replaceFns&: typeReplacementFns);
228}
229
230//===----------------------------------------------------------------------===//
231// AttrTypeImmediateSubElementWalker
232//===----------------------------------------------------------------------===//
233
234void AttrTypeImmediateSubElementWalker::walk(Attribute element) {
235 if (element)
236 walkAttrsFn(element);
237}
238
239void AttrTypeImmediateSubElementWalker::walk(Type element) {
240 if (element)
241 walkTypesFn(element);
242}
243

source code of mlir/lib/IR/AttrTypeSubElements.cpp