1//===- LoopAnnotationImporter.cpp - Loop annotation import ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "LoopAnnotationImporter.h"
10#include "llvm/IR/Constants.h"
11
12using namespace mlir;
13using namespace mlir::LLVM;
14using namespace mlir::LLVM::detail;
15
16namespace {
17/// Helper class that keeps the state of one metadata to attribute conversion.
18struct LoopMetadataConversion {
19 LoopMetadataConversion(const llvm::MDNode *node, Location loc,
20 LoopAnnotationImporter &loopAnnotationImporter)
21 : node(node), loc(loc), loopAnnotationImporter(loopAnnotationImporter),
22 ctx(loc->getContext()){};
23 /// Converts this structs loop metadata node into a LoopAnnotationAttr.
24 LoopAnnotationAttr convert();
25
26 /// Initializes the shared state for the conversion member functions.
27 LogicalResult initConversionState();
28
29 /// Helper function to get and erase a property.
30 const llvm::MDNode *lookupAndEraseProperty(StringRef name);
31
32 /// Helper functions to lookup and convert MDNodes into a specifc attribute
33 /// kind. These functions return null-attributes if there is no node with the
34 /// specified name, or failure, if the node is ill-formatted.
35 FailureOr<BoolAttr> lookupUnitNode(StringRef name);
36 FailureOr<BoolAttr> lookupBoolNode(StringRef name, bool negated = false);
37 FailureOr<BoolAttr> lookupIntNodeAsBoolAttr(StringRef name);
38 FailureOr<IntegerAttr> lookupIntNode(StringRef name);
39 FailureOr<llvm::MDNode *> lookupMDNode(StringRef name);
40 FailureOr<SmallVector<llvm::MDNode *>> lookupMDNodes(StringRef name);
41 FailureOr<LoopAnnotationAttr> lookupFollowupNode(StringRef name);
42 FailureOr<BoolAttr> lookupBooleanUnitNode(StringRef enableName,
43 StringRef disableName,
44 bool negated = false);
45
46 /// Conversion functions for sub-attributes.
47 FailureOr<LoopVectorizeAttr> convertVectorizeAttr();
48 FailureOr<LoopInterleaveAttr> convertInterleaveAttr();
49 FailureOr<LoopUnrollAttr> convertUnrollAttr();
50 FailureOr<LoopUnrollAndJamAttr> convertUnrollAndJamAttr();
51 FailureOr<LoopLICMAttr> convertLICMAttr();
52 FailureOr<LoopDistributeAttr> convertDistributeAttr();
53 FailureOr<LoopPipelineAttr> convertPipelineAttr();
54 FailureOr<LoopPeeledAttr> convertPeeledAttr();
55 FailureOr<LoopUnswitchAttr> convertUnswitchAttr();
56 FailureOr<SmallVector<AccessGroupAttr>> convertParallelAccesses();
57 FusedLoc convertStartLoc();
58 FailureOr<FusedLoc> convertEndLoc();
59
60 llvm::SmallVector<llvm::DILocation *, 2> locations;
61 llvm::StringMap<const llvm::MDNode *> propertyMap;
62 const llvm::MDNode *node;
63 Location loc;
64 LoopAnnotationImporter &loopAnnotationImporter;
65 MLIRContext *ctx;
66};
67} // namespace
68
69LogicalResult LoopMetadataConversion::initConversionState() {
70 // Check if it's a valid node.
71 if (node->getNumOperands() == 0 ||
72 dyn_cast<llvm::MDNode>(Val: node->getOperand(I: 0)) != node)
73 return emitWarning(loc) << "invalid loop node";
74
75 for (const llvm::MDOperand &operand : llvm::drop_begin(RangeOrContainer: node->operands())) {
76 if (auto *diLoc = dyn_cast<llvm::DILocation>(Val: operand)) {
77 locations.push_back(Elt: diLoc);
78 continue;
79 }
80
81 auto *property = dyn_cast<llvm::MDNode>(Val: operand);
82 if (!property)
83 return emitWarning(loc) << "expected all loop properties to be either "
84 "debug locations or metadata nodes";
85
86 if (property->getNumOperands() == 0)
87 return emitWarning(loc) << "cannot import empty loop property";
88
89 auto *nameNode = dyn_cast<llvm::MDString>(Val: property->getOperand(I: 0));
90 if (!nameNode)
91 return emitWarning(loc) << "cannot import loop property without a name";
92 StringRef name = nameNode->getString();
93
94 bool succ = propertyMap.try_emplace(Key: name, Args&: property).second;
95 if (!succ)
96 return emitWarning(loc)
97 << "cannot import loop properties with duplicated names " << name;
98 }
99
100 return success();
101}
102
103const llvm::MDNode *
104LoopMetadataConversion::lookupAndEraseProperty(StringRef name) {
105 auto it = propertyMap.find(Key: name);
106 if (it == propertyMap.end())
107 return nullptr;
108 const llvm::MDNode *property = it->getValue();
109 propertyMap.erase(I: it);
110 return property;
111}
112
113FailureOr<BoolAttr> LoopMetadataConversion::lookupUnitNode(StringRef name) {
114 const llvm::MDNode *property = lookupAndEraseProperty(name);
115 if (!property)
116 return BoolAttr(nullptr);
117
118 if (property->getNumOperands() != 1)
119 return emitWarning(loc)
120 << "expected metadata node " << name << " to hold no value";
121
122 return BoolAttr::get(context: ctx, value: true);
123}
124
125FailureOr<BoolAttr> LoopMetadataConversion::lookupBooleanUnitNode(
126 StringRef enableName, StringRef disableName, bool negated) {
127 auto enable = lookupUnitNode(name: enableName);
128 auto disable = lookupUnitNode(name: disableName);
129 if (failed(result: enable) || failed(result: disable))
130 return failure();
131
132 if (*enable && *disable)
133 return emitWarning(loc)
134 << "expected metadata nodes " << enableName << " and " << disableName
135 << " to be mutually exclusive.";
136
137 if (*enable)
138 return BoolAttr::get(context: ctx, value: !negated);
139
140 if (*disable)
141 return BoolAttr::get(context: ctx, value: negated);
142 return BoolAttr(nullptr);
143}
144
145FailureOr<BoolAttr> LoopMetadataConversion::lookupBoolNode(StringRef name,
146 bool negated) {
147 const llvm::MDNode *property = lookupAndEraseProperty(name);
148 if (!property)
149 return BoolAttr(nullptr);
150
151 auto emitNodeWarning = [&]() {
152 return emitWarning(loc)
153 << "expected metadata node " << name << " to hold a boolean value";
154 };
155
156 if (property->getNumOperands() != 2)
157 return emitNodeWarning();
158 llvm::ConstantInt *val =
159 llvm::mdconst::dyn_extract<llvm::ConstantInt>(MD: property->getOperand(I: 1));
160 if (!val || val->getBitWidth() != 1)
161 return emitNodeWarning();
162
163 return BoolAttr::get(context: ctx, value: val->getValue().getLimitedValue(Limit: 1) ^ negated);
164}
165
166FailureOr<BoolAttr>
167LoopMetadataConversion::lookupIntNodeAsBoolAttr(StringRef name) {
168 const llvm::MDNode *property = lookupAndEraseProperty(name);
169 if (!property)
170 return BoolAttr(nullptr);
171
172 auto emitNodeWarning = [&]() {
173 return emitWarning(loc)
174 << "expected metadata node " << name << " to hold an integer value";
175 };
176
177 if (property->getNumOperands() != 2)
178 return emitNodeWarning();
179 llvm::ConstantInt *val =
180 llvm::mdconst::dyn_extract<llvm::ConstantInt>(MD: property->getOperand(I: 1));
181 if (!val || val->getBitWidth() != 32)
182 return emitNodeWarning();
183
184 return BoolAttr::get(context: ctx, value: val->getValue().getLimitedValue(Limit: 1));
185}
186
187FailureOr<IntegerAttr> LoopMetadataConversion::lookupIntNode(StringRef name) {
188 const llvm::MDNode *property = lookupAndEraseProperty(name);
189 if (!property)
190 return IntegerAttr(nullptr);
191
192 auto emitNodeWarning = [&]() {
193 return emitWarning(loc)
194 << "expected metadata node " << name << " to hold an i32 value";
195 };
196
197 if (property->getNumOperands() != 2)
198 return emitNodeWarning();
199
200 llvm::ConstantInt *val =
201 llvm::mdconst::dyn_extract<llvm::ConstantInt>(MD: property->getOperand(I: 1));
202 if (!val || val->getBitWidth() != 32)
203 return emitNodeWarning();
204
205 return IntegerAttr::get(IntegerType::get(ctx, 32),
206 val->getValue().getLimitedValue());
207}
208
209FailureOr<llvm::MDNode *> LoopMetadataConversion::lookupMDNode(StringRef name) {
210 const llvm::MDNode *property = lookupAndEraseProperty(name);
211 if (!property)
212 return nullptr;
213
214 auto emitNodeWarning = [&]() {
215 return emitWarning(loc)
216 << "expected metadata node " << name << " to hold an MDNode";
217 };
218
219 if (property->getNumOperands() != 2)
220 return emitNodeWarning();
221
222 auto *node = dyn_cast<llvm::MDNode>(Val: property->getOperand(I: 1));
223 if (!node)
224 return emitNodeWarning();
225
226 return node;
227}
228
229FailureOr<SmallVector<llvm::MDNode *>>
230LoopMetadataConversion::lookupMDNodes(StringRef name) {
231 const llvm::MDNode *property = lookupAndEraseProperty(name);
232 SmallVector<llvm::MDNode *> res;
233 if (!property)
234 return res;
235
236 auto emitNodeWarning = [&]() {
237 return emitWarning(loc) << "expected metadata node " << name
238 << " to hold one or multiple MDNodes";
239 };
240
241 if (property->getNumOperands() < 2)
242 return emitNodeWarning();
243
244 for (unsigned i = 1, e = property->getNumOperands(); i < e; ++i) {
245 auto *node = dyn_cast<llvm::MDNode>(Val: property->getOperand(I: i));
246 if (!node)
247 return emitNodeWarning();
248 res.push_back(Elt: node);
249 }
250
251 return res;
252}
253
254FailureOr<LoopAnnotationAttr>
255LoopMetadataConversion::lookupFollowupNode(StringRef name) {
256 auto node = lookupMDNode(name);
257 if (failed(result: node))
258 return failure();
259 if (*node == nullptr)
260 return LoopAnnotationAttr(nullptr);
261
262 return loopAnnotationImporter.translateLoopAnnotation(*node, loc);
263}
264
265static bool isEmptyOrNull(const Attribute attr) { return !attr; }
266
267template <typename T>
268static bool isEmptyOrNull(const SmallVectorImpl<T> &vec) {
269 return vec.empty();
270}
271
272/// Helper function that only creates and attribute of type T if all argument
273/// conversion were successfull and at least one of them holds a non-null value.
274template <typename T, typename... P>
275static T createIfNonNull(MLIRContext *ctx, const P &...args) {
276 bool anyFailed = (failed(args) || ...);
277 if (anyFailed)
278 return {};
279
280 bool allEmpty = (isEmptyOrNull(*args) && ...);
281 if (allEmpty)
282 return {};
283
284 return T::get(ctx, *args...);
285}
286
287FailureOr<LoopVectorizeAttr> LoopMetadataConversion::convertVectorizeAttr() {
288 FailureOr<BoolAttr> enable =
289 lookupBoolNode(name: "llvm.loop.vectorize.enable", negated: true);
290 FailureOr<BoolAttr> predicateEnable =
291 lookupBoolNode(name: "llvm.loop.vectorize.predicate.enable");
292 FailureOr<BoolAttr> scalableEnable =
293 lookupBoolNode(name: "llvm.loop.vectorize.scalable.enable");
294 FailureOr<IntegerAttr> width = lookupIntNode("llvm.loop.vectorize.width");
295 FailureOr<LoopAnnotationAttr> followupVec =
296 lookupFollowupNode("llvm.loop.vectorize.followup_vectorized");
297 FailureOr<LoopAnnotationAttr> followupEpi =
298 lookupFollowupNode("llvm.loop.vectorize.followup_epilogue");
299 FailureOr<LoopAnnotationAttr> followupAll =
300 lookupFollowupNode("llvm.loop.vectorize.followup_all");
301
302 return createIfNonNull<LoopVectorizeAttr>(ctx, enable, predicateEnable,
303 scalableEnable, width, followupVec,
304 followupEpi, followupAll);
305}
306
307FailureOr<LoopInterleaveAttr> LoopMetadataConversion::convertInterleaveAttr() {
308 FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.interleave.count");
309 return createIfNonNull<LoopInterleaveAttr>(ctx, count);
310}
311
312FailureOr<LoopUnrollAttr> LoopMetadataConversion::convertUnrollAttr() {
313 FailureOr<BoolAttr> disable = lookupBooleanUnitNode(
314 enableName: "llvm.loop.unroll.enable", disableName: "llvm.loop.unroll.disable", /*negated=*/true);
315 FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.unroll.count");
316 FailureOr<BoolAttr> runtimeDisable =
317 lookupUnitNode(name: "llvm.loop.unroll.runtime.disable");
318 FailureOr<BoolAttr> full = lookupUnitNode(name: "llvm.loop.unroll.full");
319 FailureOr<LoopAnnotationAttr> followupUnrolled =
320 lookupFollowupNode("llvm.loop.unroll.followup_unrolled");
321 FailureOr<LoopAnnotationAttr> followupRemainder =
322 lookupFollowupNode("llvm.loop.unroll.followup_remainder");
323 FailureOr<LoopAnnotationAttr> followupAll =
324 lookupFollowupNode("llvm.loop.unroll.followup_all");
325
326 return createIfNonNull<LoopUnrollAttr>(ctx, disable, count, runtimeDisable,
327 full, followupUnrolled,
328 followupRemainder, followupAll);
329}
330
331FailureOr<LoopUnrollAndJamAttr>
332LoopMetadataConversion::convertUnrollAndJamAttr() {
333 FailureOr<BoolAttr> disable = lookupBooleanUnitNode(
334 enableName: "llvm.loop.unroll_and_jam.enable", disableName: "llvm.loop.unroll_and_jam.disable",
335 /*negated=*/true);
336 FailureOr<IntegerAttr> count =
337 lookupIntNode("llvm.loop.unroll_and_jam.count");
338 FailureOr<LoopAnnotationAttr> followupOuter =
339 lookupFollowupNode("llvm.loop.unroll_and_jam.followup_outer");
340 FailureOr<LoopAnnotationAttr> followupInner =
341 lookupFollowupNode("llvm.loop.unroll_and_jam.followup_inner");
342 FailureOr<LoopAnnotationAttr> followupRemainderOuter =
343 lookupFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_outer");
344 FailureOr<LoopAnnotationAttr> followupRemainderInner =
345 lookupFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_inner");
346 FailureOr<LoopAnnotationAttr> followupAll =
347 lookupFollowupNode("llvm.loop.unroll_and_jam.followup_all");
348 return createIfNonNull<LoopUnrollAndJamAttr>(
349 ctx, disable, count, followupOuter, followupInner, followupRemainderOuter,
350 followupRemainderInner, followupAll);
351}
352
353FailureOr<LoopLICMAttr> LoopMetadataConversion::convertLICMAttr() {
354 FailureOr<BoolAttr> disable = lookupUnitNode(name: "llvm.licm.disable");
355 FailureOr<BoolAttr> versioningDisable =
356 lookupUnitNode(name: "llvm.loop.licm_versioning.disable");
357 return createIfNonNull<LoopLICMAttr>(ctx, disable, versioningDisable);
358}
359
360FailureOr<LoopDistributeAttr> LoopMetadataConversion::convertDistributeAttr() {
361 FailureOr<BoolAttr> disable =
362 lookupBoolNode(name: "llvm.loop.distribute.enable", negated: true);
363 FailureOr<LoopAnnotationAttr> followupCoincident =
364 lookupFollowupNode("llvm.loop.distribute.followup_coincident");
365 FailureOr<LoopAnnotationAttr> followupSequential =
366 lookupFollowupNode("llvm.loop.distribute.followup_sequential");
367 FailureOr<LoopAnnotationAttr> followupFallback =
368 lookupFollowupNode("llvm.loop.distribute.followup_fallback");
369 FailureOr<LoopAnnotationAttr> followupAll =
370 lookupFollowupNode("llvm.loop.distribute.followup_all");
371 return createIfNonNull<LoopDistributeAttr>(ctx, disable, followupCoincident,
372 followupSequential,
373 followupFallback, followupAll);
374}
375
376FailureOr<LoopPipelineAttr> LoopMetadataConversion::convertPipelineAttr() {
377 FailureOr<BoolAttr> disable = lookupBoolNode(name: "llvm.loop.pipeline.disable");
378 FailureOr<IntegerAttr> initiationinterval =
379 lookupIntNode("llvm.loop.pipeline.initiationinterval");
380 return createIfNonNull<LoopPipelineAttr>(ctx, disable, initiationinterval);
381}
382
383FailureOr<LoopPeeledAttr> LoopMetadataConversion::convertPeeledAttr() {
384 FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.peeled.count");
385 return createIfNonNull<LoopPeeledAttr>(ctx, count);
386}
387
388FailureOr<LoopUnswitchAttr> LoopMetadataConversion::convertUnswitchAttr() {
389 FailureOr<BoolAttr> partialDisable =
390 lookupUnitNode(name: "llvm.loop.unswitch.partial.disable");
391 return createIfNonNull<LoopUnswitchAttr>(ctx, partialDisable);
392}
393
394FailureOr<SmallVector<AccessGroupAttr>>
395LoopMetadataConversion::convertParallelAccesses() {
396 FailureOr<SmallVector<llvm::MDNode *>> nodes =
397 lookupMDNodes(name: "llvm.loop.parallel_accesses");
398 if (failed(result: nodes))
399 return failure();
400 SmallVector<AccessGroupAttr> refs;
401 for (llvm::MDNode *node : *nodes) {
402 FailureOr<SmallVector<AccessGroupAttr>> accessGroups =
403 loopAnnotationImporter.lookupAccessGroupAttrs(node);
404 if (failed(accessGroups)) {
405 emitWarning(loc) << "could not lookup access group";
406 continue;
407 }
408 llvm::append_range(refs, *accessGroups);
409 }
410 return refs;
411}
412
413FusedLoc LoopMetadataConversion::convertStartLoc() {
414 if (locations.empty())
415 return {};
416 return dyn_cast<FusedLoc>(
417 loopAnnotationImporter.moduleImport.translateLoc(locations[0]));
418}
419
420FailureOr<FusedLoc> LoopMetadataConversion::convertEndLoc() {
421 if (locations.size() < 2)
422 return FusedLoc();
423 if (locations.size() > 2)
424 return emitError(loc)
425 << "expected loop metadata to have at most two DILocations";
426 return dyn_cast<FusedLoc>(
427 loopAnnotationImporter.moduleImport.translateLoc(locations[1]));
428}
429
430LoopAnnotationAttr LoopMetadataConversion::convert() {
431 if (failed(result: initConversionState()))
432 return {};
433
434 FailureOr<BoolAttr> disableNonForced =
435 lookupUnitNode(name: "llvm.loop.disable_nonforced");
436 FailureOr<LoopVectorizeAttr> vecAttr = convertVectorizeAttr();
437 FailureOr<LoopInterleaveAttr> interleaveAttr = convertInterleaveAttr();
438 FailureOr<LoopUnrollAttr> unrollAttr = convertUnrollAttr();
439 FailureOr<LoopUnrollAndJamAttr> unrollAndJamAttr = convertUnrollAndJamAttr();
440 FailureOr<LoopLICMAttr> licmAttr = convertLICMAttr();
441 FailureOr<LoopDistributeAttr> distributeAttr = convertDistributeAttr();
442 FailureOr<LoopPipelineAttr> pipelineAttr = convertPipelineAttr();
443 FailureOr<LoopPeeledAttr> peeledAttr = convertPeeledAttr();
444 FailureOr<LoopUnswitchAttr> unswitchAttr = convertUnswitchAttr();
445 FailureOr<BoolAttr> mustProgress = lookupUnitNode(name: "llvm.loop.mustprogress");
446 FailureOr<BoolAttr> isVectorized =
447 lookupIntNodeAsBoolAttr(name: "llvm.loop.isvectorized");
448 FailureOr<SmallVector<AccessGroupAttr>> parallelAccesses =
449 convertParallelAccesses();
450
451 // Drop the metadata if there are parts that cannot be imported.
452 if (!propertyMap.empty()) {
453 for (auto name : propertyMap.keys())
454 emitWarning(loc) << "unknown loop annotation " << name;
455 return {};
456 }
457
458 FailureOr<FusedLoc> startLoc = convertStartLoc();
459 FailureOr<FusedLoc> endLoc = convertEndLoc();
460
461 return createIfNonNull<LoopAnnotationAttr>(
462 ctx, disableNonForced, vecAttr, interleaveAttr, unrollAttr,
463 unrollAndJamAttr, licmAttr, distributeAttr, pipelineAttr, peeledAttr,
464 unswitchAttr, mustProgress, isVectorized, startLoc, endLoc,
465 parallelAccesses);
466}
467
468LoopAnnotationAttr
469LoopAnnotationImporter::translateLoopAnnotation(const llvm::MDNode *node,
470 Location loc) {
471 if (!node)
472 return {};
473
474 // Note: This check is necessary to distinguish between failed translations
475 // and not yet attempted translations.
476 auto it = loopMetadataMapping.find(node);
477 if (it != loopMetadataMapping.end())
478 return it->getSecond();
479
480 LoopAnnotationAttr attr = LoopMetadataConversion(node, loc, *this).convert();
481
482 mapLoopMetadata(node, attr);
483 return attr;
484}
485
486LogicalResult
487LoopAnnotationImporter::translateAccessGroup(const llvm::MDNode *node,
488 Location loc) {
489 SmallVector<const llvm::MDNode *> accessGroups;
490 if (!node->getNumOperands())
491 accessGroups.push_back(Elt: node);
492 for (const llvm::MDOperand &operand : node->operands()) {
493 auto *childNode = dyn_cast<llvm::MDNode>(Val: operand);
494 if (!childNode)
495 return failure();
496 accessGroups.push_back(Elt: cast<llvm::MDNode>(Val: operand.get()));
497 }
498
499 // Convert all entries of the access group list to access group operations.
500 for (const llvm::MDNode *accessGroup : accessGroups) {
501 if (accessGroupMapping.count(accessGroup))
502 continue;
503 // Verify the access group node is distinct and empty.
504 if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct())
505 return emitWarning(loc)
506 << "expected an access group node to be empty and distinct";
507
508 // Add a mapping from the access group node to the newly created attribute.
509 accessGroupMapping[accessGroup] = builder.getAttr<AccessGroupAttr>();
510 }
511 return success();
512}
513
514FailureOr<SmallVector<AccessGroupAttr>>
515LoopAnnotationImporter::lookupAccessGroupAttrs(const llvm::MDNode *node) const {
516 // An access group node is either a single access group or an access group
517 // list.
518 SmallVector<AccessGroupAttr> accessGroups;
519 if (!node->getNumOperands())
520 accessGroups.push_back(accessGroupMapping.lookup(node));
521 for (const llvm::MDOperand &operand : node->operands()) {
522 auto *node = cast<llvm::MDNode>(Val: operand.get());
523 accessGroups.push_back(accessGroupMapping.lookup(node));
524 }
525 // Exit if one of the access group node lookups failed.
526 if (llvm::is_contained(accessGroups, nullptr))
527 return failure();
528 return accessGroups;
529}
530

source code of mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp