| 1 | //===- LoopAnnotationImporter.cpp - Loop annotation import ----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "LoopAnnotationImporter.h" |
| 10 | #include "llvm/IR/Constants.h" |
| 11 | |
| 12 | using namespace mlir; |
| 13 | using namespace mlir::LLVM; |
| 14 | using namespace mlir::LLVM::detail; |
| 15 | |
| 16 | namespace { |
| 17 | /// Helper class that keeps the state of one metadata to attribute conversion. |
| 18 | struct LoopMetadataConversion { |
| 19 | LoopMetadataConversion(const llvm::MDNode *node, Location loc, |
| 20 | LoopAnnotationImporter &loopAnnotationImporter) |
| 21 | : node(node), loc(loc), loopAnnotationImporter(loopAnnotationImporter), |
| 22 | ctx(loc->getContext()){}; |
| 23 | /// Converts this structs loop metadata node into a LoopAnnotationAttr. |
| 24 | LoopAnnotationAttr convert(); |
| 25 | |
| 26 | /// Initializes the shared state for the conversion member functions. |
| 27 | LogicalResult initConversionState(); |
| 28 | |
| 29 | /// Helper function to get and erase a property. |
| 30 | const llvm::MDNode *lookupAndEraseProperty(StringRef name); |
| 31 | |
| 32 | /// Helper functions to lookup and convert MDNodes into a specifc attribute |
| 33 | /// kind. These functions return null-attributes if there is no node with the |
| 34 | /// specified name, or failure, if the node is ill-formatted. |
| 35 | FailureOr<BoolAttr> lookupUnitNode(StringRef name); |
| 36 | FailureOr<BoolAttr> lookupBoolNode(StringRef name, bool negated = false); |
| 37 | FailureOr<BoolAttr> lookupIntNodeAsBoolAttr(StringRef name); |
| 38 | FailureOr<IntegerAttr> lookupIntNode(StringRef name); |
| 39 | FailureOr<llvm::MDNode *> lookupMDNode(StringRef name); |
| 40 | FailureOr<SmallVector<llvm::MDNode *>> lookupMDNodes(StringRef name); |
| 41 | FailureOr<LoopAnnotationAttr> lookupFollowupNode(StringRef name); |
| 42 | FailureOr<BoolAttr> lookupBooleanUnitNode(StringRef enableName, |
| 43 | StringRef disableName, |
| 44 | bool negated = false); |
| 45 | |
| 46 | /// Conversion functions for sub-attributes. |
| 47 | FailureOr<LoopVectorizeAttr> convertVectorizeAttr(); |
| 48 | FailureOr<LoopInterleaveAttr> convertInterleaveAttr(); |
| 49 | FailureOr<LoopUnrollAttr> convertUnrollAttr(); |
| 50 | FailureOr<LoopUnrollAndJamAttr> convertUnrollAndJamAttr(); |
| 51 | FailureOr<LoopLICMAttr> convertLICMAttr(); |
| 52 | FailureOr<LoopDistributeAttr> convertDistributeAttr(); |
| 53 | FailureOr<LoopPipelineAttr> convertPipelineAttr(); |
| 54 | FailureOr<LoopPeeledAttr> convertPeeledAttr(); |
| 55 | FailureOr<LoopUnswitchAttr> convertUnswitchAttr(); |
| 56 | FailureOr<SmallVector<AccessGroupAttr>> convertParallelAccesses(); |
| 57 | FusedLoc convertStartLoc(); |
| 58 | FailureOr<FusedLoc> convertEndLoc(); |
| 59 | |
| 60 | llvm::SmallVector<llvm::DILocation *, 2> locations; |
| 61 | llvm::StringMap<const llvm::MDNode *> propertyMap; |
| 62 | const llvm::MDNode *node; |
| 63 | Location loc; |
| 64 | LoopAnnotationImporter &loopAnnotationImporter; |
| 65 | MLIRContext *ctx; |
| 66 | }; |
| 67 | } // namespace |
| 68 | |
| 69 | LogicalResult LoopMetadataConversion::initConversionState() { |
| 70 | // Check if it's a valid node. |
| 71 | if (node->getNumOperands() == 0 || |
| 72 | dyn_cast<llvm::MDNode>(Val: node->getOperand(I: 0)) != node) |
| 73 | return emitWarning(loc) << "invalid loop node" ; |
| 74 | |
| 75 | for (const llvm::MDOperand &operand : llvm::drop_begin(RangeOrContainer: node->operands())) { |
| 76 | if (auto *diLoc = dyn_cast<llvm::DILocation>(Val: operand)) { |
| 77 | locations.push_back(Elt: diLoc); |
| 78 | continue; |
| 79 | } |
| 80 | |
| 81 | auto *property = dyn_cast<llvm::MDNode>(Val: operand); |
| 82 | if (!property) |
| 83 | return emitWarning(loc) << "expected all loop properties to be either " |
| 84 | "debug locations or metadata nodes" ; |
| 85 | |
| 86 | if (property->getNumOperands() == 0) |
| 87 | return emitWarning(loc) << "cannot import empty loop property" ; |
| 88 | |
| 89 | auto *nameNode = dyn_cast<llvm::MDString>(Val: property->getOperand(I: 0)); |
| 90 | if (!nameNode) |
| 91 | return emitWarning(loc) << "cannot import loop property without a name" ; |
| 92 | StringRef name = nameNode->getString(); |
| 93 | |
| 94 | bool succ = propertyMap.try_emplace(Key: name, Args&: property).second; |
| 95 | if (!succ) |
| 96 | return emitWarning(loc) |
| 97 | << "cannot import loop properties with duplicated names " << name; |
| 98 | } |
| 99 | |
| 100 | return success(); |
| 101 | } |
| 102 | |
| 103 | const llvm::MDNode * |
| 104 | LoopMetadataConversion::lookupAndEraseProperty(StringRef name) { |
| 105 | auto it = propertyMap.find(Key: name); |
| 106 | if (it == propertyMap.end()) |
| 107 | return nullptr; |
| 108 | const llvm::MDNode *property = it->getValue(); |
| 109 | propertyMap.erase(I: it); |
| 110 | return property; |
| 111 | } |
| 112 | |
| 113 | FailureOr<BoolAttr> LoopMetadataConversion::lookupUnitNode(StringRef name) { |
| 114 | const llvm::MDNode *property = lookupAndEraseProperty(name); |
| 115 | if (!property) |
| 116 | return BoolAttr(nullptr); |
| 117 | |
| 118 | if (property->getNumOperands() != 1) |
| 119 | return emitWarning(loc) |
| 120 | << "expected metadata node " << name << " to hold no value" ; |
| 121 | |
| 122 | return BoolAttr::get(context: ctx, value: true); |
| 123 | } |
| 124 | |
| 125 | FailureOr<BoolAttr> LoopMetadataConversion::lookupBooleanUnitNode( |
| 126 | StringRef enableName, StringRef disableName, bool negated) { |
| 127 | auto enable = lookupUnitNode(name: enableName); |
| 128 | auto disable = lookupUnitNode(name: disableName); |
| 129 | if (failed(Result: enable) || failed(Result: disable)) |
| 130 | return failure(); |
| 131 | |
| 132 | if (*enable && *disable) |
| 133 | return emitWarning(loc) |
| 134 | << "expected metadata nodes " << enableName << " and " << disableName |
| 135 | << " to be mutually exclusive." ; |
| 136 | |
| 137 | if (*enable) |
| 138 | return BoolAttr::get(context: ctx, value: !negated); |
| 139 | |
| 140 | if (*disable) |
| 141 | return BoolAttr::get(context: ctx, value: negated); |
| 142 | return BoolAttr(nullptr); |
| 143 | } |
| 144 | |
| 145 | FailureOr<BoolAttr> LoopMetadataConversion::lookupBoolNode(StringRef name, |
| 146 | bool negated) { |
| 147 | const llvm::MDNode *property = lookupAndEraseProperty(name); |
| 148 | if (!property) |
| 149 | return BoolAttr(nullptr); |
| 150 | |
| 151 | auto emitNodeWarning = [&]() { |
| 152 | return emitWarning(loc) |
| 153 | << "expected metadata node " << name << " to hold a boolean value" ; |
| 154 | }; |
| 155 | |
| 156 | if (property->getNumOperands() != 2) |
| 157 | return emitNodeWarning(); |
| 158 | llvm::ConstantInt *val = |
| 159 | llvm::mdconst::dyn_extract<llvm::ConstantInt>(MD: property->getOperand(I: 1)); |
| 160 | if (!val || val->getBitWidth() != 1) |
| 161 | return emitNodeWarning(); |
| 162 | |
| 163 | return BoolAttr::get(context: ctx, value: val->getValue().getLimitedValue(Limit: 1) ^ negated); |
| 164 | } |
| 165 | |
| 166 | FailureOr<BoolAttr> |
| 167 | LoopMetadataConversion::lookupIntNodeAsBoolAttr(StringRef name) { |
| 168 | const llvm::MDNode *property = lookupAndEraseProperty(name); |
| 169 | if (!property) |
| 170 | return BoolAttr(nullptr); |
| 171 | |
| 172 | auto emitNodeWarning = [&]() { |
| 173 | return emitWarning(loc) |
| 174 | << "expected metadata node " << name << " to hold an integer value" ; |
| 175 | }; |
| 176 | |
| 177 | if (property->getNumOperands() != 2) |
| 178 | return emitNodeWarning(); |
| 179 | llvm::ConstantInt *val = |
| 180 | llvm::mdconst::dyn_extract<llvm::ConstantInt>(MD: property->getOperand(I: 1)); |
| 181 | if (!val || val->getBitWidth() != 32) |
| 182 | return emitNodeWarning(); |
| 183 | |
| 184 | return BoolAttr::get(context: ctx, value: val->getValue().getLimitedValue(Limit: 1)); |
| 185 | } |
| 186 | |
| 187 | FailureOr<IntegerAttr> LoopMetadataConversion::lookupIntNode(StringRef name) { |
| 188 | const llvm::MDNode *property = lookupAndEraseProperty(name); |
| 189 | if (!property) |
| 190 | return IntegerAttr(nullptr); |
| 191 | |
| 192 | auto emitNodeWarning = [&]() { |
| 193 | return emitWarning(loc) |
| 194 | << "expected metadata node " << name << " to hold an i32 value" ; |
| 195 | }; |
| 196 | |
| 197 | if (property->getNumOperands() != 2) |
| 198 | return emitNodeWarning(); |
| 199 | |
| 200 | llvm::ConstantInt *val = |
| 201 | llvm::mdconst::dyn_extract<llvm::ConstantInt>(MD: property->getOperand(I: 1)); |
| 202 | if (!val || val->getBitWidth() != 32) |
| 203 | return emitNodeWarning(); |
| 204 | |
| 205 | return IntegerAttr::get(IntegerType::get(ctx, 32), |
| 206 | val->getValue().getLimitedValue()); |
| 207 | } |
| 208 | |
| 209 | FailureOr<llvm::MDNode *> LoopMetadataConversion::lookupMDNode(StringRef name) { |
| 210 | const llvm::MDNode *property = lookupAndEraseProperty(name); |
| 211 | if (!property) |
| 212 | return nullptr; |
| 213 | |
| 214 | auto emitNodeWarning = [&]() { |
| 215 | return emitWarning(loc) |
| 216 | << "expected metadata node " << name << " to hold an MDNode" ; |
| 217 | }; |
| 218 | |
| 219 | if (property->getNumOperands() != 2) |
| 220 | return emitNodeWarning(); |
| 221 | |
| 222 | auto *node = dyn_cast<llvm::MDNode>(Val: property->getOperand(I: 1)); |
| 223 | if (!node) |
| 224 | return emitNodeWarning(); |
| 225 | |
| 226 | return node; |
| 227 | } |
| 228 | |
| 229 | FailureOr<SmallVector<llvm::MDNode *>> |
| 230 | LoopMetadataConversion::lookupMDNodes(StringRef name) { |
| 231 | const llvm::MDNode *property = lookupAndEraseProperty(name); |
| 232 | SmallVector<llvm::MDNode *> res; |
| 233 | if (!property) |
| 234 | return res; |
| 235 | |
| 236 | auto emitNodeWarning = [&]() { |
| 237 | return emitWarning(loc) << "expected metadata node " << name |
| 238 | << " to hold one or multiple MDNodes" ; |
| 239 | }; |
| 240 | |
| 241 | if (property->getNumOperands() < 2) |
| 242 | return emitNodeWarning(); |
| 243 | |
| 244 | for (unsigned i = 1, e = property->getNumOperands(); i < e; ++i) { |
| 245 | auto *node = dyn_cast<llvm::MDNode>(Val: property->getOperand(I: i)); |
| 246 | if (!node) |
| 247 | return emitNodeWarning(); |
| 248 | res.push_back(Elt: node); |
| 249 | } |
| 250 | |
| 251 | return res; |
| 252 | } |
| 253 | |
| 254 | FailureOr<LoopAnnotationAttr> |
| 255 | LoopMetadataConversion::lookupFollowupNode(StringRef name) { |
| 256 | auto node = lookupMDNode(name); |
| 257 | if (failed(Result: node)) |
| 258 | return failure(); |
| 259 | if (*node == nullptr) |
| 260 | return LoopAnnotationAttr(nullptr); |
| 261 | |
| 262 | return loopAnnotationImporter.translateLoopAnnotation(*node, loc); |
| 263 | } |
| 264 | |
| 265 | static bool isEmptyOrNull(const Attribute attr) { return !attr; } |
| 266 | |
| 267 | template <typename T> |
| 268 | static bool isEmptyOrNull(const SmallVectorImpl<T> &vec) { |
| 269 | return vec.empty(); |
| 270 | } |
| 271 | |
| 272 | /// Helper function that only creates and attribute of type T if all argument |
| 273 | /// conversion were successfull and at least one of them holds a non-null value. |
| 274 | template <typename T, typename... P> |
| 275 | static T createIfNonNull(MLIRContext *ctx, const P &...args) { |
| 276 | bool anyFailed = (failed(args) || ...); |
| 277 | if (anyFailed) |
| 278 | return {}; |
| 279 | |
| 280 | bool allEmpty = (isEmptyOrNull(*args) && ...); |
| 281 | if (allEmpty) |
| 282 | return {}; |
| 283 | |
| 284 | return T::get(ctx, *args...); |
| 285 | } |
| 286 | |
| 287 | FailureOr<LoopVectorizeAttr> LoopMetadataConversion::convertVectorizeAttr() { |
| 288 | FailureOr<BoolAttr> enable = |
| 289 | lookupBoolNode(name: "llvm.loop.vectorize.enable" , negated: true); |
| 290 | FailureOr<BoolAttr> predicateEnable = |
| 291 | lookupBoolNode(name: "llvm.loop.vectorize.predicate.enable" ); |
| 292 | FailureOr<BoolAttr> scalableEnable = |
| 293 | lookupBoolNode(name: "llvm.loop.vectorize.scalable.enable" ); |
| 294 | FailureOr<IntegerAttr> width = lookupIntNode("llvm.loop.vectorize.width" ); |
| 295 | FailureOr<LoopAnnotationAttr> followupVec = |
| 296 | lookupFollowupNode("llvm.loop.vectorize.followup_vectorized" ); |
| 297 | FailureOr<LoopAnnotationAttr> followupEpi = |
| 298 | lookupFollowupNode("llvm.loop.vectorize.followup_epilogue" ); |
| 299 | FailureOr<LoopAnnotationAttr> followupAll = |
| 300 | lookupFollowupNode("llvm.loop.vectorize.followup_all" ); |
| 301 | |
| 302 | return createIfNonNull<LoopVectorizeAttr>(ctx, enable, predicateEnable, |
| 303 | scalableEnable, width, followupVec, |
| 304 | followupEpi, followupAll); |
| 305 | } |
| 306 | |
| 307 | FailureOr<LoopInterleaveAttr> LoopMetadataConversion::convertInterleaveAttr() { |
| 308 | FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.interleave.count" ); |
| 309 | return createIfNonNull<LoopInterleaveAttr>(ctx, count); |
| 310 | } |
| 311 | |
| 312 | FailureOr<LoopUnrollAttr> LoopMetadataConversion::convertUnrollAttr() { |
| 313 | FailureOr<BoolAttr> disable = lookupBooleanUnitNode( |
| 314 | enableName: "llvm.loop.unroll.enable" , disableName: "llvm.loop.unroll.disable" , /*negated=*/true); |
| 315 | FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.unroll.count" ); |
| 316 | FailureOr<BoolAttr> runtimeDisable = |
| 317 | lookupUnitNode(name: "llvm.loop.unroll.runtime.disable" ); |
| 318 | FailureOr<BoolAttr> full = lookupUnitNode(name: "llvm.loop.unroll.full" ); |
| 319 | FailureOr<LoopAnnotationAttr> followupUnrolled = |
| 320 | lookupFollowupNode("llvm.loop.unroll.followup_unrolled" ); |
| 321 | FailureOr<LoopAnnotationAttr> followupRemainder = |
| 322 | lookupFollowupNode("llvm.loop.unroll.followup_remainder" ); |
| 323 | FailureOr<LoopAnnotationAttr> followupAll = |
| 324 | lookupFollowupNode("llvm.loop.unroll.followup_all" ); |
| 325 | |
| 326 | return createIfNonNull<LoopUnrollAttr>(ctx, disable, count, runtimeDisable, |
| 327 | full, followupUnrolled, |
| 328 | followupRemainder, followupAll); |
| 329 | } |
| 330 | |
| 331 | FailureOr<LoopUnrollAndJamAttr> |
| 332 | LoopMetadataConversion::convertUnrollAndJamAttr() { |
| 333 | FailureOr<BoolAttr> disable = lookupBooleanUnitNode( |
| 334 | enableName: "llvm.loop.unroll_and_jam.enable" , disableName: "llvm.loop.unroll_and_jam.disable" , |
| 335 | /*negated=*/true); |
| 336 | FailureOr<IntegerAttr> count = |
| 337 | lookupIntNode("llvm.loop.unroll_and_jam.count" ); |
| 338 | FailureOr<LoopAnnotationAttr> followupOuter = |
| 339 | lookupFollowupNode("llvm.loop.unroll_and_jam.followup_outer" ); |
| 340 | FailureOr<LoopAnnotationAttr> followupInner = |
| 341 | lookupFollowupNode("llvm.loop.unroll_and_jam.followup_inner" ); |
| 342 | FailureOr<LoopAnnotationAttr> followupRemainderOuter = |
| 343 | lookupFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_outer" ); |
| 344 | FailureOr<LoopAnnotationAttr> followupRemainderInner = |
| 345 | lookupFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_inner" ); |
| 346 | FailureOr<LoopAnnotationAttr> followupAll = |
| 347 | lookupFollowupNode("llvm.loop.unroll_and_jam.followup_all" ); |
| 348 | return createIfNonNull<LoopUnrollAndJamAttr>( |
| 349 | ctx, disable, count, followupOuter, followupInner, followupRemainderOuter, |
| 350 | followupRemainderInner, followupAll); |
| 351 | } |
| 352 | |
| 353 | FailureOr<LoopLICMAttr> LoopMetadataConversion::convertLICMAttr() { |
| 354 | FailureOr<BoolAttr> disable = lookupUnitNode(name: "llvm.licm.disable" ); |
| 355 | FailureOr<BoolAttr> versioningDisable = |
| 356 | lookupUnitNode(name: "llvm.loop.licm_versioning.disable" ); |
| 357 | return createIfNonNull<LoopLICMAttr>(ctx, disable, versioningDisable); |
| 358 | } |
| 359 | |
| 360 | FailureOr<LoopDistributeAttr> LoopMetadataConversion::convertDistributeAttr() { |
| 361 | FailureOr<BoolAttr> disable = |
| 362 | lookupBoolNode(name: "llvm.loop.distribute.enable" , negated: true); |
| 363 | FailureOr<LoopAnnotationAttr> followupCoincident = |
| 364 | lookupFollowupNode("llvm.loop.distribute.followup_coincident" ); |
| 365 | FailureOr<LoopAnnotationAttr> followupSequential = |
| 366 | lookupFollowupNode("llvm.loop.distribute.followup_sequential" ); |
| 367 | FailureOr<LoopAnnotationAttr> followupFallback = |
| 368 | lookupFollowupNode("llvm.loop.distribute.followup_fallback" ); |
| 369 | FailureOr<LoopAnnotationAttr> followupAll = |
| 370 | lookupFollowupNode("llvm.loop.distribute.followup_all" ); |
| 371 | return createIfNonNull<LoopDistributeAttr>(ctx, disable, followupCoincident, |
| 372 | followupSequential, |
| 373 | followupFallback, followupAll); |
| 374 | } |
| 375 | |
| 376 | FailureOr<LoopPipelineAttr> LoopMetadataConversion::convertPipelineAttr() { |
| 377 | FailureOr<BoolAttr> disable = lookupBoolNode(name: "llvm.loop.pipeline.disable" ); |
| 378 | FailureOr<IntegerAttr> initiationinterval = |
| 379 | lookupIntNode("llvm.loop.pipeline.initiationinterval" ); |
| 380 | return createIfNonNull<LoopPipelineAttr>(ctx, disable, initiationinterval); |
| 381 | } |
| 382 | |
| 383 | FailureOr<LoopPeeledAttr> LoopMetadataConversion::convertPeeledAttr() { |
| 384 | FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.peeled.count" ); |
| 385 | return createIfNonNull<LoopPeeledAttr>(ctx, count); |
| 386 | } |
| 387 | |
| 388 | FailureOr<LoopUnswitchAttr> LoopMetadataConversion::convertUnswitchAttr() { |
| 389 | FailureOr<BoolAttr> partialDisable = |
| 390 | lookupUnitNode(name: "llvm.loop.unswitch.partial.disable" ); |
| 391 | return createIfNonNull<LoopUnswitchAttr>(ctx, partialDisable); |
| 392 | } |
| 393 | |
| 394 | FailureOr<SmallVector<AccessGroupAttr>> |
| 395 | LoopMetadataConversion::convertParallelAccesses() { |
| 396 | FailureOr<SmallVector<llvm::MDNode *>> nodes = |
| 397 | lookupMDNodes(name: "llvm.loop.parallel_accesses" ); |
| 398 | if (failed(Result: nodes)) |
| 399 | return failure(); |
| 400 | SmallVector<AccessGroupAttr> refs; |
| 401 | for (llvm::MDNode *node : *nodes) { |
| 402 | FailureOr<SmallVector<AccessGroupAttr>> accessGroups = |
| 403 | loopAnnotationImporter.lookupAccessGroupAttrs(node); |
| 404 | if (failed(accessGroups)) { |
| 405 | emitWarning(loc) << "could not lookup access group" ; |
| 406 | continue; |
| 407 | } |
| 408 | llvm::append_range(refs, *accessGroups); |
| 409 | } |
| 410 | return refs; |
| 411 | } |
| 412 | |
| 413 | FusedLoc LoopMetadataConversion::convertStartLoc() { |
| 414 | if (locations.empty()) |
| 415 | return {}; |
| 416 | return dyn_cast<FusedLoc>( |
| 417 | loopAnnotationImporter.moduleImport.translateLoc(locations[0])); |
| 418 | } |
| 419 | |
| 420 | FailureOr<FusedLoc> LoopMetadataConversion::convertEndLoc() { |
| 421 | if (locations.size() < 2) |
| 422 | return FusedLoc(); |
| 423 | if (locations.size() > 2) |
| 424 | return emitError(loc) |
| 425 | << "expected loop metadata to have at most two DILocations" ; |
| 426 | return dyn_cast<FusedLoc>( |
| 427 | loopAnnotationImporter.moduleImport.translateLoc(locations[1])); |
| 428 | } |
| 429 | |
| 430 | LoopAnnotationAttr LoopMetadataConversion::convert() { |
| 431 | if (failed(Result: initConversionState())) |
| 432 | return {}; |
| 433 | |
| 434 | FailureOr<BoolAttr> disableNonForced = |
| 435 | lookupUnitNode(name: "llvm.loop.disable_nonforced" ); |
| 436 | FailureOr<LoopVectorizeAttr> vecAttr = convertVectorizeAttr(); |
| 437 | FailureOr<LoopInterleaveAttr> interleaveAttr = convertInterleaveAttr(); |
| 438 | FailureOr<LoopUnrollAttr> unrollAttr = convertUnrollAttr(); |
| 439 | FailureOr<LoopUnrollAndJamAttr> unrollAndJamAttr = convertUnrollAndJamAttr(); |
| 440 | FailureOr<LoopLICMAttr> licmAttr = convertLICMAttr(); |
| 441 | FailureOr<LoopDistributeAttr> distributeAttr = convertDistributeAttr(); |
| 442 | FailureOr<LoopPipelineAttr> pipelineAttr = convertPipelineAttr(); |
| 443 | FailureOr<LoopPeeledAttr> peeledAttr = convertPeeledAttr(); |
| 444 | FailureOr<LoopUnswitchAttr> unswitchAttr = convertUnswitchAttr(); |
| 445 | FailureOr<BoolAttr> mustProgress = lookupUnitNode(name: "llvm.loop.mustprogress" ); |
| 446 | FailureOr<BoolAttr> isVectorized = |
| 447 | lookupIntNodeAsBoolAttr(name: "llvm.loop.isvectorized" ); |
| 448 | FailureOr<SmallVector<AccessGroupAttr>> parallelAccesses = |
| 449 | convertParallelAccesses(); |
| 450 | |
| 451 | // Drop the metadata if there are parts that cannot be imported. |
| 452 | if (!propertyMap.empty()) { |
| 453 | for (auto name : propertyMap.keys()) |
| 454 | emitWarning(loc) << "unknown loop annotation " << name; |
| 455 | return {}; |
| 456 | } |
| 457 | |
| 458 | FailureOr<FusedLoc> startLoc = convertStartLoc(); |
| 459 | FailureOr<FusedLoc> endLoc = convertEndLoc(); |
| 460 | |
| 461 | return createIfNonNull<LoopAnnotationAttr>( |
| 462 | ctx, disableNonForced, vecAttr, interleaveAttr, unrollAttr, |
| 463 | unrollAndJamAttr, licmAttr, distributeAttr, pipelineAttr, peeledAttr, |
| 464 | unswitchAttr, mustProgress, isVectorized, startLoc, endLoc, |
| 465 | parallelAccesses); |
| 466 | } |
| 467 | |
| 468 | LoopAnnotationAttr |
| 469 | LoopAnnotationImporter::translateLoopAnnotation(const llvm::MDNode *node, |
| 470 | Location loc) { |
| 471 | if (!node) |
| 472 | return {}; |
| 473 | |
| 474 | // Note: This check is necessary to distinguish between failed translations |
| 475 | // and not yet attempted translations. |
| 476 | auto it = loopMetadataMapping.find(node); |
| 477 | if (it != loopMetadataMapping.end()) |
| 478 | return it->getSecond(); |
| 479 | |
| 480 | LoopAnnotationAttr attr = LoopMetadataConversion(node, loc, *this).convert(); |
| 481 | |
| 482 | mapLoopMetadata(node, attr); |
| 483 | return attr; |
| 484 | } |
| 485 | |
| 486 | LogicalResult |
| 487 | LoopAnnotationImporter::translateAccessGroup(const llvm::MDNode *node, |
| 488 | Location loc) { |
| 489 | SmallVector<const llvm::MDNode *> accessGroups; |
| 490 | if (!node->getNumOperands()) |
| 491 | accessGroups.push_back(Elt: node); |
| 492 | for (const llvm::MDOperand &operand : node->operands()) { |
| 493 | auto *childNode = dyn_cast<llvm::MDNode>(Val: operand); |
| 494 | if (!childNode) |
| 495 | return failure(); |
| 496 | accessGroups.push_back(Elt: cast<llvm::MDNode>(Val: operand.get())); |
| 497 | } |
| 498 | |
| 499 | // Convert all entries of the access group list to access group operations. |
| 500 | for (const llvm::MDNode *accessGroup : accessGroups) { |
| 501 | if (accessGroupMapping.count(accessGroup)) |
| 502 | continue; |
| 503 | // Verify the access group node is distinct and empty. |
| 504 | if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct()) |
| 505 | return emitWarning(loc) |
| 506 | << "expected an access group node to be empty and distinct" ; |
| 507 | |
| 508 | // Add a mapping from the access group node to the newly created attribute. |
| 509 | accessGroupMapping[accessGroup] = builder.getAttr<AccessGroupAttr>(); |
| 510 | } |
| 511 | return success(); |
| 512 | } |
| 513 | |
| 514 | FailureOr<SmallVector<AccessGroupAttr>> |
| 515 | LoopAnnotationImporter::lookupAccessGroupAttrs(const llvm::MDNode *node) const { |
| 516 | // An access group node is either a single access group or an access group |
| 517 | // list. |
| 518 | SmallVector<AccessGroupAttr> accessGroups; |
| 519 | if (!node->getNumOperands()) |
| 520 | accessGroups.push_back(accessGroupMapping.lookup(node)); |
| 521 | for (const llvm::MDOperand &operand : node->operands()) { |
| 522 | auto *node = cast<llvm::MDNode>(Val: operand.get()); |
| 523 | accessGroups.push_back(accessGroupMapping.lookup(node)); |
| 524 | } |
| 525 | // Exit if one of the access group node lookups failed. |
| 526 | if (llvm::is_contained(accessGroups, nullptr)) |
| 527 | return failure(); |
| 528 | return accessGroups; |
| 529 | } |
| 530 | |