| 1 | //===- AttrTypeSubElements.cpp - Attr and Type SubElement Interfaces ------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/IR/Operation.h" |
| 10 | #include <optional> |
| 11 | |
| 12 | using namespace mlir; |
| 13 | |
| 14 | //===----------------------------------------------------------------------===// |
| 15 | // AttrTypeWalker |
| 16 | //===----------------------------------------------------------------------===// |
| 17 | |
| 18 | WalkResult AttrTypeWalker::walkImpl(Attribute attr, WalkOrder order) { |
| 19 | return walkImpl(element: attr, walkFns&: attrWalkFns, order); |
| 20 | } |
| 21 | WalkResult AttrTypeWalker::walkImpl(Type type, WalkOrder order) { |
| 22 | return walkImpl(element: type, walkFns&: typeWalkFns, order); |
| 23 | } |
| 24 | |
| 25 | template <typename T, typename WalkFns> |
| 26 | WalkResult AttrTypeWalker::walkImpl(T element, WalkFns &walkFns, |
| 27 | WalkOrder order) { |
| 28 | // Check if we've already walk this element before. |
| 29 | auto key = std::make_pair(element.getAsOpaquePointer(), (int)order); |
| 30 | auto [it, inserted] = |
| 31 | visitedAttrTypes.try_emplace(key, WalkResult::advance()); |
| 32 | if (!inserted) |
| 33 | return it->second; |
| 34 | |
| 35 | // If we are walking in post order, walk the sub elements first. |
| 36 | if (order == WalkOrder::PostOrder) { |
| 37 | if (walkSubElements(element, order).wasInterrupted()) |
| 38 | return visitedAttrTypes[key] = WalkResult::interrupt(); |
| 39 | } |
| 40 | |
| 41 | // Walk this element, bailing if skipped or interrupted. |
| 42 | for (auto &walkFn : llvm::reverse(walkFns)) { |
| 43 | WalkResult walkResult = walkFn(element); |
| 44 | if (walkResult.wasInterrupted()) |
| 45 | return visitedAttrTypes[key] = WalkResult::interrupt(); |
| 46 | if (walkResult.wasSkipped()) |
| 47 | return WalkResult::advance(); |
| 48 | } |
| 49 | |
| 50 | // If we are walking in pre-order, walk the sub elements last. |
| 51 | if (order == WalkOrder::PreOrder) { |
| 52 | if (walkSubElements(element, order).wasInterrupted()) |
| 53 | return WalkResult::interrupt(); |
| 54 | } |
| 55 | return WalkResult::advance(); |
| 56 | } |
| 57 | |
| 58 | template <typename T> |
| 59 | WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) { |
| 60 | WalkResult result = WalkResult::advance(); |
| 61 | auto walkFn = [&](auto element) { |
| 62 | if (element && !result.wasInterrupted()) |
| 63 | result = walkImpl(element, order); |
| 64 | }; |
| 65 | interface.walkImmediateSubElements(walkFn, walkFn); |
| 66 | return result.wasInterrupted() ? result : WalkResult::advance(); |
| 67 | } |
| 68 | |
| 69 | //===----------------------------------------------------------------------===// |
| 70 | /// AttrTypeReplacerBase |
| 71 | //===----------------------------------------------------------------------===// |
| 72 | |
| 73 | template <typename Concrete> |
| 74 | void detail::AttrTypeReplacerBase<Concrete>::addReplacement( |
| 75 | ReplaceFn<Attribute> fn) { |
| 76 | attrReplacementFns.emplace_back(args: std::move(fn)); |
| 77 | } |
| 78 | |
| 79 | template <typename Concrete> |
| 80 | void detail::AttrTypeReplacerBase<Concrete>::addReplacement( |
| 81 | ReplaceFn<Type> fn) { |
| 82 | typeReplacementFns.push_back(x: std::move(fn)); |
| 83 | } |
| 84 | |
| 85 | template <typename Concrete> |
| 86 | void detail::AttrTypeReplacerBase<Concrete>::replaceElementsIn( |
| 87 | Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) { |
| 88 | // Functor that replaces the given element if the new value is different, |
| 89 | // otherwise returns nullptr. |
| 90 | auto replaceIfDifferent = [&](auto element) { |
| 91 | auto replacement = static_cast<Concrete *>(this)->replace(element); |
| 92 | return (replacement && replacement != element) ? replacement : nullptr; |
| 93 | }; |
| 94 | |
| 95 | // Update the attribute dictionary. |
| 96 | if (replaceAttrs) { |
| 97 | if (auto newAttrs = replaceIfDifferent(op->getAttrDictionary())) |
| 98 | op->setAttrs(cast<DictionaryAttr>(newAttrs)); |
| 99 | } |
| 100 | |
| 101 | // If we aren't updating locations or types, we're done. |
| 102 | if (!replaceTypes && !replaceLocs) |
| 103 | return; |
| 104 | |
| 105 | // Update the location. |
| 106 | if (replaceLocs) { |
| 107 | if (Attribute newLoc = replaceIfDifferent(op->getLoc())) |
| 108 | op->setLoc(cast<LocationAttr>(Val&: newLoc)); |
| 109 | } |
| 110 | |
| 111 | // Update the result types. |
| 112 | if (replaceTypes) { |
| 113 | for (OpResult result : op->getResults()) |
| 114 | if (Type newType = replaceIfDifferent(result.getType())) |
| 115 | result.setType(newType); |
| 116 | } |
| 117 | |
| 118 | // Update any nested block arguments. |
| 119 | for (Region ®ion : op->getRegions()) { |
| 120 | for (Block &block : region) { |
| 121 | for (BlockArgument &arg : block.getArguments()) { |
| 122 | if (replaceLocs) { |
| 123 | if (Attribute newLoc = replaceIfDifferent(arg.getLoc())) |
| 124 | arg.setLoc(cast<LocationAttr>(Val&: newLoc)); |
| 125 | } |
| 126 | |
| 127 | if (replaceTypes) { |
| 128 | if (Type newType = replaceIfDifferent(arg.getType())) |
| 129 | arg.setType(newType); |
| 130 | } |
| 131 | } |
| 132 | } |
| 133 | } |
| 134 | } |
| 135 | |
| 136 | template <typename Concrete> |
| 137 | void detail::AttrTypeReplacerBase<Concrete>::recursivelyReplaceElementsIn( |
| 138 | Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) { |
| 139 | op->walk([&](Operation *nestedOp) { |
| 140 | replaceElementsIn(op: nestedOp, replaceAttrs, replaceLocs, replaceTypes); |
| 141 | }); |
| 142 | } |
| 143 | |
| 144 | template <typename T, typename Replacer> |
| 145 | static void updateSubElementImpl(T element, Replacer &replacer, |
| 146 | SmallVectorImpl<T> &newElements, |
| 147 | FailureOr<bool> &changed) { |
| 148 | // Bail early if we failed at any point. |
| 149 | if (failed(Result: changed)) |
| 150 | return; |
| 151 | |
| 152 | // Guard against potentially null inputs. We always map null to null. |
| 153 | if (!element) { |
| 154 | newElements.push_back(nullptr); |
| 155 | return; |
| 156 | } |
| 157 | |
| 158 | // Replace the element. |
| 159 | if (T result = replacer.replace(element)) { |
| 160 | newElements.push_back(result); |
| 161 | if (result != element) |
| 162 | changed = true; |
| 163 | } else { |
| 164 | changed = failure(); |
| 165 | } |
| 166 | } |
| 167 | |
| 168 | template <typename T, typename Replacer> |
| 169 | static T replaceSubElements(T interface, Replacer &replacer) { |
| 170 | // Walk the current sub-elements, replacing them as necessary. |
| 171 | SmallVector<Attribute, 16> newAttrs; |
| 172 | SmallVector<Type, 16> newTypes; |
| 173 | FailureOr<bool> changed = false; |
| 174 | interface.walkImmediateSubElements( |
| 175 | [&](Attribute element) { |
| 176 | updateSubElementImpl(element, replacer, newAttrs, changed); |
| 177 | }, |
| 178 | [&](Type element) { |
| 179 | updateSubElementImpl(element, replacer, newTypes, changed); |
| 180 | }); |
| 181 | if (failed(Result: changed)) |
| 182 | return nullptr; |
| 183 | |
| 184 | // If any sub-elements changed, use the new elements during the replacement. |
| 185 | T result = interface; |
| 186 | if (*changed) |
| 187 | result = interface.replaceImmediateSubElements(newAttrs, newTypes); |
| 188 | return result; |
| 189 | } |
| 190 | |
| 191 | /// Shared implementation of replacing a given attribute or type element. |
| 192 | template <typename T, typename ReplaceFns, typename Replacer> |
| 193 | static T replaceElementImpl(T element, ReplaceFns &replaceFns, |
| 194 | Replacer &replacer) { |
| 195 | T result = element; |
| 196 | WalkResult walkResult = WalkResult::advance(); |
| 197 | for (auto &replaceFn : llvm::reverse(replaceFns)) { |
| 198 | if (std::optional<std::pair<T, WalkResult>> newRes = replaceFn(element)) { |
| 199 | std::tie(result, walkResult) = *newRes; |
| 200 | break; |
| 201 | } |
| 202 | } |
| 203 | |
| 204 | // If an error occurred, return nullptr to indicate failure. |
| 205 | if (walkResult.wasInterrupted() || !result) { |
| 206 | return nullptr; |
| 207 | } |
| 208 | |
| 209 | // Handle replacing sub-elements if this element is also a container. |
| 210 | if (!walkResult.wasSkipped()) { |
| 211 | // Replace the sub elements of this element, bailing if we fail. |
| 212 | if (!(result = replaceSubElements(result, replacer))) { |
| 213 | return nullptr; |
| 214 | } |
| 215 | } |
| 216 | |
| 217 | return result; |
| 218 | } |
| 219 | |
| 220 | template <typename Concrete> |
| 221 | Attribute detail::AttrTypeReplacerBase<Concrete>::replaceBase(Attribute attr) { |
| 222 | return replaceElementImpl(attr, attrReplacementFns, |
| 223 | *static_cast<Concrete *>(this)); |
| 224 | } |
| 225 | |
| 226 | template <typename Concrete> |
| 227 | Type detail::AttrTypeReplacerBase<Concrete>::replaceBase(Type type) { |
| 228 | return replaceElementImpl(type, typeReplacementFns, |
| 229 | *static_cast<Concrete *>(this)); |
| 230 | } |
| 231 | |
| 232 | //===----------------------------------------------------------------------===// |
| 233 | /// AttrTypeReplacer |
| 234 | //===----------------------------------------------------------------------===// |
| 235 | |
| 236 | template class detail::AttrTypeReplacerBase<AttrTypeReplacer>; |
| 237 | |
| 238 | template <typename T> |
| 239 | T AttrTypeReplacer::cachedReplaceImpl(T element) { |
| 240 | const void *opaqueElement = element.getAsOpaquePointer(); |
| 241 | auto [it, inserted] = cache.try_emplace(Key: opaqueElement, Args&: opaqueElement); |
| 242 | if (!inserted) |
| 243 | return T::getFromOpaquePointer(it->second); |
| 244 | |
| 245 | T result = replaceBase(element); |
| 246 | |
| 247 | cache[opaqueElement] = result.getAsOpaquePointer(); |
| 248 | return result; |
| 249 | } |
| 250 | |
| 251 | Attribute AttrTypeReplacer::replace(Attribute attr) { |
| 252 | return cachedReplaceImpl(element: attr); |
| 253 | } |
| 254 | |
| 255 | Type AttrTypeReplacer::replace(Type type) { return cachedReplaceImpl(element: type); } |
| 256 | |
| 257 | //===----------------------------------------------------------------------===// |
| 258 | /// CyclicAttrTypeReplacer |
| 259 | //===----------------------------------------------------------------------===// |
| 260 | |
| 261 | template class detail::AttrTypeReplacerBase<CyclicAttrTypeReplacer>; |
| 262 | |
| 263 | CyclicAttrTypeReplacer::CyclicAttrTypeReplacer() |
| 264 | : cache([&](void *attr) { return breakCycleImpl(element: attr); }) {} |
| 265 | |
| 266 | void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn<Attribute> fn) { |
| 267 | attrCycleBreakerFns.emplace_back(args: std::move(fn)); |
| 268 | } |
| 269 | |
| 270 | void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn<Type> fn) { |
| 271 | typeCycleBreakerFns.emplace_back(args: std::move(fn)); |
| 272 | } |
| 273 | |
| 274 | template <typename T> |
| 275 | T CyclicAttrTypeReplacer::cachedReplaceImpl(T element) { |
| 276 | void *opaqueTaggedElement = AttrOrType(element).getOpaqueValue(); |
| 277 | CyclicReplacerCache<void *, const void *>::CacheEntry cacheEntry = |
| 278 | cache.lookupOrInit(element: opaqueTaggedElement); |
| 279 | if (auto resultOpt = cacheEntry.get()) |
| 280 | return T::getFromOpaquePointer(*resultOpt); |
| 281 | |
| 282 | T result = replaceBase(element); |
| 283 | |
| 284 | cacheEntry.resolve(result: result.getAsOpaquePointer()); |
| 285 | return result; |
| 286 | } |
| 287 | |
| 288 | Attribute CyclicAttrTypeReplacer::replace(Attribute attr) { |
| 289 | return cachedReplaceImpl(element: attr); |
| 290 | } |
| 291 | |
| 292 | Type CyclicAttrTypeReplacer::replace(Type type) { |
| 293 | return cachedReplaceImpl(element: type); |
| 294 | } |
| 295 | |
| 296 | std::optional<const void *> |
| 297 | CyclicAttrTypeReplacer::breakCycleImpl(void *element) { |
| 298 | AttrOrType attrType = AttrOrType::getFromOpaqueValue(VP: element); |
| 299 | if (auto attr = dyn_cast<Attribute>(Val&: attrType)) { |
| 300 | for (auto &cyclicReplaceFn : llvm::reverse(C&: attrCycleBreakerFns)) { |
| 301 | if (std::optional<Attribute> newRes = cyclicReplaceFn(attr)) { |
| 302 | return newRes->getAsOpaquePointer(); |
| 303 | } |
| 304 | } |
| 305 | } else { |
| 306 | auto type = dyn_cast<Type>(Val&: attrType); |
| 307 | for (auto &cyclicReplaceFn : llvm::reverse(C&: typeCycleBreakerFns)) { |
| 308 | if (std::optional<Type> newRes = cyclicReplaceFn(type)) { |
| 309 | return newRes->getAsOpaquePointer(); |
| 310 | } |
| 311 | } |
| 312 | } |
| 313 | return std::nullopt; |
| 314 | } |
| 315 | |
| 316 | //===----------------------------------------------------------------------===// |
| 317 | // AttrTypeImmediateSubElementWalker |
| 318 | //===----------------------------------------------------------------------===// |
| 319 | |
| 320 | void AttrTypeImmediateSubElementWalker::walk(Attribute element) { |
| 321 | if (element) |
| 322 | walkAttrsFn(element); |
| 323 | } |
| 324 | |
| 325 | void AttrTypeImmediateSubElementWalker::walk(Type element) { |
| 326 | if (element) |
| 327 | walkTypesFn(element); |
| 328 | } |
| 329 | |