1//===- LoopAnnotationTranslation.cpp - Loop annotation export -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "LoopAnnotationTranslation.h"
10#include "llvm/IR/DebugInfoMetadata.h"
11
12using namespace mlir;
13using namespace mlir::LLVM;
14using namespace mlir::LLVM::detail;
15
16namespace {
17/// Helper class that keeps the state of one attribute to metadata conversion.
18struct LoopAnnotationConversion {
19 LoopAnnotationConversion(LoopAnnotationAttr attr, Operation *op,
20 LoopAnnotationTranslation &loopAnnotationTranslation,
21 llvm::LLVMContext &ctx)
22 : attr(attr), op(op),
23 loopAnnotationTranslation(loopAnnotationTranslation), ctx(ctx) {}
24
25 /// Converts this struct's loop annotation into a corresponding LLVMIR
26 /// metadata representation.
27 llvm::MDNode *convert();
28
29 /// Conversion functions for different payload attribute kinds.
30 void addUnitNode(StringRef name);
31 void addUnitNode(StringRef name, BoolAttr attr);
32 void addI32NodeWithVal(StringRef name, uint32_t val);
33 void convertBoolNode(StringRef name, BoolAttr attr, bool negated = false);
34 void convertI32Node(StringRef name, IntegerAttr attr);
35 void convertFollowupNode(StringRef name, LoopAnnotationAttr attr);
36 void convertLocation(FusedLoc attr);
37
38 /// Conversion functions for each for each loop annotation sub-attribute.
39 void convertLoopOptions(LoopVectorizeAttr options);
40 void convertLoopOptions(LoopInterleaveAttr options);
41 void convertLoopOptions(LoopUnrollAttr options);
42 void convertLoopOptions(LoopUnrollAndJamAttr options);
43 void convertLoopOptions(LoopLICMAttr options);
44 void convertLoopOptions(LoopDistributeAttr options);
45 void convertLoopOptions(LoopPipelineAttr options);
46 void convertLoopOptions(LoopPeeledAttr options);
47 void convertLoopOptions(LoopUnswitchAttr options);
48
49 LoopAnnotationAttr attr;
50 Operation *op;
51 LoopAnnotationTranslation &loopAnnotationTranslation;
52 llvm::LLVMContext &ctx;
53 llvm::SmallVector<llvm::Metadata *> metadataNodes;
54};
55} // namespace
56
57void LoopAnnotationConversion::addUnitNode(StringRef name) {
58 metadataNodes.push_back(
59 Elt: llvm::MDNode::get(Context&: ctx, MDs: {llvm::MDString::get(Context&: ctx, Str: name)}));
60}
61
62void LoopAnnotationConversion::addUnitNode(StringRef name, BoolAttr attr) {
63 if (attr && attr.getValue())
64 addUnitNode(name);
65}
66
67void LoopAnnotationConversion::addI32NodeWithVal(StringRef name, uint32_t val) {
68 llvm::Constant *cstValue = llvm::ConstantInt::get(
69 Ty: llvm::IntegerType::get(C&: ctx, /*NumBits=*/32), V: val, /*isSigned=*/IsSigned: false);
70 metadataNodes.push_back(
71 Elt: llvm::MDNode::get(Context&: ctx, MDs: {llvm::MDString::get(Context&: ctx, Str: name),
72 llvm::ConstantAsMetadata::get(C: cstValue)}));
73}
74
75void LoopAnnotationConversion::convertBoolNode(StringRef name, BoolAttr attr,
76 bool negated) {
77 if (!attr)
78 return;
79 bool val = negated ^ attr.getValue();
80 llvm::Constant *cstValue = llvm::ConstantInt::getBool(Context&: ctx, V: val);
81 metadataNodes.push_back(
82 Elt: llvm::MDNode::get(Context&: ctx, MDs: {llvm::MDString::get(Context&: ctx, Str: name),
83 llvm::ConstantAsMetadata::get(C: cstValue)}));
84}
85
86void LoopAnnotationConversion::convertI32Node(StringRef name,
87 IntegerAttr attr) {
88 if (!attr)
89 return;
90 addI32NodeWithVal(name, val: attr.getInt());
91}
92
93void LoopAnnotationConversion::convertFollowupNode(StringRef name,
94 LoopAnnotationAttr attr) {
95 if (!attr)
96 return;
97
98 llvm::MDNode *node =
99 loopAnnotationTranslation.translateLoopAnnotation(attr, op);
100
101 metadataNodes.push_back(
102 Elt: llvm::MDNode::get(Context&: ctx, MDs: {llvm::MDString::get(Context&: ctx, Str: name), node}));
103}
104
105void LoopAnnotationConversion::convertLoopOptions(LoopVectorizeAttr options) {
106 convertBoolNode(name: "llvm.loop.vectorize.enable", attr: options.getDisable(), negated: true);
107 convertBoolNode(name: "llvm.loop.vectorize.predicate.enable",
108 attr: options.getPredicateEnable());
109 convertBoolNode(name: "llvm.loop.vectorize.scalable.enable",
110 attr: options.getScalableEnable());
111 convertI32Node(name: "llvm.loop.vectorize.width", attr: options.getWidth());
112 convertFollowupNode("llvm.loop.vectorize.followup_vectorized",
113 options.getFollowupVectorized());
114 convertFollowupNode("llvm.loop.vectorize.followup_epilogue",
115 options.getFollowupEpilogue());
116 convertFollowupNode("llvm.loop.vectorize.followup_all",
117 options.getFollowupAll());
118}
119
120void LoopAnnotationConversion::convertLoopOptions(LoopInterleaveAttr options) {
121 convertI32Node(name: "llvm.loop.interleave.count", attr: options.getCount());
122}
123
124void LoopAnnotationConversion::convertLoopOptions(LoopUnrollAttr options) {
125 if (auto disable = options.getDisable())
126 addUnitNode(disable.getValue() ? "llvm.loop.unroll.disable"
127 : "llvm.loop.unroll.enable");
128 convertI32Node(name: "llvm.loop.unroll.count", attr: options.getCount());
129 convertBoolNode(name: "llvm.loop.unroll.runtime.disable",
130 attr: options.getRuntimeDisable());
131 addUnitNode("llvm.loop.unroll.full", options.getFull());
132 convertFollowupNode("llvm.loop.unroll.followup_unrolled",
133 options.getFollowupUnrolled());
134 convertFollowupNode("llvm.loop.unroll.followup_remainder",
135 options.getFollowupRemainder());
136 convertFollowupNode("llvm.loop.unroll.followup_all",
137 options.getFollowupAll());
138}
139
140void LoopAnnotationConversion::convertLoopOptions(
141 LoopUnrollAndJamAttr options) {
142 if (auto disable = options.getDisable())
143 addUnitNode(disable.getValue() ? "llvm.loop.unroll_and_jam.disable"
144 : "llvm.loop.unroll_and_jam.enable");
145 convertI32Node(name: "llvm.loop.unroll_and_jam.count", attr: options.getCount());
146 convertFollowupNode("llvm.loop.unroll_and_jam.followup_outer",
147 options.getFollowupOuter());
148 convertFollowupNode("llvm.loop.unroll_and_jam.followup_inner",
149 options.getFollowupInner());
150 convertFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_outer",
151 options.getFollowupRemainderOuter());
152 convertFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_inner",
153 options.getFollowupRemainderInner());
154 convertFollowupNode("llvm.loop.unroll_and_jam.followup_all",
155 options.getFollowupAll());
156}
157
158void LoopAnnotationConversion::convertLoopOptions(LoopLICMAttr options) {
159 addUnitNode("llvm.licm.disable", options.getDisable());
160 addUnitNode("llvm.loop.licm_versioning.disable",
161 options.getVersioningDisable());
162}
163
164void LoopAnnotationConversion::convertLoopOptions(LoopDistributeAttr options) {
165 convertBoolNode(name: "llvm.loop.distribute.enable", attr: options.getDisable(), negated: true);
166 convertFollowupNode("llvm.loop.distribute.followup_coincident",
167 options.getFollowupCoincident());
168 convertFollowupNode("llvm.loop.distribute.followup_sequential",
169 options.getFollowupSequential());
170 convertFollowupNode("llvm.loop.distribute.followup_fallback",
171 options.getFollowupFallback());
172 convertFollowupNode("llvm.loop.distribute.followup_all",
173 options.getFollowupAll());
174}
175
176void LoopAnnotationConversion::convertLoopOptions(LoopPipelineAttr options) {
177 convertBoolNode(name: "llvm.loop.pipeline.disable", attr: options.getDisable());
178 convertI32Node(name: "llvm.loop.pipeline.initiationinterval",
179 attr: options.getInitiationinterval());
180}
181
182void LoopAnnotationConversion::convertLoopOptions(LoopPeeledAttr options) {
183 convertI32Node(name: "llvm.loop.peeled.count", attr: options.getCount());
184}
185
186void LoopAnnotationConversion::convertLoopOptions(LoopUnswitchAttr options) {
187 addUnitNode("llvm.loop.unswitch.partial.disable",
188 options.getPartialDisable());
189}
190
191void LoopAnnotationConversion::convertLocation(FusedLoc location) {
192 auto localScopeAttr =
193 dyn_cast_or_null<DILocalScopeAttr>(location.getMetadata());
194 if (!localScopeAttr)
195 return;
196 auto *localScope = dyn_cast<llvm::DILocalScope>(
197 loopAnnotationTranslation.moduleTranslation.translateDebugInfo(
198 attr: localScopeAttr));
199 if (!localScope)
200 return;
201 llvm::Metadata *loc =
202 loopAnnotationTranslation.moduleTranslation.translateLoc(loc: location,
203 scope: localScope);
204 metadataNodes.push_back(Elt: loc);
205}
206
207llvm::MDNode *LoopAnnotationConversion::convert() {
208 // Reserve operand 0 for loop id self reference.
209 auto dummy = llvm::MDNode::getTemporary(Context&: ctx, MDs: std::nullopt);
210 metadataNodes.push_back(Elt: dummy.get());
211
212 if (FusedLoc startLoc = attr.getStartLoc())
213 convertLocation(startLoc);
214
215 if (FusedLoc endLoc = attr.getEndLoc())
216 convertLocation(endLoc);
217
218 addUnitNode("llvm.loop.disable_nonforced", attr.getDisableNonforced());
219 addUnitNode("llvm.loop.mustprogress", attr.getMustProgress());
220 // "isvectorized" is encoded as an i32 value.
221 if (BoolAttr isVectorized = attr.getIsVectorized())
222 addI32NodeWithVal(name: "llvm.loop.isvectorized", val: isVectorized.getValue());
223
224 if (auto options = attr.getVectorize())
225 convertLoopOptions(options);
226 if (auto options = attr.getInterleave())
227 convertLoopOptions(options);
228 if (auto options = attr.getUnroll())
229 convertLoopOptions(options);
230 if (auto options = attr.getUnrollAndJam())
231 convertLoopOptions(options);
232 if (auto options = attr.getLicm())
233 convertLoopOptions(options);
234 if (auto options = attr.getDistribute())
235 convertLoopOptions(options);
236 if (auto options = attr.getPipeline())
237 convertLoopOptions(options);
238 if (auto options = attr.getPeeled())
239 convertLoopOptions(options);
240 if (auto options = attr.getUnswitch())
241 convertLoopOptions(options);
242
243 ArrayRef<AccessGroupAttr> parallelAccessGroups = attr.getParallelAccesses();
244 if (!parallelAccessGroups.empty()) {
245 SmallVector<llvm::Metadata *> parallelAccess;
246 parallelAccess.push_back(
247 Elt: llvm::MDString::get(Context&: ctx, Str: "llvm.loop.parallel_accesses"));
248 for (AccessGroupAttr accessGroupAttr : parallelAccessGroups)
249 parallelAccess.push_back(
250 loopAnnotationTranslation.getAccessGroup(accessGroupAttr));
251 metadataNodes.push_back(Elt: llvm::MDNode::get(Context&: ctx, MDs: parallelAccess));
252 }
253
254 // Create loop options and set the first operand to itself.
255 llvm::MDNode *loopMD = llvm::MDNode::get(Context&: ctx, MDs: metadataNodes);
256 loopMD->replaceOperandWith(I: 0, New: loopMD);
257
258 return loopMD;
259}
260
261llvm::MDNode *
262LoopAnnotationTranslation::translateLoopAnnotation(LoopAnnotationAttr attr,
263 Operation *op) {
264 if (!attr)
265 return nullptr;
266
267 llvm::MDNode *loopMD = lookupLoopMetadata(options: attr);
268 if (loopMD)
269 return loopMD;
270
271 loopMD =
272 LoopAnnotationConversion(attr, op, *this, this->llvmModule.getContext())
273 .convert();
274 // Store a map from this Attribute to the LLVM metadata in case we
275 // encounter it again.
276 mapLoopMetadata(options: attr, metadata: loopMD);
277 return loopMD;
278}
279
280llvm::MDNode *
281LoopAnnotationTranslation::getAccessGroup(AccessGroupAttr accessGroupAttr) {
282 auto [result, inserted] =
283 accessGroupMetadataMapping.insert({accessGroupAttr, nullptr});
284 if (inserted)
285 result->second = llvm::MDNode::getDistinct(llvmModule.getContext(), {});
286 return result->second;
287}
288
289llvm::MDNode *
290LoopAnnotationTranslation::getAccessGroups(AccessGroupOpInterface op) {
291 ArrayAttr accessGroups = op.getAccessGroupsOrNull();
292 if (!accessGroups || accessGroups.empty())
293 return nullptr;
294
295 SmallVector<llvm::Metadata *> groupMDs;
296 for (AccessGroupAttr group : accessGroups.getAsRange<AccessGroupAttr>())
297 groupMDs.push_back(getAccessGroup(group));
298 if (groupMDs.size() == 1)
299 return llvm::cast<llvm::MDNode>(Val: groupMDs.front());
300 return llvm::MDNode::get(Context&: llvmModule.getContext(), MDs: groupMDs);
301}
302

source code of mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp