1 | //===- LoopAnnotationImporter.cpp - Loop annotation import ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "LoopAnnotationImporter.h" |
10 | #include "llvm/IR/Constants.h" |
11 | |
12 | using namespace mlir; |
13 | using namespace mlir::LLVM; |
14 | using namespace mlir::LLVM::detail; |
15 | |
16 | namespace { |
17 | /// Helper class that keeps the state of one metadata to attribute conversion. |
18 | struct LoopMetadataConversion { |
19 | LoopMetadataConversion(const llvm::MDNode *node, Location loc, |
20 | LoopAnnotationImporter &loopAnnotationImporter) |
21 | : node(node), loc(loc), loopAnnotationImporter(loopAnnotationImporter), |
22 | ctx(loc->getContext()){}; |
23 | /// Converts this structs loop metadata node into a LoopAnnotationAttr. |
24 | LoopAnnotationAttr convert(); |
25 | |
26 | /// Initializes the shared state for the conversion member functions. |
27 | LogicalResult initConversionState(); |
28 | |
29 | /// Helper function to get and erase a property. |
30 | const llvm::MDNode *lookupAndEraseProperty(StringRef name); |
31 | |
32 | /// Helper functions to lookup and convert MDNodes into a specifc attribute |
33 | /// kind. These functions return null-attributes if there is no node with the |
34 | /// specified name, or failure, if the node is ill-formatted. |
35 | FailureOr<BoolAttr> lookupUnitNode(StringRef name); |
36 | FailureOr<BoolAttr> lookupBoolNode(StringRef name, bool negated = false); |
37 | FailureOr<BoolAttr> lookupIntNodeAsBoolAttr(StringRef name); |
38 | FailureOr<IntegerAttr> lookupIntNode(StringRef name); |
39 | FailureOr<llvm::MDNode *> lookupMDNode(StringRef name); |
40 | FailureOr<SmallVector<llvm::MDNode *>> lookupMDNodes(StringRef name); |
41 | FailureOr<LoopAnnotationAttr> lookupFollowupNode(StringRef name); |
42 | FailureOr<BoolAttr> lookupBooleanUnitNode(StringRef enableName, |
43 | StringRef disableName, |
44 | bool negated = false); |
45 | |
46 | /// Conversion functions for sub-attributes. |
47 | FailureOr<LoopVectorizeAttr> convertVectorizeAttr(); |
48 | FailureOr<LoopInterleaveAttr> convertInterleaveAttr(); |
49 | FailureOr<LoopUnrollAttr> convertUnrollAttr(); |
50 | FailureOr<LoopUnrollAndJamAttr> convertUnrollAndJamAttr(); |
51 | FailureOr<LoopLICMAttr> convertLICMAttr(); |
52 | FailureOr<LoopDistributeAttr> convertDistributeAttr(); |
53 | FailureOr<LoopPipelineAttr> convertPipelineAttr(); |
54 | FailureOr<LoopPeeledAttr> convertPeeledAttr(); |
55 | FailureOr<LoopUnswitchAttr> convertUnswitchAttr(); |
56 | FailureOr<SmallVector<AccessGroupAttr>> convertParallelAccesses(); |
57 | FusedLoc convertStartLoc(); |
58 | FailureOr<FusedLoc> convertEndLoc(); |
59 | |
60 | llvm::SmallVector<llvm::DILocation *, 2> locations; |
61 | llvm::StringMap<const llvm::MDNode *> propertyMap; |
62 | const llvm::MDNode *node; |
63 | Location loc; |
64 | LoopAnnotationImporter &loopAnnotationImporter; |
65 | MLIRContext *ctx; |
66 | }; |
67 | } // namespace |
68 | |
69 | LogicalResult LoopMetadataConversion::initConversionState() { |
70 | // Check if it's a valid node. |
71 | if (node->getNumOperands() == 0 || |
72 | dyn_cast<llvm::MDNode>(Val: node->getOperand(I: 0)) != node) |
73 | return emitWarning(loc) << "invalid loop node" ; |
74 | |
75 | for (const llvm::MDOperand &operand : llvm::drop_begin(RangeOrContainer: node->operands())) { |
76 | if (auto *diLoc = dyn_cast<llvm::DILocation>(Val: operand)) { |
77 | locations.push_back(Elt: diLoc); |
78 | continue; |
79 | } |
80 | |
81 | auto *property = dyn_cast<llvm::MDNode>(Val: operand); |
82 | if (!property) |
83 | return emitWarning(loc) << "expected all loop properties to be either " |
84 | "debug locations or metadata nodes" ; |
85 | |
86 | if (property->getNumOperands() == 0) |
87 | return emitWarning(loc) << "cannot import empty loop property" ; |
88 | |
89 | auto *nameNode = dyn_cast<llvm::MDString>(Val: property->getOperand(I: 0)); |
90 | if (!nameNode) |
91 | return emitWarning(loc) << "cannot import loop property without a name" ; |
92 | StringRef name = nameNode->getString(); |
93 | |
94 | bool succ = propertyMap.try_emplace(Key: name, Args&: property).second; |
95 | if (!succ) |
96 | return emitWarning(loc) |
97 | << "cannot import loop properties with duplicated names " << name; |
98 | } |
99 | |
100 | return success(); |
101 | } |
102 | |
103 | const llvm::MDNode * |
104 | LoopMetadataConversion::lookupAndEraseProperty(StringRef name) { |
105 | auto it = propertyMap.find(Key: name); |
106 | if (it == propertyMap.end()) |
107 | return nullptr; |
108 | const llvm::MDNode *property = it->getValue(); |
109 | propertyMap.erase(I: it); |
110 | return property; |
111 | } |
112 | |
113 | FailureOr<BoolAttr> LoopMetadataConversion::lookupUnitNode(StringRef name) { |
114 | const llvm::MDNode *property = lookupAndEraseProperty(name); |
115 | if (!property) |
116 | return BoolAttr(nullptr); |
117 | |
118 | if (property->getNumOperands() != 1) |
119 | return emitWarning(loc) |
120 | << "expected metadata node " << name << " to hold no value" ; |
121 | |
122 | return BoolAttr::get(context: ctx, value: true); |
123 | } |
124 | |
125 | FailureOr<BoolAttr> LoopMetadataConversion::lookupBooleanUnitNode( |
126 | StringRef enableName, StringRef disableName, bool negated) { |
127 | auto enable = lookupUnitNode(name: enableName); |
128 | auto disable = lookupUnitNode(name: disableName); |
129 | if (failed(result: enable) || failed(result: disable)) |
130 | return failure(); |
131 | |
132 | if (*enable && *disable) |
133 | return emitWarning(loc) |
134 | << "expected metadata nodes " << enableName << " and " << disableName |
135 | << " to be mutually exclusive." ; |
136 | |
137 | if (*enable) |
138 | return BoolAttr::get(context: ctx, value: !negated); |
139 | |
140 | if (*disable) |
141 | return BoolAttr::get(context: ctx, value: negated); |
142 | return BoolAttr(nullptr); |
143 | } |
144 | |
145 | FailureOr<BoolAttr> LoopMetadataConversion::lookupBoolNode(StringRef name, |
146 | bool negated) { |
147 | const llvm::MDNode *property = lookupAndEraseProperty(name); |
148 | if (!property) |
149 | return BoolAttr(nullptr); |
150 | |
151 | auto emitNodeWarning = [&]() { |
152 | return emitWarning(loc) |
153 | << "expected metadata node " << name << " to hold a boolean value" ; |
154 | }; |
155 | |
156 | if (property->getNumOperands() != 2) |
157 | return emitNodeWarning(); |
158 | llvm::ConstantInt *val = |
159 | llvm::mdconst::dyn_extract<llvm::ConstantInt>(MD: property->getOperand(I: 1)); |
160 | if (!val || val->getBitWidth() != 1) |
161 | return emitNodeWarning(); |
162 | |
163 | return BoolAttr::get(context: ctx, value: val->getValue().getLimitedValue(Limit: 1) ^ negated); |
164 | } |
165 | |
166 | FailureOr<BoolAttr> |
167 | LoopMetadataConversion::lookupIntNodeAsBoolAttr(StringRef name) { |
168 | const llvm::MDNode *property = lookupAndEraseProperty(name); |
169 | if (!property) |
170 | return BoolAttr(nullptr); |
171 | |
172 | auto emitNodeWarning = [&]() { |
173 | return emitWarning(loc) |
174 | << "expected metadata node " << name << " to hold an integer value" ; |
175 | }; |
176 | |
177 | if (property->getNumOperands() != 2) |
178 | return emitNodeWarning(); |
179 | llvm::ConstantInt *val = |
180 | llvm::mdconst::dyn_extract<llvm::ConstantInt>(MD: property->getOperand(I: 1)); |
181 | if (!val || val->getBitWidth() != 32) |
182 | return emitNodeWarning(); |
183 | |
184 | return BoolAttr::get(context: ctx, value: val->getValue().getLimitedValue(Limit: 1)); |
185 | } |
186 | |
187 | FailureOr<IntegerAttr> LoopMetadataConversion::lookupIntNode(StringRef name) { |
188 | const llvm::MDNode *property = lookupAndEraseProperty(name); |
189 | if (!property) |
190 | return IntegerAttr(nullptr); |
191 | |
192 | auto emitNodeWarning = [&]() { |
193 | return emitWarning(loc) |
194 | << "expected metadata node " << name << " to hold an i32 value" ; |
195 | }; |
196 | |
197 | if (property->getNumOperands() != 2) |
198 | return emitNodeWarning(); |
199 | |
200 | llvm::ConstantInt *val = |
201 | llvm::mdconst::dyn_extract<llvm::ConstantInt>(MD: property->getOperand(I: 1)); |
202 | if (!val || val->getBitWidth() != 32) |
203 | return emitNodeWarning(); |
204 | |
205 | return IntegerAttr::get(IntegerType::get(ctx, 32), |
206 | val->getValue().getLimitedValue()); |
207 | } |
208 | |
209 | FailureOr<llvm::MDNode *> LoopMetadataConversion::lookupMDNode(StringRef name) { |
210 | const llvm::MDNode *property = lookupAndEraseProperty(name); |
211 | if (!property) |
212 | return nullptr; |
213 | |
214 | auto emitNodeWarning = [&]() { |
215 | return emitWarning(loc) |
216 | << "expected metadata node " << name << " to hold an MDNode" ; |
217 | }; |
218 | |
219 | if (property->getNumOperands() != 2) |
220 | return emitNodeWarning(); |
221 | |
222 | auto *node = dyn_cast<llvm::MDNode>(Val: property->getOperand(I: 1)); |
223 | if (!node) |
224 | return emitNodeWarning(); |
225 | |
226 | return node; |
227 | } |
228 | |
229 | FailureOr<SmallVector<llvm::MDNode *>> |
230 | LoopMetadataConversion::lookupMDNodes(StringRef name) { |
231 | const llvm::MDNode *property = lookupAndEraseProperty(name); |
232 | SmallVector<llvm::MDNode *> res; |
233 | if (!property) |
234 | return res; |
235 | |
236 | auto emitNodeWarning = [&]() { |
237 | return emitWarning(loc) << "expected metadata node " << name |
238 | << " to hold one or multiple MDNodes" ; |
239 | }; |
240 | |
241 | if (property->getNumOperands() < 2) |
242 | return emitNodeWarning(); |
243 | |
244 | for (unsigned i = 1, e = property->getNumOperands(); i < e; ++i) { |
245 | auto *node = dyn_cast<llvm::MDNode>(Val: property->getOperand(I: i)); |
246 | if (!node) |
247 | return emitNodeWarning(); |
248 | res.push_back(Elt: node); |
249 | } |
250 | |
251 | return res; |
252 | } |
253 | |
254 | FailureOr<LoopAnnotationAttr> |
255 | LoopMetadataConversion::lookupFollowupNode(StringRef name) { |
256 | auto node = lookupMDNode(name); |
257 | if (failed(result: node)) |
258 | return failure(); |
259 | if (*node == nullptr) |
260 | return LoopAnnotationAttr(nullptr); |
261 | |
262 | return loopAnnotationImporter.translateLoopAnnotation(*node, loc); |
263 | } |
264 | |
265 | static bool isEmptyOrNull(const Attribute attr) { return !attr; } |
266 | |
267 | template <typename T> |
268 | static bool isEmptyOrNull(const SmallVectorImpl<T> &vec) { |
269 | return vec.empty(); |
270 | } |
271 | |
272 | /// Helper function that only creates and attribute of type T if all argument |
273 | /// conversion were successfull and at least one of them holds a non-null value. |
274 | template <typename T, typename... P> |
275 | static T createIfNonNull(MLIRContext *ctx, const P &...args) { |
276 | bool anyFailed = (failed(args) || ...); |
277 | if (anyFailed) |
278 | return {}; |
279 | |
280 | bool allEmpty = (isEmptyOrNull(*args) && ...); |
281 | if (allEmpty) |
282 | return {}; |
283 | |
284 | return T::get(ctx, *args...); |
285 | } |
286 | |
287 | FailureOr<LoopVectorizeAttr> LoopMetadataConversion::convertVectorizeAttr() { |
288 | FailureOr<BoolAttr> enable = |
289 | lookupBoolNode(name: "llvm.loop.vectorize.enable" , negated: true); |
290 | FailureOr<BoolAttr> predicateEnable = |
291 | lookupBoolNode(name: "llvm.loop.vectorize.predicate.enable" ); |
292 | FailureOr<BoolAttr> scalableEnable = |
293 | lookupBoolNode(name: "llvm.loop.vectorize.scalable.enable" ); |
294 | FailureOr<IntegerAttr> width = lookupIntNode("llvm.loop.vectorize.width" ); |
295 | FailureOr<LoopAnnotationAttr> followupVec = |
296 | lookupFollowupNode("llvm.loop.vectorize.followup_vectorized" ); |
297 | FailureOr<LoopAnnotationAttr> followupEpi = |
298 | lookupFollowupNode("llvm.loop.vectorize.followup_epilogue" ); |
299 | FailureOr<LoopAnnotationAttr> followupAll = |
300 | lookupFollowupNode("llvm.loop.vectorize.followup_all" ); |
301 | |
302 | return createIfNonNull<LoopVectorizeAttr>(ctx, enable, predicateEnable, |
303 | scalableEnable, width, followupVec, |
304 | followupEpi, followupAll); |
305 | } |
306 | |
307 | FailureOr<LoopInterleaveAttr> LoopMetadataConversion::convertInterleaveAttr() { |
308 | FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.interleave.count" ); |
309 | return createIfNonNull<LoopInterleaveAttr>(ctx, count); |
310 | } |
311 | |
312 | FailureOr<LoopUnrollAttr> LoopMetadataConversion::convertUnrollAttr() { |
313 | FailureOr<BoolAttr> disable = lookupBooleanUnitNode( |
314 | enableName: "llvm.loop.unroll.enable" , disableName: "llvm.loop.unroll.disable" , /*negated=*/true); |
315 | FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.unroll.count" ); |
316 | FailureOr<BoolAttr> runtimeDisable = |
317 | lookupUnitNode(name: "llvm.loop.unroll.runtime.disable" ); |
318 | FailureOr<BoolAttr> full = lookupUnitNode(name: "llvm.loop.unroll.full" ); |
319 | FailureOr<LoopAnnotationAttr> followupUnrolled = |
320 | lookupFollowupNode("llvm.loop.unroll.followup_unrolled" ); |
321 | FailureOr<LoopAnnotationAttr> followupRemainder = |
322 | lookupFollowupNode("llvm.loop.unroll.followup_remainder" ); |
323 | FailureOr<LoopAnnotationAttr> followupAll = |
324 | lookupFollowupNode("llvm.loop.unroll.followup_all" ); |
325 | |
326 | return createIfNonNull<LoopUnrollAttr>(ctx, disable, count, runtimeDisable, |
327 | full, followupUnrolled, |
328 | followupRemainder, followupAll); |
329 | } |
330 | |
331 | FailureOr<LoopUnrollAndJamAttr> |
332 | LoopMetadataConversion::convertUnrollAndJamAttr() { |
333 | FailureOr<BoolAttr> disable = lookupBooleanUnitNode( |
334 | enableName: "llvm.loop.unroll_and_jam.enable" , disableName: "llvm.loop.unroll_and_jam.disable" , |
335 | /*negated=*/true); |
336 | FailureOr<IntegerAttr> count = |
337 | lookupIntNode("llvm.loop.unroll_and_jam.count" ); |
338 | FailureOr<LoopAnnotationAttr> followupOuter = |
339 | lookupFollowupNode("llvm.loop.unroll_and_jam.followup_outer" ); |
340 | FailureOr<LoopAnnotationAttr> followupInner = |
341 | lookupFollowupNode("llvm.loop.unroll_and_jam.followup_inner" ); |
342 | FailureOr<LoopAnnotationAttr> followupRemainderOuter = |
343 | lookupFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_outer" ); |
344 | FailureOr<LoopAnnotationAttr> followupRemainderInner = |
345 | lookupFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_inner" ); |
346 | FailureOr<LoopAnnotationAttr> followupAll = |
347 | lookupFollowupNode("llvm.loop.unroll_and_jam.followup_all" ); |
348 | return createIfNonNull<LoopUnrollAndJamAttr>( |
349 | ctx, disable, count, followupOuter, followupInner, followupRemainderOuter, |
350 | followupRemainderInner, followupAll); |
351 | } |
352 | |
353 | FailureOr<LoopLICMAttr> LoopMetadataConversion::convertLICMAttr() { |
354 | FailureOr<BoolAttr> disable = lookupUnitNode(name: "llvm.licm.disable" ); |
355 | FailureOr<BoolAttr> versioningDisable = |
356 | lookupUnitNode(name: "llvm.loop.licm_versioning.disable" ); |
357 | return createIfNonNull<LoopLICMAttr>(ctx, disable, versioningDisable); |
358 | } |
359 | |
360 | FailureOr<LoopDistributeAttr> LoopMetadataConversion::convertDistributeAttr() { |
361 | FailureOr<BoolAttr> disable = |
362 | lookupBoolNode(name: "llvm.loop.distribute.enable" , negated: true); |
363 | FailureOr<LoopAnnotationAttr> followupCoincident = |
364 | lookupFollowupNode("llvm.loop.distribute.followup_coincident" ); |
365 | FailureOr<LoopAnnotationAttr> followupSequential = |
366 | lookupFollowupNode("llvm.loop.distribute.followup_sequential" ); |
367 | FailureOr<LoopAnnotationAttr> followupFallback = |
368 | lookupFollowupNode("llvm.loop.distribute.followup_fallback" ); |
369 | FailureOr<LoopAnnotationAttr> followupAll = |
370 | lookupFollowupNode("llvm.loop.distribute.followup_all" ); |
371 | return createIfNonNull<LoopDistributeAttr>(ctx, disable, followupCoincident, |
372 | followupSequential, |
373 | followupFallback, followupAll); |
374 | } |
375 | |
376 | FailureOr<LoopPipelineAttr> LoopMetadataConversion::convertPipelineAttr() { |
377 | FailureOr<BoolAttr> disable = lookupBoolNode(name: "llvm.loop.pipeline.disable" ); |
378 | FailureOr<IntegerAttr> initiationinterval = |
379 | lookupIntNode("llvm.loop.pipeline.initiationinterval" ); |
380 | return createIfNonNull<LoopPipelineAttr>(ctx, disable, initiationinterval); |
381 | } |
382 | |
383 | FailureOr<LoopPeeledAttr> LoopMetadataConversion::convertPeeledAttr() { |
384 | FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.peeled.count" ); |
385 | return createIfNonNull<LoopPeeledAttr>(ctx, count); |
386 | } |
387 | |
388 | FailureOr<LoopUnswitchAttr> LoopMetadataConversion::convertUnswitchAttr() { |
389 | FailureOr<BoolAttr> partialDisable = |
390 | lookupUnitNode(name: "llvm.loop.unswitch.partial.disable" ); |
391 | return createIfNonNull<LoopUnswitchAttr>(ctx, partialDisable); |
392 | } |
393 | |
394 | FailureOr<SmallVector<AccessGroupAttr>> |
395 | LoopMetadataConversion::convertParallelAccesses() { |
396 | FailureOr<SmallVector<llvm::MDNode *>> nodes = |
397 | lookupMDNodes(name: "llvm.loop.parallel_accesses" ); |
398 | if (failed(result: nodes)) |
399 | return failure(); |
400 | SmallVector<AccessGroupAttr> refs; |
401 | for (llvm::MDNode *node : *nodes) { |
402 | FailureOr<SmallVector<AccessGroupAttr>> accessGroups = |
403 | loopAnnotationImporter.lookupAccessGroupAttrs(node); |
404 | if (failed(accessGroups)) { |
405 | emitWarning(loc) << "could not lookup access group" ; |
406 | continue; |
407 | } |
408 | llvm::append_range(refs, *accessGroups); |
409 | } |
410 | return refs; |
411 | } |
412 | |
413 | FusedLoc LoopMetadataConversion::convertStartLoc() { |
414 | if (locations.empty()) |
415 | return {}; |
416 | return dyn_cast<FusedLoc>( |
417 | loopAnnotationImporter.moduleImport.translateLoc(locations[0])); |
418 | } |
419 | |
420 | FailureOr<FusedLoc> LoopMetadataConversion::convertEndLoc() { |
421 | if (locations.size() < 2) |
422 | return FusedLoc(); |
423 | if (locations.size() > 2) |
424 | return emitError(loc) |
425 | << "expected loop metadata to have at most two DILocations" ; |
426 | return dyn_cast<FusedLoc>( |
427 | loopAnnotationImporter.moduleImport.translateLoc(locations[1])); |
428 | } |
429 | |
430 | LoopAnnotationAttr LoopMetadataConversion::convert() { |
431 | if (failed(result: initConversionState())) |
432 | return {}; |
433 | |
434 | FailureOr<BoolAttr> disableNonForced = |
435 | lookupUnitNode(name: "llvm.loop.disable_nonforced" ); |
436 | FailureOr<LoopVectorizeAttr> vecAttr = convertVectorizeAttr(); |
437 | FailureOr<LoopInterleaveAttr> interleaveAttr = convertInterleaveAttr(); |
438 | FailureOr<LoopUnrollAttr> unrollAttr = convertUnrollAttr(); |
439 | FailureOr<LoopUnrollAndJamAttr> unrollAndJamAttr = convertUnrollAndJamAttr(); |
440 | FailureOr<LoopLICMAttr> licmAttr = convertLICMAttr(); |
441 | FailureOr<LoopDistributeAttr> distributeAttr = convertDistributeAttr(); |
442 | FailureOr<LoopPipelineAttr> pipelineAttr = convertPipelineAttr(); |
443 | FailureOr<LoopPeeledAttr> peeledAttr = convertPeeledAttr(); |
444 | FailureOr<LoopUnswitchAttr> unswitchAttr = convertUnswitchAttr(); |
445 | FailureOr<BoolAttr> mustProgress = lookupUnitNode(name: "llvm.loop.mustprogress" ); |
446 | FailureOr<BoolAttr> isVectorized = |
447 | lookupIntNodeAsBoolAttr(name: "llvm.loop.isvectorized" ); |
448 | FailureOr<SmallVector<AccessGroupAttr>> parallelAccesses = |
449 | convertParallelAccesses(); |
450 | |
451 | // Drop the metadata if there are parts that cannot be imported. |
452 | if (!propertyMap.empty()) { |
453 | for (auto name : propertyMap.keys()) |
454 | emitWarning(loc) << "unknown loop annotation " << name; |
455 | return {}; |
456 | } |
457 | |
458 | FailureOr<FusedLoc> startLoc = convertStartLoc(); |
459 | FailureOr<FusedLoc> endLoc = convertEndLoc(); |
460 | |
461 | return createIfNonNull<LoopAnnotationAttr>( |
462 | ctx, disableNonForced, vecAttr, interleaveAttr, unrollAttr, |
463 | unrollAndJamAttr, licmAttr, distributeAttr, pipelineAttr, peeledAttr, |
464 | unswitchAttr, mustProgress, isVectorized, startLoc, endLoc, |
465 | parallelAccesses); |
466 | } |
467 | |
468 | LoopAnnotationAttr |
469 | LoopAnnotationImporter::translateLoopAnnotation(const llvm::MDNode *node, |
470 | Location loc) { |
471 | if (!node) |
472 | return {}; |
473 | |
474 | // Note: This check is necessary to distinguish between failed translations |
475 | // and not yet attempted translations. |
476 | auto it = loopMetadataMapping.find(node); |
477 | if (it != loopMetadataMapping.end()) |
478 | return it->getSecond(); |
479 | |
480 | LoopAnnotationAttr attr = LoopMetadataConversion(node, loc, *this).convert(); |
481 | |
482 | mapLoopMetadata(node, attr); |
483 | return attr; |
484 | } |
485 | |
486 | LogicalResult |
487 | LoopAnnotationImporter::translateAccessGroup(const llvm::MDNode *node, |
488 | Location loc) { |
489 | SmallVector<const llvm::MDNode *> accessGroups; |
490 | if (!node->getNumOperands()) |
491 | accessGroups.push_back(Elt: node); |
492 | for (const llvm::MDOperand &operand : node->operands()) { |
493 | auto *childNode = dyn_cast<llvm::MDNode>(Val: operand); |
494 | if (!childNode) |
495 | return failure(); |
496 | accessGroups.push_back(Elt: cast<llvm::MDNode>(Val: operand.get())); |
497 | } |
498 | |
499 | // Convert all entries of the access group list to access group operations. |
500 | for (const llvm::MDNode *accessGroup : accessGroups) { |
501 | if (accessGroupMapping.count(accessGroup)) |
502 | continue; |
503 | // Verify the access group node is distinct and empty. |
504 | if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct()) |
505 | return emitWarning(loc) |
506 | << "expected an access group node to be empty and distinct" ; |
507 | |
508 | // Add a mapping from the access group node to the newly created attribute. |
509 | accessGroupMapping[accessGroup] = builder.getAttr<AccessGroupAttr>(); |
510 | } |
511 | return success(); |
512 | } |
513 | |
514 | FailureOr<SmallVector<AccessGroupAttr>> |
515 | LoopAnnotationImporter::lookupAccessGroupAttrs(const llvm::MDNode *node) const { |
516 | // An access group node is either a single access group or an access group |
517 | // list. |
518 | SmallVector<AccessGroupAttr> accessGroups; |
519 | if (!node->getNumOperands()) |
520 | accessGroups.push_back(accessGroupMapping.lookup(node)); |
521 | for (const llvm::MDOperand &operand : node->operands()) { |
522 | auto *node = cast<llvm::MDNode>(Val: operand.get()); |
523 | accessGroups.push_back(accessGroupMapping.lookup(node)); |
524 | } |
525 | // Exit if one of the access group node lookups failed. |
526 | if (llvm::is_contained(accessGroups, nullptr)) |
527 | return failure(); |
528 | return accessGroups; |
529 | } |
530 | |