| 1 | //===- LoopAnnotationTranslation.cpp - Loop annotation export -------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "LoopAnnotationTranslation.h" |
| 10 | #include "llvm/IR/DebugInfoMetadata.h" |
| 11 | |
| 12 | using namespace mlir; |
| 13 | using namespace mlir::LLVM; |
| 14 | using namespace mlir::LLVM::detail; |
| 15 | |
| 16 | namespace { |
| 17 | /// Helper class that keeps the state of one attribute to metadata conversion. |
| 18 | struct LoopAnnotationConversion { |
| 19 | LoopAnnotationConversion(LoopAnnotationAttr attr, Operation *op, |
| 20 | LoopAnnotationTranslation &loopAnnotationTranslation, |
| 21 | llvm::LLVMContext &ctx) |
| 22 | : attr(attr), op(op), |
| 23 | loopAnnotationTranslation(loopAnnotationTranslation), ctx(ctx) {} |
| 24 | |
| 25 | /// Converts this struct's loop annotation into a corresponding LLVMIR |
| 26 | /// metadata representation. |
| 27 | llvm::MDNode *convert(); |
| 28 | |
| 29 | /// Conversion functions for different payload attribute kinds. |
| 30 | void addUnitNode(StringRef name); |
| 31 | void addUnitNode(StringRef name, BoolAttr attr); |
| 32 | void addI32NodeWithVal(StringRef name, uint32_t val); |
| 33 | void convertBoolNode(StringRef name, BoolAttr attr, bool negated = false); |
| 34 | void convertI32Node(StringRef name, IntegerAttr attr); |
| 35 | void convertFollowupNode(StringRef name, LoopAnnotationAttr attr); |
| 36 | void convertLocation(FusedLoc attr); |
| 37 | |
| 38 | /// Conversion functions for each for each loop annotation sub-attribute. |
| 39 | void convertLoopOptions(LoopVectorizeAttr options); |
| 40 | void convertLoopOptions(LoopInterleaveAttr options); |
| 41 | void convertLoopOptions(LoopUnrollAttr options); |
| 42 | void convertLoopOptions(LoopUnrollAndJamAttr options); |
| 43 | void convertLoopOptions(LoopLICMAttr options); |
| 44 | void convertLoopOptions(LoopDistributeAttr options); |
| 45 | void convertLoopOptions(LoopPipelineAttr options); |
| 46 | void convertLoopOptions(LoopPeeledAttr options); |
| 47 | void convertLoopOptions(LoopUnswitchAttr options); |
| 48 | |
| 49 | LoopAnnotationAttr attr; |
| 50 | Operation *op; |
| 51 | LoopAnnotationTranslation &loopAnnotationTranslation; |
| 52 | llvm::LLVMContext &ctx; |
| 53 | llvm::SmallVector<llvm::Metadata *> metadataNodes; |
| 54 | }; |
| 55 | } // namespace |
| 56 | |
| 57 | void LoopAnnotationConversion::addUnitNode(StringRef name) { |
| 58 | metadataNodes.push_back( |
| 59 | Elt: llvm::MDNode::get(Context&: ctx, MDs: {llvm::MDString::get(Context&: ctx, Str: name)})); |
| 60 | } |
| 61 | |
| 62 | void LoopAnnotationConversion::addUnitNode(StringRef name, BoolAttr attr) { |
| 63 | if (attr && attr.getValue()) |
| 64 | addUnitNode(name); |
| 65 | } |
| 66 | |
| 67 | void LoopAnnotationConversion::addI32NodeWithVal(StringRef name, uint32_t val) { |
| 68 | llvm::Constant *cstValue = llvm::ConstantInt::get( |
| 69 | Ty: llvm::IntegerType::get(C&: ctx, /*NumBits=*/32), V: val, /*isSigned=*/IsSigned: false); |
| 70 | metadataNodes.push_back( |
| 71 | Elt: llvm::MDNode::get(Context&: ctx, MDs: {llvm::MDString::get(Context&: ctx, Str: name), |
| 72 | llvm::ConstantAsMetadata::get(C: cstValue)})); |
| 73 | } |
| 74 | |
| 75 | void LoopAnnotationConversion::convertBoolNode(StringRef name, BoolAttr attr, |
| 76 | bool negated) { |
| 77 | if (!attr) |
| 78 | return; |
| 79 | bool val = negated ^ attr.getValue(); |
| 80 | llvm::Constant *cstValue = llvm::ConstantInt::getBool(Context&: ctx, V: val); |
| 81 | metadataNodes.push_back( |
| 82 | Elt: llvm::MDNode::get(Context&: ctx, MDs: {llvm::MDString::get(Context&: ctx, Str: name), |
| 83 | llvm::ConstantAsMetadata::get(C: cstValue)})); |
| 84 | } |
| 85 | |
| 86 | void LoopAnnotationConversion::convertI32Node(StringRef name, |
| 87 | IntegerAttr attr) { |
| 88 | if (!attr) |
| 89 | return; |
| 90 | addI32NodeWithVal(name, val: attr.getInt()); |
| 91 | } |
| 92 | |
| 93 | void LoopAnnotationConversion::convertFollowupNode(StringRef name, |
| 94 | LoopAnnotationAttr attr) { |
| 95 | if (!attr) |
| 96 | return; |
| 97 | |
| 98 | llvm::MDNode *node = |
| 99 | loopAnnotationTranslation.translateLoopAnnotation(attr, op); |
| 100 | |
| 101 | metadataNodes.push_back( |
| 102 | Elt: llvm::MDNode::get(Context&: ctx, MDs: {llvm::MDString::get(Context&: ctx, Str: name), node})); |
| 103 | } |
| 104 | |
| 105 | void LoopAnnotationConversion::convertLoopOptions(LoopVectorizeAttr options) { |
| 106 | convertBoolNode(name: "llvm.loop.vectorize.enable" , attr: options.getDisable(), negated: true); |
| 107 | convertBoolNode(name: "llvm.loop.vectorize.predicate.enable" , |
| 108 | attr: options.getPredicateEnable()); |
| 109 | convertBoolNode(name: "llvm.loop.vectorize.scalable.enable" , |
| 110 | attr: options.getScalableEnable()); |
| 111 | convertI32Node(name: "llvm.loop.vectorize.width" , attr: options.getWidth()); |
| 112 | convertFollowupNode("llvm.loop.vectorize.followup_vectorized" , |
| 113 | options.getFollowupVectorized()); |
| 114 | convertFollowupNode("llvm.loop.vectorize.followup_epilogue" , |
| 115 | options.getFollowupEpilogue()); |
| 116 | convertFollowupNode("llvm.loop.vectorize.followup_all" , |
| 117 | options.getFollowupAll()); |
| 118 | } |
| 119 | |
| 120 | void LoopAnnotationConversion::convertLoopOptions(LoopInterleaveAttr options) { |
| 121 | convertI32Node(name: "llvm.loop.interleave.count" , attr: options.getCount()); |
| 122 | } |
| 123 | |
| 124 | void LoopAnnotationConversion::convertLoopOptions(LoopUnrollAttr options) { |
| 125 | if (auto disable = options.getDisable()) |
| 126 | addUnitNode(disable.getValue() ? "llvm.loop.unroll.disable" |
| 127 | : "llvm.loop.unroll.enable" ); |
| 128 | convertI32Node(name: "llvm.loop.unroll.count" , attr: options.getCount()); |
| 129 | convertBoolNode(name: "llvm.loop.unroll.runtime.disable" , |
| 130 | attr: options.getRuntimeDisable()); |
| 131 | addUnitNode("llvm.loop.unroll.full" , options.getFull()); |
| 132 | convertFollowupNode("llvm.loop.unroll.followup_unrolled" , |
| 133 | options.getFollowupUnrolled()); |
| 134 | convertFollowupNode("llvm.loop.unroll.followup_remainder" , |
| 135 | options.getFollowupRemainder()); |
| 136 | convertFollowupNode("llvm.loop.unroll.followup_all" , |
| 137 | options.getFollowupAll()); |
| 138 | } |
| 139 | |
| 140 | void LoopAnnotationConversion::convertLoopOptions( |
| 141 | LoopUnrollAndJamAttr options) { |
| 142 | if (auto disable = options.getDisable()) |
| 143 | addUnitNode(disable.getValue() ? "llvm.loop.unroll_and_jam.disable" |
| 144 | : "llvm.loop.unroll_and_jam.enable" ); |
| 145 | convertI32Node(name: "llvm.loop.unroll_and_jam.count" , attr: options.getCount()); |
| 146 | convertFollowupNode("llvm.loop.unroll_and_jam.followup_outer" , |
| 147 | options.getFollowupOuter()); |
| 148 | convertFollowupNode("llvm.loop.unroll_and_jam.followup_inner" , |
| 149 | options.getFollowupInner()); |
| 150 | convertFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_outer" , |
| 151 | options.getFollowupRemainderOuter()); |
| 152 | convertFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_inner" , |
| 153 | options.getFollowupRemainderInner()); |
| 154 | convertFollowupNode("llvm.loop.unroll_and_jam.followup_all" , |
| 155 | options.getFollowupAll()); |
| 156 | } |
| 157 | |
| 158 | void LoopAnnotationConversion::convertLoopOptions(LoopLICMAttr options) { |
| 159 | addUnitNode("llvm.licm.disable" , options.getDisable()); |
| 160 | addUnitNode("llvm.loop.licm_versioning.disable" , |
| 161 | options.getVersioningDisable()); |
| 162 | } |
| 163 | |
| 164 | void LoopAnnotationConversion::convertLoopOptions(LoopDistributeAttr options) { |
| 165 | convertBoolNode(name: "llvm.loop.distribute.enable" , attr: options.getDisable(), negated: true); |
| 166 | convertFollowupNode("llvm.loop.distribute.followup_coincident" , |
| 167 | options.getFollowupCoincident()); |
| 168 | convertFollowupNode("llvm.loop.distribute.followup_sequential" , |
| 169 | options.getFollowupSequential()); |
| 170 | convertFollowupNode("llvm.loop.distribute.followup_fallback" , |
| 171 | options.getFollowupFallback()); |
| 172 | convertFollowupNode("llvm.loop.distribute.followup_all" , |
| 173 | options.getFollowupAll()); |
| 174 | } |
| 175 | |
| 176 | void LoopAnnotationConversion::convertLoopOptions(LoopPipelineAttr options) { |
| 177 | convertBoolNode(name: "llvm.loop.pipeline.disable" , attr: options.getDisable()); |
| 178 | convertI32Node(name: "llvm.loop.pipeline.initiationinterval" , |
| 179 | attr: options.getInitiationinterval()); |
| 180 | } |
| 181 | |
| 182 | void LoopAnnotationConversion::convertLoopOptions(LoopPeeledAttr options) { |
| 183 | convertI32Node(name: "llvm.loop.peeled.count" , attr: options.getCount()); |
| 184 | } |
| 185 | |
| 186 | void LoopAnnotationConversion::convertLoopOptions(LoopUnswitchAttr options) { |
| 187 | addUnitNode("llvm.loop.unswitch.partial.disable" , |
| 188 | options.getPartialDisable()); |
| 189 | } |
| 190 | |
| 191 | void LoopAnnotationConversion::convertLocation(FusedLoc location) { |
| 192 | auto localScopeAttr = |
| 193 | dyn_cast_or_null<DILocalScopeAttr>(location.getMetadata()); |
| 194 | if (!localScopeAttr) |
| 195 | return; |
| 196 | auto *localScope = dyn_cast<llvm::DILocalScope>( |
| 197 | loopAnnotationTranslation.moduleTranslation.translateDebugInfo( |
| 198 | attr: localScopeAttr)); |
| 199 | if (!localScope) |
| 200 | return; |
| 201 | llvm::Metadata *loc = |
| 202 | loopAnnotationTranslation.moduleTranslation.translateLoc(loc: location, |
| 203 | scope: localScope); |
| 204 | metadataNodes.push_back(Elt: loc); |
| 205 | } |
| 206 | |
| 207 | llvm::MDNode *LoopAnnotationConversion::convert() { |
| 208 | // Reserve operand 0 for loop id self reference. |
| 209 | auto dummy = llvm::MDNode::getTemporary(Context&: ctx, MDs: std::nullopt); |
| 210 | metadataNodes.push_back(Elt: dummy.get()); |
| 211 | |
| 212 | if (FusedLoc startLoc = attr.getStartLoc()) |
| 213 | convertLocation(startLoc); |
| 214 | |
| 215 | if (FusedLoc endLoc = attr.getEndLoc()) |
| 216 | convertLocation(endLoc); |
| 217 | |
| 218 | addUnitNode("llvm.loop.disable_nonforced" , attr.getDisableNonforced()); |
| 219 | addUnitNode("llvm.loop.mustprogress" , attr.getMustProgress()); |
| 220 | // "isvectorized" is encoded as an i32 value. |
| 221 | if (BoolAttr isVectorized = attr.getIsVectorized()) |
| 222 | addI32NodeWithVal(name: "llvm.loop.isvectorized" , val: isVectorized.getValue()); |
| 223 | |
| 224 | if (auto options = attr.getVectorize()) |
| 225 | convertLoopOptions(options); |
| 226 | if (auto options = attr.getInterleave()) |
| 227 | convertLoopOptions(options); |
| 228 | if (auto options = attr.getUnroll()) |
| 229 | convertLoopOptions(options); |
| 230 | if (auto options = attr.getUnrollAndJam()) |
| 231 | convertLoopOptions(options); |
| 232 | if (auto options = attr.getLicm()) |
| 233 | convertLoopOptions(options); |
| 234 | if (auto options = attr.getDistribute()) |
| 235 | convertLoopOptions(options); |
| 236 | if (auto options = attr.getPipeline()) |
| 237 | convertLoopOptions(options); |
| 238 | if (auto options = attr.getPeeled()) |
| 239 | convertLoopOptions(options); |
| 240 | if (auto options = attr.getUnswitch()) |
| 241 | convertLoopOptions(options); |
| 242 | |
| 243 | ArrayRef<AccessGroupAttr> parallelAccessGroups = attr.getParallelAccesses(); |
| 244 | if (!parallelAccessGroups.empty()) { |
| 245 | SmallVector<llvm::Metadata *> parallelAccess; |
| 246 | parallelAccess.push_back( |
| 247 | Elt: llvm::MDString::get(Context&: ctx, Str: "llvm.loop.parallel_accesses" )); |
| 248 | for (AccessGroupAttr accessGroupAttr : parallelAccessGroups) |
| 249 | parallelAccess.push_back( |
| 250 | loopAnnotationTranslation.getAccessGroup(accessGroupAttr)); |
| 251 | metadataNodes.push_back(Elt: llvm::MDNode::get(Context&: ctx, MDs: parallelAccess)); |
| 252 | } |
| 253 | |
| 254 | // Create loop options and set the first operand to itself. |
| 255 | llvm::MDNode *loopMD = llvm::MDNode::get(Context&: ctx, MDs: metadataNodes); |
| 256 | loopMD->replaceOperandWith(I: 0, New: loopMD); |
| 257 | |
| 258 | return loopMD; |
| 259 | } |
| 260 | |
| 261 | llvm::MDNode * |
| 262 | LoopAnnotationTranslation::translateLoopAnnotation(LoopAnnotationAttr attr, |
| 263 | Operation *op) { |
| 264 | if (!attr) |
| 265 | return nullptr; |
| 266 | |
| 267 | llvm::MDNode *loopMD = lookupLoopMetadata(options: attr); |
| 268 | if (loopMD) |
| 269 | return loopMD; |
| 270 | |
| 271 | loopMD = |
| 272 | LoopAnnotationConversion(attr, op, *this, this->llvmModule.getContext()) |
| 273 | .convert(); |
| 274 | // Store a map from this Attribute to the LLVM metadata in case we |
| 275 | // encounter it again. |
| 276 | mapLoopMetadata(options: attr, metadata: loopMD); |
| 277 | return loopMD; |
| 278 | } |
| 279 | |
| 280 | llvm::MDNode * |
| 281 | LoopAnnotationTranslation::getAccessGroup(AccessGroupAttr accessGroupAttr) { |
| 282 | auto [result, inserted] = |
| 283 | accessGroupMetadataMapping.try_emplace(accessGroupAttr); |
| 284 | if (inserted) |
| 285 | result->second = llvm::MDNode::getDistinct(llvmModule.getContext(), {}); |
| 286 | return result->second; |
| 287 | } |
| 288 | |
| 289 | llvm::MDNode * |
| 290 | LoopAnnotationTranslation::getAccessGroups(AccessGroupOpInterface op) { |
| 291 | ArrayAttr accessGroups = op.getAccessGroupsOrNull(); |
| 292 | if (!accessGroups || accessGroups.empty()) |
| 293 | return nullptr; |
| 294 | |
| 295 | SmallVector<llvm::Metadata *> groupMDs; |
| 296 | for (AccessGroupAttr group : accessGroups.getAsRange<AccessGroupAttr>()) |
| 297 | groupMDs.push_back(getAccessGroup(group)); |
| 298 | if (groupMDs.size() == 1) |
| 299 | return llvm::cast<llvm::MDNode>(Val: groupMDs.front()); |
| 300 | return llvm::MDNode::get(Context&: llvmModule.getContext(), MDs: groupMDs); |
| 301 | } |
| 302 | |