1 | //===- LoopAnnotationTranslation.cpp - Loop annotation export -------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "LoopAnnotationTranslation.h" |
10 | #include "llvm/IR/DebugInfoMetadata.h" |
11 | |
12 | using namespace mlir; |
13 | using namespace mlir::LLVM; |
14 | using namespace mlir::LLVM::detail; |
15 | |
16 | namespace { |
17 | /// Helper class that keeps the state of one attribute to metadata conversion. |
18 | struct LoopAnnotationConversion { |
19 | LoopAnnotationConversion(LoopAnnotationAttr attr, Operation *op, |
20 | LoopAnnotationTranslation &loopAnnotationTranslation, |
21 | llvm::LLVMContext &ctx) |
22 | : attr(attr), op(op), |
23 | loopAnnotationTranslation(loopAnnotationTranslation), ctx(ctx) {} |
24 | |
25 | /// Converts this struct's loop annotation into a corresponding LLVMIR |
26 | /// metadata representation. |
27 | llvm::MDNode *convert(); |
28 | |
29 | /// Conversion functions for different payload attribute kinds. |
30 | void addUnitNode(StringRef name); |
31 | void addUnitNode(StringRef name, BoolAttr attr); |
32 | void addI32NodeWithVal(StringRef name, uint32_t val); |
33 | void convertBoolNode(StringRef name, BoolAttr attr, bool negated = false); |
34 | void convertI32Node(StringRef name, IntegerAttr attr); |
35 | void convertFollowupNode(StringRef name, LoopAnnotationAttr attr); |
36 | void convertLocation(FusedLoc attr); |
37 | |
38 | /// Conversion functions for each for each loop annotation sub-attribute. |
39 | void convertLoopOptions(LoopVectorizeAttr options); |
40 | void convertLoopOptions(LoopInterleaveAttr options); |
41 | void convertLoopOptions(LoopUnrollAttr options); |
42 | void convertLoopOptions(LoopUnrollAndJamAttr options); |
43 | void convertLoopOptions(LoopLICMAttr options); |
44 | void convertLoopOptions(LoopDistributeAttr options); |
45 | void convertLoopOptions(LoopPipelineAttr options); |
46 | void convertLoopOptions(LoopPeeledAttr options); |
47 | void convertLoopOptions(LoopUnswitchAttr options); |
48 | |
49 | LoopAnnotationAttr attr; |
50 | Operation *op; |
51 | LoopAnnotationTranslation &loopAnnotationTranslation; |
52 | llvm::LLVMContext &ctx; |
53 | llvm::SmallVector<llvm::Metadata *> metadataNodes; |
54 | }; |
55 | } // namespace |
56 | |
57 | void LoopAnnotationConversion::addUnitNode(StringRef name) { |
58 | metadataNodes.push_back( |
59 | Elt: llvm::MDNode::get(Context&: ctx, MDs: {llvm::MDString::get(Context&: ctx, Str: name)})); |
60 | } |
61 | |
62 | void LoopAnnotationConversion::addUnitNode(StringRef name, BoolAttr attr) { |
63 | if (attr && attr.getValue()) |
64 | addUnitNode(name); |
65 | } |
66 | |
67 | void LoopAnnotationConversion::addI32NodeWithVal(StringRef name, uint32_t val) { |
68 | llvm::Constant *cstValue = llvm::ConstantInt::get( |
69 | Ty: llvm::IntegerType::get(C&: ctx, /*NumBits=*/32), V: val, /*isSigned=*/IsSigned: false); |
70 | metadataNodes.push_back( |
71 | Elt: llvm::MDNode::get(Context&: ctx, MDs: {llvm::MDString::get(Context&: ctx, Str: name), |
72 | llvm::ConstantAsMetadata::get(C: cstValue)})); |
73 | } |
74 | |
75 | void LoopAnnotationConversion::convertBoolNode(StringRef name, BoolAttr attr, |
76 | bool negated) { |
77 | if (!attr) |
78 | return; |
79 | bool val = negated ^ attr.getValue(); |
80 | llvm::Constant *cstValue = llvm::ConstantInt::getBool(Context&: ctx, V: val); |
81 | metadataNodes.push_back( |
82 | Elt: llvm::MDNode::get(Context&: ctx, MDs: {llvm::MDString::get(Context&: ctx, Str: name), |
83 | llvm::ConstantAsMetadata::get(C: cstValue)})); |
84 | } |
85 | |
86 | void LoopAnnotationConversion::convertI32Node(StringRef name, |
87 | IntegerAttr attr) { |
88 | if (!attr) |
89 | return; |
90 | addI32NodeWithVal(name, val: attr.getInt()); |
91 | } |
92 | |
93 | void LoopAnnotationConversion::convertFollowupNode(StringRef name, |
94 | LoopAnnotationAttr attr) { |
95 | if (!attr) |
96 | return; |
97 | |
98 | llvm::MDNode *node = |
99 | loopAnnotationTranslation.translateLoopAnnotation(attr, op); |
100 | |
101 | metadataNodes.push_back( |
102 | Elt: llvm::MDNode::get(Context&: ctx, MDs: {llvm::MDString::get(Context&: ctx, Str: name), node})); |
103 | } |
104 | |
105 | void LoopAnnotationConversion::convertLoopOptions(LoopVectorizeAttr options) { |
106 | convertBoolNode(name: "llvm.loop.vectorize.enable" , attr: options.getDisable(), negated: true); |
107 | convertBoolNode(name: "llvm.loop.vectorize.predicate.enable" , |
108 | attr: options.getPredicateEnable()); |
109 | convertBoolNode(name: "llvm.loop.vectorize.scalable.enable" , |
110 | attr: options.getScalableEnable()); |
111 | convertI32Node(name: "llvm.loop.vectorize.width" , attr: options.getWidth()); |
112 | convertFollowupNode("llvm.loop.vectorize.followup_vectorized" , |
113 | options.getFollowupVectorized()); |
114 | convertFollowupNode("llvm.loop.vectorize.followup_epilogue" , |
115 | options.getFollowupEpilogue()); |
116 | convertFollowupNode("llvm.loop.vectorize.followup_all" , |
117 | options.getFollowupAll()); |
118 | } |
119 | |
120 | void LoopAnnotationConversion::convertLoopOptions(LoopInterleaveAttr options) { |
121 | convertI32Node(name: "llvm.loop.interleave.count" , attr: options.getCount()); |
122 | } |
123 | |
124 | void LoopAnnotationConversion::convertLoopOptions(LoopUnrollAttr options) { |
125 | if (auto disable = options.getDisable()) |
126 | addUnitNode(disable.getValue() ? "llvm.loop.unroll.disable" |
127 | : "llvm.loop.unroll.enable" ); |
128 | convertI32Node(name: "llvm.loop.unroll.count" , attr: options.getCount()); |
129 | convertBoolNode(name: "llvm.loop.unroll.runtime.disable" , |
130 | attr: options.getRuntimeDisable()); |
131 | addUnitNode("llvm.loop.unroll.full" , options.getFull()); |
132 | convertFollowupNode("llvm.loop.unroll.followup_unrolled" , |
133 | options.getFollowupUnrolled()); |
134 | convertFollowupNode("llvm.loop.unroll.followup_remainder" , |
135 | options.getFollowupRemainder()); |
136 | convertFollowupNode("llvm.loop.unroll.followup_all" , |
137 | options.getFollowupAll()); |
138 | } |
139 | |
140 | void LoopAnnotationConversion::convertLoopOptions( |
141 | LoopUnrollAndJamAttr options) { |
142 | if (auto disable = options.getDisable()) |
143 | addUnitNode(disable.getValue() ? "llvm.loop.unroll_and_jam.disable" |
144 | : "llvm.loop.unroll_and_jam.enable" ); |
145 | convertI32Node(name: "llvm.loop.unroll_and_jam.count" , attr: options.getCount()); |
146 | convertFollowupNode("llvm.loop.unroll_and_jam.followup_outer" , |
147 | options.getFollowupOuter()); |
148 | convertFollowupNode("llvm.loop.unroll_and_jam.followup_inner" , |
149 | options.getFollowupInner()); |
150 | convertFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_outer" , |
151 | options.getFollowupRemainderOuter()); |
152 | convertFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_inner" , |
153 | options.getFollowupRemainderInner()); |
154 | convertFollowupNode("llvm.loop.unroll_and_jam.followup_all" , |
155 | options.getFollowupAll()); |
156 | } |
157 | |
158 | void LoopAnnotationConversion::convertLoopOptions(LoopLICMAttr options) { |
159 | addUnitNode("llvm.licm.disable" , options.getDisable()); |
160 | addUnitNode("llvm.loop.licm_versioning.disable" , |
161 | options.getVersioningDisable()); |
162 | } |
163 | |
164 | void LoopAnnotationConversion::convertLoopOptions(LoopDistributeAttr options) { |
165 | convertBoolNode(name: "llvm.loop.distribute.enable" , attr: options.getDisable(), negated: true); |
166 | convertFollowupNode("llvm.loop.distribute.followup_coincident" , |
167 | options.getFollowupCoincident()); |
168 | convertFollowupNode("llvm.loop.distribute.followup_sequential" , |
169 | options.getFollowupSequential()); |
170 | convertFollowupNode("llvm.loop.distribute.followup_fallback" , |
171 | options.getFollowupFallback()); |
172 | convertFollowupNode("llvm.loop.distribute.followup_all" , |
173 | options.getFollowupAll()); |
174 | } |
175 | |
176 | void LoopAnnotationConversion::convertLoopOptions(LoopPipelineAttr options) { |
177 | convertBoolNode(name: "llvm.loop.pipeline.disable" , attr: options.getDisable()); |
178 | convertI32Node(name: "llvm.loop.pipeline.initiationinterval" , |
179 | attr: options.getInitiationinterval()); |
180 | } |
181 | |
182 | void LoopAnnotationConversion::convertLoopOptions(LoopPeeledAttr options) { |
183 | convertI32Node(name: "llvm.loop.peeled.count" , attr: options.getCount()); |
184 | } |
185 | |
186 | void LoopAnnotationConversion::convertLoopOptions(LoopUnswitchAttr options) { |
187 | addUnitNode("llvm.loop.unswitch.partial.disable" , |
188 | options.getPartialDisable()); |
189 | } |
190 | |
191 | void LoopAnnotationConversion::convertLocation(FusedLoc location) { |
192 | auto localScopeAttr = |
193 | dyn_cast_or_null<DILocalScopeAttr>(location.getMetadata()); |
194 | if (!localScopeAttr) |
195 | return; |
196 | auto *localScope = dyn_cast<llvm::DILocalScope>( |
197 | loopAnnotationTranslation.moduleTranslation.translateDebugInfo( |
198 | attr: localScopeAttr)); |
199 | if (!localScope) |
200 | return; |
201 | llvm::Metadata *loc = |
202 | loopAnnotationTranslation.moduleTranslation.translateLoc(loc: location, |
203 | scope: localScope); |
204 | metadataNodes.push_back(Elt: loc); |
205 | } |
206 | |
207 | llvm::MDNode *LoopAnnotationConversion::convert() { |
208 | // Reserve operand 0 for loop id self reference. |
209 | auto dummy = llvm::MDNode::getTemporary(Context&: ctx, MDs: std::nullopt); |
210 | metadataNodes.push_back(Elt: dummy.get()); |
211 | |
212 | if (FusedLoc startLoc = attr.getStartLoc()) |
213 | convertLocation(startLoc); |
214 | |
215 | if (FusedLoc endLoc = attr.getEndLoc()) |
216 | convertLocation(endLoc); |
217 | |
218 | addUnitNode("llvm.loop.disable_nonforced" , attr.getDisableNonforced()); |
219 | addUnitNode("llvm.loop.mustprogress" , attr.getMustProgress()); |
220 | // "isvectorized" is encoded as an i32 value. |
221 | if (BoolAttr isVectorized = attr.getIsVectorized()) |
222 | addI32NodeWithVal(name: "llvm.loop.isvectorized" , val: isVectorized.getValue()); |
223 | |
224 | if (auto options = attr.getVectorize()) |
225 | convertLoopOptions(options); |
226 | if (auto options = attr.getInterleave()) |
227 | convertLoopOptions(options); |
228 | if (auto options = attr.getUnroll()) |
229 | convertLoopOptions(options); |
230 | if (auto options = attr.getUnrollAndJam()) |
231 | convertLoopOptions(options); |
232 | if (auto options = attr.getLicm()) |
233 | convertLoopOptions(options); |
234 | if (auto options = attr.getDistribute()) |
235 | convertLoopOptions(options); |
236 | if (auto options = attr.getPipeline()) |
237 | convertLoopOptions(options); |
238 | if (auto options = attr.getPeeled()) |
239 | convertLoopOptions(options); |
240 | if (auto options = attr.getUnswitch()) |
241 | convertLoopOptions(options); |
242 | |
243 | ArrayRef<AccessGroupAttr> parallelAccessGroups = attr.getParallelAccesses(); |
244 | if (!parallelAccessGroups.empty()) { |
245 | SmallVector<llvm::Metadata *> parallelAccess; |
246 | parallelAccess.push_back( |
247 | Elt: llvm::MDString::get(Context&: ctx, Str: "llvm.loop.parallel_accesses" )); |
248 | for (AccessGroupAttr accessGroupAttr : parallelAccessGroups) |
249 | parallelAccess.push_back( |
250 | loopAnnotationTranslation.getAccessGroup(accessGroupAttr)); |
251 | metadataNodes.push_back(Elt: llvm::MDNode::get(Context&: ctx, MDs: parallelAccess)); |
252 | } |
253 | |
254 | // Create loop options and set the first operand to itself. |
255 | llvm::MDNode *loopMD = llvm::MDNode::get(Context&: ctx, MDs: metadataNodes); |
256 | loopMD->replaceOperandWith(I: 0, New: loopMD); |
257 | |
258 | return loopMD; |
259 | } |
260 | |
261 | llvm::MDNode * |
262 | LoopAnnotationTranslation::translateLoopAnnotation(LoopAnnotationAttr attr, |
263 | Operation *op) { |
264 | if (!attr) |
265 | return nullptr; |
266 | |
267 | llvm::MDNode *loopMD = lookupLoopMetadata(options: attr); |
268 | if (loopMD) |
269 | return loopMD; |
270 | |
271 | loopMD = |
272 | LoopAnnotationConversion(attr, op, *this, this->llvmModule.getContext()) |
273 | .convert(); |
274 | // Store a map from this Attribute to the LLVM metadata in case we |
275 | // encounter it again. |
276 | mapLoopMetadata(options: attr, metadata: loopMD); |
277 | return loopMD; |
278 | } |
279 | |
280 | llvm::MDNode * |
281 | LoopAnnotationTranslation::getAccessGroup(AccessGroupAttr accessGroupAttr) { |
282 | auto [result, inserted] = |
283 | accessGroupMetadataMapping.insert({accessGroupAttr, nullptr}); |
284 | if (inserted) |
285 | result->second = llvm::MDNode::getDistinct(llvmModule.getContext(), {}); |
286 | return result->second; |
287 | } |
288 | |
289 | llvm::MDNode * |
290 | LoopAnnotationTranslation::getAccessGroups(AccessGroupOpInterface op) { |
291 | ArrayAttr accessGroups = op.getAccessGroupsOrNull(); |
292 | if (!accessGroups || accessGroups.empty()) |
293 | return nullptr; |
294 | |
295 | SmallVector<llvm::Metadata *> groupMDs; |
296 | for (AccessGroupAttr group : accessGroups.getAsRange<AccessGroupAttr>()) |
297 | groupMDs.push_back(getAccessGroup(group)); |
298 | if (groupMDs.size() == 1) |
299 | return llvm::cast<llvm::MDNode>(Val: groupMDs.front()); |
300 | return llvm::MDNode::get(Context&: llvmModule.getContext(), MDs: groupMDs); |
301 | } |
302 | |