1//===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the MLIR AsmPrinter class, which is used to implement
10// the various print() methods on the core IR objects.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/IR/AffineExpr.h"
15#include "mlir/IR/AffineMap.h"
16#include "mlir/IR/AsmState.h"
17#include "mlir/IR/Attributes.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/BuiltinAttributes.h"
20#include "mlir/IR/BuiltinDialect.h"
21#include "mlir/IR/BuiltinTypeInterfaces.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/Dialect.h"
24#include "mlir/IR/DialectImplementation.h"
25#include "mlir/IR/DialectResourceBlobManager.h"
26#include "mlir/IR/IntegerSet.h"
27#include "mlir/IR/MLIRContext.h"
28#include "mlir/IR/OpImplementation.h"
29#include "mlir/IR/Operation.h"
30#include "mlir/IR/Verifier.h"
31#include "llvm/ADT/APFloat.h"
32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/DenseMap.h"
34#include "llvm/ADT/MapVector.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/ScopeExit.h"
37#include "llvm/ADT/ScopedHashTable.h"
38#include "llvm/ADT/SetVector.h"
39#include "llvm/ADT/SmallString.h"
40#include "llvm/ADT/StringExtras.h"
41#include "llvm/ADT/StringSet.h"
42#include "llvm/ADT/TypeSwitch.h"
43#include "llvm/Support/CommandLine.h"
44#include "llvm/Support/Debug.h"
45#include "llvm/Support/Endian.h"
46#include "llvm/Support/Regex.h"
47#include "llvm/Support/SaveAndRestore.h"
48#include "llvm/Support/Threading.h"
49#include "llvm/Support/raw_ostream.h"
50#include <type_traits>
51
52#include <optional>
53#include <tuple>
54
55using namespace mlir;
56using namespace mlir::detail;
57
58#define DEBUG_TYPE "mlir-asm-printer"
59
60void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
61
62void OperationName::dump() const { print(os&: llvm::errs()); }
63
64//===--------------------------------------------------------------------===//
65// AsmParser
66//===--------------------------------------------------------------------===//
67
68AsmParser::~AsmParser() = default;
69DialectAsmParser::~DialectAsmParser() = default;
70OpAsmParser::~OpAsmParser() = default;
71
72MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
73
74/// Parse a type list.
75/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
76ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
77 return parseCommaSeparatedList(
78 parseElementFn: [&]() { return parseType(result&: result.emplace_back()); });
79}
80
81//===----------------------------------------------------------------------===//
82// DialectAsmPrinter
83//===----------------------------------------------------------------------===//
84
85DialectAsmPrinter::~DialectAsmPrinter() = default;
86
87//===----------------------------------------------------------------------===//
88// OpAsmPrinter
89//===----------------------------------------------------------------------===//
90
91OpAsmPrinter::~OpAsmPrinter() = default;
92
93void OpAsmPrinter::printFunctionalType(Operation *op) {
94 auto &os = getStream();
95 os << '(';
96 llvm::interleaveComma(c: op->getOperands(), os, each_fn: [&](Value operand) {
97 // Print the types of null values as <<NULL TYPE>>.
98 *this << (operand ? operand.getType() : Type());
99 });
100 os << ") -> ";
101
102 // Print the result list. We don't parenthesize single result types unless
103 // it is a function (avoiding a grammar ambiguity).
104 bool wrapped = op->getNumResults() != 1;
105 if (!wrapped && op->getResult(idx: 0).getType() &&
106 llvm::isa<FunctionType>(Val: op->getResult(idx: 0).getType()))
107 wrapped = true;
108
109 if (wrapped)
110 os << '(';
111
112 llvm::interleaveComma(c: op->getResults(), os, each_fn: [&](const OpResult &result) {
113 // Print the types of null values as <<NULL TYPE>>.
114 *this << (result ? result.getType() : Type());
115 });
116
117 if (wrapped)
118 os << ')';
119}
120
121//===----------------------------------------------------------------------===//
122// Operation OpAsm interface.
123//===----------------------------------------------------------------------===//
124
125/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
126#include "mlir/IR/OpAsmInterface.cpp.inc"
127
128LogicalResult
129OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {
130 return entry.emitError() << "unknown 'resource' key '" << entry.getKey()
131 << "' for dialect '" << getDialect()->getNamespace()
132 << "'";
133}
134
135//===----------------------------------------------------------------------===//
136// OpPrintingFlags
137//===----------------------------------------------------------------------===//
138
139namespace {
140/// This struct contains command line options that can be used to initialize
141/// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
142/// for global command line options.
143struct AsmPrinterOptions {
144 llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{
145 "mlir-print-elementsattrs-with-hex-if-larger",
146 llvm::cl::desc(
147 "Print DenseElementsAttrs with a hex string that have "
148 "more elements than the given upper limit (use -1 to disable)")};
149
150 llvm::cl::opt<unsigned> elideElementsAttrIfLarger{
151 "mlir-elide-elementsattrs-if-larger",
152 llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
153 "more elements than the given upper limit")};
154
155 llvm::cl::opt<unsigned> elideResourceStringsIfLarger{
156 "mlir-elide-resource-strings-if-larger",
157 llvm::cl::desc(
158 "Elide printing value of resources if string is too long in chars.")};
159
160 llvm::cl::opt<bool> printDebugInfoOpt{
161 "mlir-print-debuginfo", llvm::cl::init(Val: false),
162 llvm::cl::desc("Print debug info in MLIR output")};
163
164 llvm::cl::opt<bool> printPrettyDebugInfoOpt{
165 "mlir-pretty-debuginfo", llvm::cl::init(Val: false),
166 llvm::cl::desc("Print pretty debug info in MLIR output")};
167
168 // Use the generic op output form in the operation printer even if the custom
169 // form is defined.
170 llvm::cl::opt<bool> printGenericOpFormOpt{
171 "mlir-print-op-generic", llvm::cl::init(Val: false),
172 llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
173
174 llvm::cl::opt<bool> assumeVerifiedOpt{
175 "mlir-print-assume-verified", llvm::cl::init(Val: false),
176 llvm::cl::desc("Skip op verification when using custom printers"),
177 llvm::cl::Hidden};
178
179 llvm::cl::opt<bool> printLocalScopeOpt{
180 "mlir-print-local-scope", llvm::cl::init(Val: false),
181 llvm::cl::desc("Print with local scope and inline information (eliding "
182 "aliases for attributes, types, and locations")};
183
184 llvm::cl::opt<bool> skipRegionsOpt{
185 "mlir-print-skip-regions", llvm::cl::init(Val: false),
186 llvm::cl::desc("Skip regions when printing ops.")};
187
188 llvm::cl::opt<bool> printValueUsers{
189 "mlir-print-value-users", llvm::cl::init(Val: false),
190 llvm::cl::desc(
191 "Print users of operation results and block arguments as a comment")};
192};
193} // namespace
194
195static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
196
197/// Register a set of useful command-line options that can be used to configure
198/// various flags within the AsmPrinter.
199void mlir::registerAsmPrinterCLOptions() {
200 // Make sure that the options struct has been initialized.
201 *clOptions;
202}
203
204/// Initialize the printing flags with default supplied by the cl::opts above.
205OpPrintingFlags::OpPrintingFlags()
206 : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
207 printGenericOpFormFlag(false), skipRegionsFlag(false),
208 assumeVerifiedFlag(false), printLocalScope(false),
209 printValueUsersFlag(false) {
210 // Initialize based upon command line options, if they are available.
211 if (!clOptions.isConstructed())
212 return;
213 if (clOptions->elideElementsAttrIfLarger.getNumOccurrences())
214 elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger;
215 if (clOptions->elideResourceStringsIfLarger.getNumOccurrences())
216 resourceStringCharLimit = clOptions->elideResourceStringsIfLarger;
217 printDebugInfoFlag = clOptions->printDebugInfoOpt;
218 printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
219 printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
220 assumeVerifiedFlag = clOptions->assumeVerifiedOpt;
221 printLocalScope = clOptions->printLocalScopeOpt;
222 skipRegionsFlag = clOptions->skipRegionsOpt;
223 printValueUsersFlag = clOptions->printValueUsers;
224}
225
226/// Enable the elision of large elements attributes, by printing a '...'
227/// instead of the element data, when the number of elements is greater than
228/// `largeElementLimit`. Note: The IR generated with this option is not
229/// parsable.
230OpPrintingFlags &
231OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
232 elementsAttrElementLimit = largeElementLimit;
233 return *this;
234}
235
236OpPrintingFlags &
237OpPrintingFlags::elideLargeResourceString(int64_t largeResourceLimit) {
238 resourceStringCharLimit = largeResourceLimit;
239 return *this;
240}
241
242/// Enable printing of debug information. If 'prettyForm' is set to true,
243/// debug information is printed in a more readable 'pretty' form.
244OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool enable,
245 bool prettyForm) {
246 printDebugInfoFlag = enable;
247 printDebugInfoPrettyFormFlag = prettyForm;
248 return *this;
249}
250
251/// Always print operations in the generic form.
252OpPrintingFlags &OpPrintingFlags::printGenericOpForm(bool enable) {
253 printGenericOpFormFlag = enable;
254 return *this;
255}
256
257/// Always skip Regions.
258OpPrintingFlags &OpPrintingFlags::skipRegions(bool skip) {
259 skipRegionsFlag = skip;
260 return *this;
261}
262
263/// Do not verify the operation when using custom operation printers.
264OpPrintingFlags &OpPrintingFlags::assumeVerified() {
265 assumeVerifiedFlag = true;
266 return *this;
267}
268
269/// Use local scope when printing the operation. This allows for using the
270/// printer in a more localized and thread-safe setting, but may not necessarily
271/// be identical of what the IR will look like when dumping the full module.
272OpPrintingFlags &OpPrintingFlags::useLocalScope() {
273 printLocalScope = true;
274 return *this;
275}
276
277/// Print users of values as comments.
278OpPrintingFlags &OpPrintingFlags::printValueUsers() {
279 printValueUsersFlag = true;
280 return *this;
281}
282
283/// Return if the given ElementsAttr should be elided.
284bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
285 return elementsAttrElementLimit &&
286 *elementsAttrElementLimit < int64_t(attr.getNumElements()) &&
287 !llvm::isa<SplatElementsAttr>(attr);
288}
289
290/// Return the size limit for printing large ElementsAttr.
291std::optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const {
292 return elementsAttrElementLimit;
293}
294
295/// Return the size limit for printing large ElementsAttr.
296std::optional<uint64_t> OpPrintingFlags::getLargeResourceStringLimit() const {
297 return resourceStringCharLimit;
298}
299
300/// Return if debug information should be printed.
301bool OpPrintingFlags::shouldPrintDebugInfo() const {
302 return printDebugInfoFlag;
303}
304
305/// Return if debug information should be printed in the pretty form.
306bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
307 return printDebugInfoPrettyFormFlag;
308}
309
310/// Return if operations should be printed in the generic form.
311bool OpPrintingFlags::shouldPrintGenericOpForm() const {
312 return printGenericOpFormFlag;
313}
314
315/// Return if Region should be skipped.
316bool OpPrintingFlags::shouldSkipRegions() const { return skipRegionsFlag; }
317
318/// Return if operation verification should be skipped.
319bool OpPrintingFlags::shouldAssumeVerified() const {
320 return assumeVerifiedFlag;
321}
322
323/// Return if the printer should use local scope when dumping the IR.
324bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
325
326/// Return if the printer should print users of values.
327bool OpPrintingFlags::shouldPrintValueUsers() const {
328 return printValueUsersFlag;
329}
330
331/// Returns true if an ElementsAttr with the given number of elements should be
332/// printed with hex.
333static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
334 // Check to see if a command line option was provided for the limit.
335 if (clOptions.isConstructed()) {
336 if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) {
337 // -1 is used to disable hex printing.
338 if (clOptions->printElementsAttrWithHexIfLarger == -1)
339 return false;
340 return numElements > clOptions->printElementsAttrWithHexIfLarger;
341 }
342 }
343
344 // Otherwise, default to printing with hex if the number of elements is >100.
345 return numElements > 100;
346}
347
348//===----------------------------------------------------------------------===//
349// NewLineCounter
350//===----------------------------------------------------------------------===//
351
352namespace {
353/// This class is a simple formatter that emits a new line when inputted into a
354/// stream, that enables counting the number of newlines emitted. This class
355/// should be used whenever emitting newlines in the printer.
356struct NewLineCounter {
357 unsigned curLine = 1;
358};
359
360static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
361 ++newLine.curLine;
362 return os << '\n';
363}
364} // namespace
365
366//===----------------------------------------------------------------------===//
367// AsmPrinter::Impl
368//===----------------------------------------------------------------------===//
369
370namespace mlir {
371class AsmPrinter::Impl {
372public:
373 Impl(raw_ostream &os, AsmStateImpl &state);
374 explicit Impl(Impl &other) : Impl(other.os, other.state) {}
375
376 /// Returns the output stream of the printer.
377 raw_ostream &getStream() { return os; }
378
379 template <typename Container, typename UnaryFunctor>
380 inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const {
381 llvm::interleaveComma(c, os, eachFn);
382 }
383
384 /// This enum describes the different kinds of elision for the type of an
385 /// attribute when printing it.
386 enum class AttrTypeElision {
387 /// The type must not be elided,
388 Never,
389 /// The type may be elided when it matches the default used in the parser
390 /// (for example i64 is the default for integer attributes).
391 May,
392 /// The type must be elided.
393 Must
394 };
395
396 /// Print the given attribute or an alias.
397 void printAttribute(Attribute attr,
398 AttrTypeElision typeElision = AttrTypeElision::Never);
399 /// Print the given attribute without considering an alias.
400 void printAttributeImpl(Attribute attr,
401 AttrTypeElision typeElision = AttrTypeElision::Never);
402
403 /// Print the alias for the given attribute, return failure if no alias could
404 /// be printed.
405 LogicalResult printAlias(Attribute attr);
406
407 /// Print the given type or an alias.
408 void printType(Type type);
409 /// Print the given type.
410 void printTypeImpl(Type type);
411
412 /// Print the alias for the given type, return failure if no alias could
413 /// be printed.
414 LogicalResult printAlias(Type type);
415
416 /// Print the given location to the stream. If `allowAlias` is true, this
417 /// allows for the internal location to use an attribute alias.
418 void printLocation(LocationAttr loc, bool allowAlias = false);
419
420 /// Print a reference to the given resource that is owned by the given
421 /// dialect.
422 void printResourceHandle(const AsmDialectResourceHandle &resource);
423
424 void printAffineMap(AffineMap map);
425 void
426 printAffineExpr(AffineExpr expr,
427 function_ref<void(unsigned, bool)> printValueName = nullptr);
428 void printAffineConstraint(AffineExpr expr, bool isEq);
429 void printIntegerSet(IntegerSet set);
430
431 LogicalResult pushCyclicPrinting(const void *opaquePointer);
432
433 void popCyclicPrinting();
434
435 void printDimensionList(ArrayRef<int64_t> shape);
436
437protected:
438 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
439 ArrayRef<StringRef> elidedAttrs = {},
440 bool withKeyword = false);
441 void printNamedAttribute(NamedAttribute attr);
442 void printTrailingLocation(Location loc, bool allowAlias = true);
443 void printLocationInternal(LocationAttr loc, bool pretty = false,
444 bool isTopLevel = false);
445
446 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
447 /// used instead of individual elements when the elements attr is large.
448 void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
449
450 /// Print a dense string elements attribute.
451 void printDenseStringElementsAttr(DenseStringElementsAttr attr);
452
453 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
454 /// used instead of individual elements when the elements attr is large.
455 void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
456 bool allowHex);
457
458 /// Print a dense array attribute.
459 void printDenseArrayAttr(DenseArrayAttr attr);
460
461 void printDialectAttribute(Attribute attr);
462 void printDialectType(Type type);
463
464 /// Print an escaped string, wrapped with "".
465 void printEscapedString(StringRef str);
466
467 /// Print a hex string, wrapped with "".
468 void printHexString(StringRef str);
469 void printHexString(ArrayRef<char> data);
470
471 /// This enum is used to represent the binding strength of the enclosing
472 /// context that an AffineExprStorage is being printed in, so we can
473 /// intelligently produce parens.
474 enum class BindingStrength {
475 Weak, // + and -
476 Strong, // All other binary operators.
477 };
478 void printAffineExprInternal(
479 AffineExpr expr, BindingStrength enclosingTightness,
480 function_ref<void(unsigned, bool)> printValueName = nullptr);
481
482 /// The output stream for the printer.
483 raw_ostream &os;
484
485 /// An underlying assembly printer state.
486 AsmStateImpl &state;
487
488 /// A set of flags to control the printer's behavior.
489 OpPrintingFlags printerFlags;
490
491 /// A tracker for the number of new lines emitted during printing.
492 NewLineCounter newLine;
493};
494} // namespace mlir
495
496//===----------------------------------------------------------------------===//
497// AliasInitializer
498//===----------------------------------------------------------------------===//
499
500namespace {
501/// This class represents a specific instance of a symbol Alias.
502class SymbolAlias {
503public:
504 SymbolAlias(StringRef name, uint32_t suffixIndex, bool isType,
505 bool isDeferrable)
506 : name(name), suffixIndex(suffixIndex), isType(isType),
507 isDeferrable(isDeferrable) {}
508
509 /// Print this alias to the given stream.
510 void print(raw_ostream &os) const {
511 os << (isType ? "!" : "#") << name;
512 if (suffixIndex)
513 os << suffixIndex;
514 }
515
516 /// Returns true if this is a type alias.
517 bool isTypeAlias() const { return isType; }
518
519 /// Returns true if this alias supports deferred resolution when parsing.
520 bool canBeDeferred() const { return isDeferrable; }
521
522private:
523 /// The main name of the alias.
524 StringRef name;
525 /// The suffix index of the alias.
526 uint32_t suffixIndex : 30;
527 /// A flag indicating whether this alias is for a type.
528 bool isType : 1;
529 /// A flag indicating whether this alias may be deferred or not.
530 bool isDeferrable : 1;
531};
532
533/// This class represents a utility that initializes the set of attribute and
534/// type aliases, without the need to store the extra information within the
535/// main AliasState class or pass it around via function arguments.
536class AliasInitializer {
537public:
538 AliasInitializer(
539 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces,
540 llvm::BumpPtrAllocator &aliasAllocator)
541 : interfaces(interfaces), aliasAllocator(aliasAllocator),
542 aliasOS(aliasBuffer) {}
543
544 void initialize(Operation *op, const OpPrintingFlags &printerFlags,
545 llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias);
546
547 /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
548 /// set to true if the originator of this attribute can resolve the alias
549 /// after parsing has completed (e.g. in the case of operation locations).
550 /// `elideType` indicates if the type of the attribute should be skipped when
551 /// looking for nested aliases. Returns the maximum alias depth of the
552 /// attribute, and the alias index of this attribute.
553 std::pair<size_t, size_t> visit(Attribute attr, bool canBeDeferred = false,
554 bool elideType = false) {
555 return visitImpl(value: attr, aliases, canBeDeferred, printArgs&: elideType);
556 }
557
558 /// Visit the given type to see if it has an alias. `canBeDeferred` is
559 /// set to true if the originator of this attribute can resolve the alias
560 /// after parsing has completed. Returns the maximum alias depth of the type,
561 /// and the alias index of this type.
562 std::pair<size_t, size_t> visit(Type type, bool canBeDeferred = false) {
563 return visitImpl(value: type, aliases, canBeDeferred);
564 }
565
566private:
567 struct InProgressAliasInfo {
568 InProgressAliasInfo()
569 : aliasDepth(0), isType(false), canBeDeferred(false) {}
570 InProgressAliasInfo(StringRef alias, bool isType, bool canBeDeferred)
571 : alias(alias), aliasDepth(1), isType(isType),
572 canBeDeferred(canBeDeferred) {}
573
574 bool operator<(const InProgressAliasInfo &rhs) const {
575 // Order first by depth, then by attr/type kind, and then by name.
576 if (aliasDepth != rhs.aliasDepth)
577 return aliasDepth < rhs.aliasDepth;
578 if (isType != rhs.isType)
579 return isType;
580 return alias < rhs.alias;
581 }
582
583 /// The alias for the attribute or type, or std::nullopt if the value has no
584 /// alias.
585 std::optional<StringRef> alias;
586 /// The alias depth of this attribute or type, i.e. an indication of the
587 /// relative ordering of when to print this alias.
588 unsigned aliasDepth : 30;
589 /// If this alias represents a type or an attribute.
590 bool isType : 1;
591 /// If this alias can be deferred or not.
592 bool canBeDeferred : 1;
593 /// Indices for child aliases.
594 SmallVector<size_t> childIndices;
595 };
596
597 /// Visit the given attribute or type to see if it has an alias.
598 /// `canBeDeferred` is set to true if the originator of this value can resolve
599 /// the alias after parsing has completed (e.g. in the case of operation
600 /// locations). Returns the maximum alias depth of the value, and its alias
601 /// index.
602 template <typename T, typename... PrintArgs>
603 std::pair<size_t, size_t>
604 visitImpl(T value,
605 llvm::MapVector<const void *, InProgressAliasInfo> &aliases,
606 bool canBeDeferred, PrintArgs &&...printArgs);
607
608 /// Mark the given alias as non-deferrable.
609 void markAliasNonDeferrable(size_t aliasIndex);
610
611 /// Try to generate an alias for the provided symbol. If an alias is
612 /// generated, the provided alias mapping and reverse mapping are updated.
613 template <typename T>
614 void generateAlias(T symbol, InProgressAliasInfo &alias, bool canBeDeferred);
615
616 /// Given a collection of aliases and symbols, initialize a mapping from a
617 /// symbol to a given alias.
618 static void initializeAliases(
619 llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
620 llvm::MapVector<const void *, SymbolAlias> &symbolToAlias);
621
622 /// The set of asm interfaces within the context.
623 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
624
625 /// An allocator used for alias names.
626 llvm::BumpPtrAllocator &aliasAllocator;
627
628 /// The set of built aliases.
629 llvm::MapVector<const void *, InProgressAliasInfo> aliases;
630
631 /// Storage and stream used when generating an alias.
632 SmallString<32> aliasBuffer;
633 llvm::raw_svector_ostream aliasOS;
634};
635
636/// This class implements a dummy OpAsmPrinter that doesn't print any output,
637/// and merely collects the attributes and types that *would* be printed in a
638/// normal print invocation so that we can generate proper aliases. This allows
639/// for us to generate aliases only for the attributes and types that would be
640/// in the output, and trims down unnecessary output.
641class DummyAliasOperationPrinter : private OpAsmPrinter {
642public:
643 explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags,
644 AliasInitializer &initializer)
645 : printerFlags(printerFlags), initializer(initializer) {}
646
647 /// Prints the entire operation with the custom assembly form, if available,
648 /// or the generic assembly form, otherwise.
649 void printCustomOrGenericOp(Operation *op) override {
650 // Visit the operation location.
651 if (printerFlags.shouldPrintDebugInfo())
652 initializer.visit(attr: op->getLoc(), /*canBeDeferred=*/true);
653
654 // If requested, always print the generic form.
655 if (!printerFlags.shouldPrintGenericOpForm()) {
656 op->getName().printAssembly(op, p&: *this, /*defaultDialect=*/"");
657 return;
658 }
659
660 // Otherwise print with the generic assembly form.
661 printGenericOp(op);
662 }
663
664private:
665 /// Print the given operation in the generic form.
666 void printGenericOp(Operation *op, bool printOpName = true) override {
667 // Consider nested operations for aliases.
668 if (!printerFlags.shouldSkipRegions()) {
669 for (Region &region : op->getRegions())
670 printRegion(region, /*printEntryBlockArgs=*/true,
671 /*printBlockTerminators=*/true);
672 }
673
674 // Visit all the types used in the operation.
675 for (Type type : op->getOperandTypes())
676 printType(type);
677 for (Type type : op->getResultTypes())
678 printType(type);
679
680 // Consider the attributes of the operation for aliases.
681 for (const NamedAttribute &attr : op->getAttrs())
682 printAttribute(attr: attr.getValue());
683 }
684
685 /// Print the given block. If 'printBlockArgs' is false, the arguments of the
686 /// block are not printed. If 'printBlockTerminator' is false, the terminator
687 /// operation of the block is not printed.
688 void print(Block *block, bool printBlockArgs = true,
689 bool printBlockTerminator = true) {
690 // Consider the types of the block arguments for aliases if 'printBlockArgs'
691 // is set to true.
692 if (printBlockArgs) {
693 for (BlockArgument arg : block->getArguments()) {
694 printType(type: arg.getType());
695
696 // Visit the argument location.
697 if (printerFlags.shouldPrintDebugInfo())
698 // TODO: Allow deferring argument locations.
699 initializer.visit(attr: arg.getLoc(), /*canBeDeferred=*/false);
700 }
701 }
702
703 // Consider the operations within this block, ignoring the terminator if
704 // requested.
705 bool hasTerminator =
706 !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
707 auto range = llvm::make_range(
708 x: block->begin(),
709 y: std::prev(x: block->end(),
710 n: (!hasTerminator || printBlockTerminator) ? 0 : 1));
711 for (Operation &op : range)
712 printCustomOrGenericOp(op: &op);
713 }
714
715 /// Print the given region.
716 void printRegion(Region &region, bool printEntryBlockArgs,
717 bool printBlockTerminators,
718 bool printEmptyBlock = false) override {
719 if (region.empty())
720 return;
721 if (printerFlags.shouldSkipRegions()) {
722 os << "{...}";
723 return;
724 }
725
726 auto *entryBlock = &region.front();
727 print(block: entryBlock, printBlockArgs: printEntryBlockArgs, printBlockTerminator: printBlockTerminators);
728 for (Block &b : llvm::drop_begin(RangeOrContainer&: region, N: 1))
729 print(block: &b);
730 }
731
732 void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
733 bool omitType) override {
734 printType(type: arg.getType());
735 // Visit the argument location.
736 if (printerFlags.shouldPrintDebugInfo())
737 // TODO: Allow deferring argument locations.
738 initializer.visit(attr: arg.getLoc(), /*canBeDeferred=*/false);
739 }
740
741 /// Consider the given type to be printed for an alias.
742 void printType(Type type) override { initializer.visit(type); }
743
744 /// Consider the given attribute to be printed for an alias.
745 void printAttribute(Attribute attr) override { initializer.visit(attr); }
746 void printAttributeWithoutType(Attribute attr) override {
747 printAttribute(attr);
748 }
749 LogicalResult printAlias(Attribute attr) override {
750 initializer.visit(attr);
751 return success();
752 }
753 LogicalResult printAlias(Type type) override {
754 initializer.visit(type);
755 return success();
756 }
757
758 /// Consider the given location to be printed for an alias.
759 void printOptionalLocationSpecifier(Location loc) override {
760 printAttribute(attr: loc);
761 }
762
763 /// Print the given set of attributes with names not included within
764 /// 'elidedAttrs'.
765 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
766 ArrayRef<StringRef> elidedAttrs = {}) override {
767 if (attrs.empty())
768 return;
769 if (elidedAttrs.empty()) {
770 for (const NamedAttribute &attr : attrs)
771 printAttribute(attr: attr.getValue());
772 return;
773 }
774 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
775 elidedAttrs.end());
776 for (const NamedAttribute &attr : attrs)
777 if (!elidedAttrsSet.contains(V: attr.getName().strref()))
778 printAttribute(attr: attr.getValue());
779 }
780 void printOptionalAttrDictWithKeyword(
781 ArrayRef<NamedAttribute> attrs,
782 ArrayRef<StringRef> elidedAttrs = {}) override {
783 printOptionalAttrDict(attrs, elidedAttrs);
784 }
785
786 /// Return a null stream as the output stream, this will ignore any data fed
787 /// to it.
788 raw_ostream &getStream() const override { return os; }
789
790 /// The following are hooks of `OpAsmPrinter` that are not necessary for
791 /// determining potential aliases.
792 void printFloat(const APFloat &) override {}
793 void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
794 void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
795 void printNewline() override {}
796 void increaseIndent() override {}
797 void decreaseIndent() override {}
798 void printOperand(Value) override {}
799 void printOperand(Value, raw_ostream &os) override {
800 // Users expect the output string to have at least the prefixed % to signal
801 // a value name. To maintain this invariant, emit a name even if it is
802 // guaranteed to go unused.
803 os << "%";
804 }
805 void printKeywordOrString(StringRef) override {}
806 void printString(StringRef) override {}
807 void printResourceHandle(const AsmDialectResourceHandle &) override {}
808 void printSymbolName(StringRef) override {}
809 void printSuccessor(Block *) override {}
810 void printSuccessorAndUseList(Block *, ValueRange) override {}
811 void shadowRegionArgs(Region &, ValueRange) override {}
812
813 /// The printer flags to use when determining potential aliases.
814 const OpPrintingFlags &printerFlags;
815
816 /// The initializer to use when identifying aliases.
817 AliasInitializer &initializer;
818
819 /// A dummy output stream.
820 mutable llvm::raw_null_ostream os;
821};
822
823class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
824public:
825 explicit DummyAliasDialectAsmPrinter(AliasInitializer &initializer,
826 bool canBeDeferred,
827 SmallVectorImpl<size_t> &childIndices)
828 : initializer(initializer), canBeDeferred(canBeDeferred),
829 childIndices(childIndices) {}
830
831 /// Print the given attribute/type, visiting any nested aliases that would be
832 /// generated as part of printing. Returns the maximum alias depth found while
833 /// printing the given value.
834 template <typename T, typename... PrintArgs>
835 size_t printAndVisitNestedAliases(T value, PrintArgs &&...printArgs) {
836 printAndVisitNestedAliasesImpl(value, printArgs...);
837 return maxAliasDepth;
838 }
839
840private:
841 /// Print the given attribute/type, visiting any nested aliases that would be
842 /// generated as part of printing.
843 void printAndVisitNestedAliasesImpl(Attribute attr, bool elideType) {
844 if (!isa<BuiltinDialect>(Val: attr.getDialect())) {
845 attr.getDialect().printAttribute(attr, *this);
846
847 // Process the builtin attributes.
848 } else if (llvm::isa<AffineMapAttr, DenseArrayAttr, FloatAttr, IntegerAttr,
849 IntegerSetAttr, UnitAttr>(Val: attr)) {
850 return;
851 } else if (auto distinctAttr = dyn_cast<DistinctAttr>(attr)) {
852 printAttribute(attr: distinctAttr.getReferencedAttr());
853 } else if (auto dictAttr = dyn_cast<DictionaryAttr>(attr)) {
854 for (const NamedAttribute &nestedAttr : dictAttr.getValue()) {
855 printAttribute(nestedAttr.getName());
856 printAttribute(nestedAttr.getValue());
857 }
858 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
859 for (Attribute nestedAttr : arrayAttr.getValue())
860 printAttribute(nestedAttr);
861 } else if (auto typeAttr = dyn_cast<TypeAttr>(attr)) {
862 printType(type: typeAttr.getValue());
863 } else if (auto locAttr = dyn_cast<OpaqueLoc>(attr)) {
864 printAttribute(attr: locAttr.getFallbackLocation());
865 } else if (auto locAttr = dyn_cast<NameLoc>(attr)) {
866 if (!isa<UnknownLoc>(locAttr.getChildLoc()))
867 printAttribute(attr: locAttr.getChildLoc());
868 } else if (auto locAttr = dyn_cast<CallSiteLoc>(attr)) {
869 printAttribute(attr: locAttr.getCallee());
870 printAttribute(attr: locAttr.getCaller());
871 } else if (auto locAttr = dyn_cast<FusedLoc>(attr)) {
872 if (Attribute metadata = locAttr.getMetadata())
873 printAttribute(attr: metadata);
874 for (Location nestedLoc : locAttr.getLocations())
875 printAttribute(nestedLoc);
876 }
877
878 // Don't print the type if we must elide it, or if it is a None type.
879 if (!elideType) {
880 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) {
881 Type attrType = typedAttr.getType();
882 if (!llvm::isa<NoneType>(Val: attrType))
883 printType(type: attrType);
884 }
885 }
886 }
887 void printAndVisitNestedAliasesImpl(Type type) {
888 if (!isa<BuiltinDialect>(Val: type.getDialect()))
889 return type.getDialect().printType(type, *this);
890
891 // Only visit the layout of memref if it isn't the identity.
892 if (auto memrefTy = llvm::dyn_cast<MemRefType>(type)) {
893 printType(type: memrefTy.getElementType());
894 MemRefLayoutAttrInterface layout = memrefTy.getLayout();
895 if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity())
896 printAttribute(attr: memrefTy.getLayout());
897 if (memrefTy.getMemorySpace())
898 printAttribute(attr: memrefTy.getMemorySpace());
899 return;
900 }
901
902 // For most builtin types, we can simply walk the sub elements.
903 auto visitFn = [&](auto element) {
904 if (element)
905 (void)printAlias(element);
906 };
907 type.walkImmediateSubElements(walkAttrsFn: visitFn, walkTypesFn: visitFn);
908 }
909
910 /// Consider the given type to be printed for an alias.
911 void printType(Type type) override {
912 recordAliasResult(aliasDepthAndIndex: initializer.visit(type, canBeDeferred));
913 }
914
915 /// Consider the given attribute to be printed for an alias.
916 void printAttribute(Attribute attr) override {
917 recordAliasResult(aliasDepthAndIndex: initializer.visit(attr, canBeDeferred));
918 }
919 void printAttributeWithoutType(Attribute attr) override {
920 recordAliasResult(
921 aliasDepthAndIndex: initializer.visit(attr, canBeDeferred, /*elideType=*/true));
922 }
923 LogicalResult printAlias(Attribute attr) override {
924 printAttribute(attr);
925 return success();
926 }
927 LogicalResult printAlias(Type type) override {
928 printType(type);
929 return success();
930 }
931
932 /// Record the alias result of a child element.
933 void recordAliasResult(std::pair<size_t, size_t> aliasDepthAndIndex) {
934 childIndices.push_back(Elt: aliasDepthAndIndex.second);
935 if (aliasDepthAndIndex.first > maxAliasDepth)
936 maxAliasDepth = aliasDepthAndIndex.first;
937 }
938
939 /// Return a null stream as the output stream, this will ignore any data fed
940 /// to it.
941 raw_ostream &getStream() const override { return os; }
942
943 /// The following are hooks of `DialectAsmPrinter` that are not necessary for
944 /// determining potential aliases.
945 void printFloat(const APFloat &) override {}
946 void printKeywordOrString(StringRef) override {}
947 void printString(StringRef) override {}
948 void printSymbolName(StringRef) override {}
949 void printResourceHandle(const AsmDialectResourceHandle &) override {}
950
951 LogicalResult pushCyclicPrinting(const void *opaquePointer) override {
952 return success(isSuccess: cyclicPrintingStack.insert(X: opaquePointer));
953 }
954
955 void popCyclicPrinting() override { cyclicPrintingStack.pop_back(); }
956
957 /// Stack of potentially cyclic mutable attributes or type currently being
958 /// printed.
959 SetVector<const void *> cyclicPrintingStack;
960
961 /// The initializer to use when identifying aliases.
962 AliasInitializer &initializer;
963
964 /// If the aliases visited by this printer can be deferred.
965 bool canBeDeferred;
966
967 /// The indices of child aliases.
968 SmallVectorImpl<size_t> &childIndices;
969
970 /// The maximum alias depth found by the printer.
971 size_t maxAliasDepth = 0;
972
973 /// A dummy output stream.
974 mutable llvm::raw_null_ostream os;
975};
976} // namespace
977
978/// Sanitize the given name such that it can be used as a valid identifier. If
979/// the string needs to be modified in any way, the provided buffer is used to
980/// store the new copy,
981static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
982 StringRef allowedPunctChars = "$._-",
983 bool allowTrailingDigit = true) {
984 assert(!name.empty() && "Shouldn't have an empty name here");
985
986 auto copyNameToBuffer = [&] {
987 for (char ch : name) {
988 if (llvm::isAlnum(C: ch) || allowedPunctChars.contains(C: ch))
989 buffer.push_back(Elt: ch);
990 else if (ch == ' ')
991 buffer.push_back(Elt: '_');
992 else
993 buffer.append(RHS: llvm::utohexstr(X: (unsigned char)ch));
994 }
995 };
996
997 // Check to see if this name is valid. If it starts with a digit, then it
998 // could conflict with the autogenerated numeric ID's, so add an underscore
999 // prefix to avoid problems.
1000 if (isdigit(name[0])) {
1001 buffer.push_back(Elt: '_');
1002 copyNameToBuffer();
1003 return buffer;
1004 }
1005
1006 // If the name ends with a trailing digit, add a '_' to avoid potential
1007 // conflicts with autogenerated ID's.
1008 if (!allowTrailingDigit && isdigit(name.back())) {
1009 copyNameToBuffer();
1010 buffer.push_back(Elt: '_');
1011 return buffer;
1012 }
1013
1014 // Check to see that the name consists of only valid identifier characters.
1015 for (char ch : name) {
1016 if (!llvm::isAlnum(C: ch) && !allowedPunctChars.contains(C: ch)) {
1017 copyNameToBuffer();
1018 return buffer;
1019 }
1020 }
1021
1022 // If there are no invalid characters, return the original name.
1023 return name;
1024}
1025
1026/// Given a collection of aliases and symbols, initialize a mapping from a
1027/// symbol to a given alias.
1028void AliasInitializer::initializeAliases(
1029 llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
1030 llvm::MapVector<const void *, SymbolAlias> &symbolToAlias) {
1031 SmallVector<std::pair<const void *, InProgressAliasInfo>, 0>
1032 unprocessedAliases = visitedSymbols.takeVector();
1033 llvm::stable_sort(Range&: unprocessedAliases, C: [](const auto &lhs, const auto &rhs) {
1034 return lhs.second < rhs.second;
1035 });
1036
1037 llvm::StringMap<unsigned> nameCounts;
1038 for (auto &[symbol, aliasInfo] : unprocessedAliases) {
1039 if (!aliasInfo.alias)
1040 continue;
1041 StringRef alias = *aliasInfo.alias;
1042 unsigned nameIndex = nameCounts[alias]++;
1043 symbolToAlias.insert(
1044 KV: {symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType,
1045 aliasInfo.canBeDeferred)});
1046 }
1047}
1048
1049void AliasInitializer::initialize(
1050 Operation *op, const OpPrintingFlags &printerFlags,
1051 llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias) {
1052 // Use a dummy printer when walking the IR so that we can collect the
1053 // attributes/types that will actually be used during printing when
1054 // considering aliases.
1055 DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
1056 aliasPrinter.printCustomOrGenericOp(op);
1057
1058 // Initialize the aliases.
1059 initializeAliases(visitedSymbols&: aliases, symbolToAlias&: attrTypeToAlias);
1060}
1061
1062template <typename T, typename... PrintArgs>
1063std::pair<size_t, size_t> AliasInitializer::visitImpl(
1064 T value, llvm::MapVector<const void *, InProgressAliasInfo> &aliases,
1065 bool canBeDeferred, PrintArgs &&...printArgs) {
1066 auto [it, inserted] =
1067 aliases.insert({value.getAsOpaquePointer(), InProgressAliasInfo()});
1068 size_t aliasIndex = std::distance(aliases.begin(), it);
1069 if (!inserted) {
1070 // Make sure that the alias isn't deferred if we don't permit it.
1071 if (!canBeDeferred)
1072 markAliasNonDeferrable(aliasIndex);
1073 return {static_cast<size_t>(it->second.aliasDepth), aliasIndex};
1074 }
1075
1076 // Try to generate an alias for this value.
1077 generateAlias(value, it->second, canBeDeferred);
1078
1079 // Print the value, capturing any nested elements that require aliases.
1080 SmallVector<size_t> childAliases;
1081 DummyAliasDialectAsmPrinter printer(*this, canBeDeferred, childAliases);
1082 size_t maxAliasDepth =
1083 printer.printAndVisitNestedAliases(value, printArgs...);
1084
1085 // Make sure to recompute `it` in case the map was reallocated.
1086 it = std::next(x: aliases.begin(), n: aliasIndex);
1087
1088 // If we had sub elements, update to account for the depth.
1089 it->second.childIndices = std::move(childAliases);
1090 if (maxAliasDepth)
1091 it->second.aliasDepth = maxAliasDepth + 1;
1092
1093 // Propagate the alias depth of the value.
1094 return {(size_t)it->second.aliasDepth, aliasIndex};
1095}
1096
1097void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) {
1098 auto it = std::next(x: aliases.begin(), n: aliasIndex);
1099
1100 // If already marked non-deferrable stop the recursion.
1101 // All children should already be marked non-deferrable as well.
1102 if (!it->second.canBeDeferred)
1103 return;
1104
1105 it->second.canBeDeferred = false;
1106
1107 // Propagate the non-deferrable flag to any child aliases.
1108 for (size_t childIndex : it->second.childIndices)
1109 markAliasNonDeferrable(aliasIndex: childIndex);
1110}
1111
1112template <typename T>
1113void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
1114 bool canBeDeferred) {
1115 SmallString<32> nameBuffer;
1116 for (const auto &interface : interfaces) {
1117 OpAsmDialectInterface::AliasResult result =
1118 interface.getAlias(symbol, aliasOS);
1119 if (result == OpAsmDialectInterface::AliasResult::NoAlias)
1120 continue;
1121 nameBuffer = std::move(aliasBuffer);
1122 assert(!nameBuffer.empty() && "expected valid alias name");
1123 if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
1124 break;
1125 }
1126
1127 if (nameBuffer.empty())
1128 return;
1129
1130 SmallString<16> tempBuffer;
1131 StringRef name =
1132 sanitizeIdentifier(name: nameBuffer, buffer&: tempBuffer, /*allowedPunctChars=*/"$_-",
1133 /*allowTrailingDigit=*/false);
1134 name = name.copy(A&: aliasAllocator);
1135 alias = InProgressAliasInfo(name, /*isType=*/std::is_base_of_v<Type, T>,
1136 canBeDeferred);
1137}
1138
1139//===----------------------------------------------------------------------===//
1140// AliasState
1141//===----------------------------------------------------------------------===//
1142
1143namespace {
1144/// This class manages the state for type and attribute aliases.
1145class AliasState {
1146public:
1147 // Initialize the internal aliases.
1148 void
1149 initialize(Operation *op, const OpPrintingFlags &printerFlags,
1150 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
1151
1152 /// Get an alias for the given attribute if it has one and print it in `os`.
1153 /// Returns success if an alias was printed, failure otherwise.
1154 LogicalResult getAlias(Attribute attr, raw_ostream &os) const;
1155
1156 /// Get an alias for the given type if it has one and print it in `os`.
1157 /// Returns success if an alias was printed, failure otherwise.
1158 LogicalResult getAlias(Type ty, raw_ostream &os) const;
1159
1160 /// Print all of the referenced aliases that can not be resolved in a deferred
1161 /// manner.
1162 void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
1163 printAliases(p, newLine, /*isDeferred=*/false);
1164 }
1165
1166 /// Print all of the referenced aliases that support deferred resolution.
1167 void printDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
1168 printAliases(p, newLine, /*isDeferred=*/true);
1169 }
1170
1171private:
1172 /// Print all of the referenced aliases that support the provided resolution
1173 /// behavior.
1174 void printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
1175 bool isDeferred);
1176
1177 /// Mapping between attribute/type and alias.
1178 llvm::MapVector<const void *, SymbolAlias> attrTypeToAlias;
1179
1180 /// An allocator used for alias names.
1181 llvm::BumpPtrAllocator aliasAllocator;
1182};
1183} // namespace
1184
1185void AliasState::initialize(
1186 Operation *op, const OpPrintingFlags &printerFlags,
1187 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
1188 AliasInitializer initializer(interfaces, aliasAllocator);
1189 initializer.initialize(op, printerFlags, attrTypeToAlias);
1190}
1191
1192LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
1193 auto it = attrTypeToAlias.find(Key: attr.getAsOpaquePointer());
1194 if (it == attrTypeToAlias.end())
1195 return failure();
1196 it->second.print(os);
1197 return success();
1198}
1199
1200LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
1201 auto it = attrTypeToAlias.find(Key: ty.getAsOpaquePointer());
1202 if (it == attrTypeToAlias.end())
1203 return failure();
1204
1205 it->second.print(os);
1206 return success();
1207}
1208
1209void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
1210 bool isDeferred) {
1211 auto filterFn = [=](const auto &aliasIt) {
1212 return aliasIt.second.canBeDeferred() == isDeferred;
1213 };
1214 for (auto &[opaqueSymbol, alias] :
1215 llvm::make_filter_range(Range&: attrTypeToAlias, Pred: filterFn)) {
1216 alias.print(os&: p.getStream());
1217 p.getStream() << " = ";
1218
1219 if (alias.isTypeAlias()) {
1220 // TODO: Support nested aliases in mutable types.
1221 Type type = Type::getFromOpaquePointer(pointer: opaqueSymbol);
1222 if (type.hasTrait<TypeTrait::IsMutable>())
1223 p.getStream() << type;
1224 else
1225 p.printTypeImpl(type);
1226 } else {
1227 // TODO: Support nested aliases in mutable attributes.
1228 Attribute attr = Attribute::getFromOpaquePointer(ptr: opaqueSymbol);
1229 if (attr.hasTrait<AttributeTrait::IsMutable>())
1230 p.getStream() << attr;
1231 else
1232 p.printAttributeImpl(attr);
1233 }
1234
1235 p.getStream() << newLine;
1236 }
1237}
1238
1239//===----------------------------------------------------------------------===//
1240// SSANameState
1241//===----------------------------------------------------------------------===//
1242
1243namespace {
1244/// Info about block printing: a number which is its position in the visitation
1245/// order, and a name that is used to print reference to it, e.g. ^bb42.
1246struct BlockInfo {
1247 int ordering;
1248 StringRef name;
1249};
1250
1251/// This class manages the state of SSA value names.
1252class SSANameState {
1253public:
1254 /// A sentinel value used for values with names set.
1255 enum : unsigned { NameSentinel = ~0U };
1256
1257 SSANameState(Operation *op, const OpPrintingFlags &printerFlags);
1258 SSANameState() = default;
1259
1260 /// Print the SSA identifier for the given value to 'stream'. If
1261 /// 'printResultNo' is true, it also presents the result number ('#' number)
1262 /// of this value.
1263 void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
1264
1265 /// Print the operation identifier.
1266 void printOperationID(Operation *op, raw_ostream &stream) const;
1267
1268 /// Return the result indices for each of the result groups registered by this
1269 /// operation, or empty if none exist.
1270 ArrayRef<int> getOpResultGroups(Operation *op);
1271
1272 /// Get the info for the given block.
1273 BlockInfo getBlockInfo(Block *block);
1274
1275 /// Renumber the arguments for the specified region to the same names as the
1276 /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
1277 /// details.
1278 void shadowRegionArgs(Region &region, ValueRange namesToUse);
1279
1280private:
1281 /// Number the SSA values within the given IR unit.
1282 void numberValuesInRegion(Region &region);
1283 void numberValuesInBlock(Block &block);
1284 void numberValuesInOp(Operation &op);
1285
1286 /// Given a result of an operation 'result', find the result group head
1287 /// 'lookupValue' and the result of 'result' within that group in
1288 /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
1289 /// has more than 1 result.
1290 void getResultIDAndNumber(OpResult result, Value &lookupValue,
1291 std::optional<int> &lookupResultNo) const;
1292
1293 /// Set a special value name for the given value.
1294 void setValueName(Value value, StringRef name);
1295
1296 /// Uniques the given value name within the printer. If the given name
1297 /// conflicts, it is automatically renamed.
1298 StringRef uniqueValueName(StringRef name);
1299
1300 /// This is the value ID for each SSA value. If this returns NameSentinel,
1301 /// then the valueID has an entry in valueNames.
1302 DenseMap<Value, unsigned> valueIDs;
1303 DenseMap<Value, StringRef> valueNames;
1304
1305 /// When printing users of values, an operation without a result might
1306 /// be the user. This map holds ids for such operations.
1307 DenseMap<Operation *, unsigned> operationIDs;
1308
1309 /// This is a map of operations that contain multiple named result groups,
1310 /// i.e. there may be multiple names for the results of the operation. The
1311 /// value of this map are the result numbers that start a result group.
1312 DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
1313
1314 /// This maps blocks to there visitation number in the current region as well
1315 /// as the string representing their name.
1316 DenseMap<Block *, BlockInfo> blockNames;
1317
1318 /// This keeps track of all of the non-numeric names that are in flight,
1319 /// allowing us to check for duplicates.
1320 /// Note: the value of the map is unused.
1321 llvm::ScopedHashTable<StringRef, char> usedNames;
1322 llvm::BumpPtrAllocator usedNameAllocator;
1323
1324 /// This is the next value ID to assign in numbering.
1325 unsigned nextValueID = 0;
1326 /// This is the next ID to assign to a region entry block argument.
1327 unsigned nextArgumentID = 0;
1328 /// This is the next ID to assign when a name conflict is detected.
1329 unsigned nextConflictID = 0;
1330
1331 /// These are the printing flags. They control, eg., whether to print in
1332 /// generic form.
1333 OpPrintingFlags printerFlags;
1334};
1335} // namespace
1336
1337SSANameState::SSANameState(Operation *op, const OpPrintingFlags &printerFlags)
1338 : printerFlags(printerFlags) {
1339 llvm::SaveAndRestore valueIDSaver(nextValueID);
1340 llvm::SaveAndRestore argumentIDSaver(nextArgumentID);
1341 llvm::SaveAndRestore conflictIDSaver(nextConflictID);
1342
1343 // The naming context includes `nextValueID`, `nextArgumentID`,
1344 // `nextConflictID` and `usedNames` scoped HashTable. This information is
1345 // carried from the parent region.
1346 using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy;
1347 using NamingContext =
1348 std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>;
1349
1350 // Allocator for UsedNamesScopeTy
1351 llvm::BumpPtrAllocator allocator;
1352
1353 // Add a scope for the top level operation.
1354 auto *topLevelNamesScope =
1355 new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames);
1356
1357 SmallVector<NamingContext, 8> nameContext;
1358 for (Region &region : op->getRegions())
1359 nameContext.push_back(Elt: std::make_tuple(args: &region, args&: nextValueID, args&: nextArgumentID,
1360 args&: nextConflictID, args&: topLevelNamesScope));
1361
1362 numberValuesInOp(op&: *op);
1363
1364 while (!nameContext.empty()) {
1365 Region *region;
1366 UsedNamesScopeTy *parentScope;
1367 std::tie(args&: region, args&: nextValueID, args&: nextArgumentID, args&: nextConflictID, args&: parentScope) =
1368 nameContext.pop_back_val();
1369
1370 // When we switch from one subtree to another, pop the scopes(needless)
1371 // until the parent scope.
1372 while (usedNames.getCurScope() != parentScope) {
1373 usedNames.getCurScope()->~UsedNamesScopeTy();
1374 assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) &&
1375 "top level parentScope must be a nullptr");
1376 }
1377
1378 // Add a scope for the current region.
1379 auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>())
1380 UsedNamesScopeTy(usedNames);
1381
1382 numberValuesInRegion(region&: *region);
1383
1384 for (Operation &op : region->getOps())
1385 for (Region &region : op.getRegions())
1386 nameContext.push_back(Elt: std::make_tuple(args: &region, args&: nextValueID,
1387 args&: nextArgumentID, args&: nextConflictID,
1388 args&: curNamesScope));
1389 }
1390
1391 // Manually remove all the scopes.
1392 while (usedNames.getCurScope() != nullptr)
1393 usedNames.getCurScope()->~UsedNamesScopeTy();
1394}
1395
1396void SSANameState::printValueID(Value value, bool printResultNo,
1397 raw_ostream &stream) const {
1398 if (!value) {
1399 stream << "<<NULL VALUE>>";
1400 return;
1401 }
1402
1403 std::optional<int> resultNo;
1404 auto lookupValue = value;
1405
1406 // If this is an operation result, collect the head lookup value of the result
1407 // group and the result number of 'result' within that group.
1408 if (OpResult result = dyn_cast<OpResult>(Val&: value))
1409 getResultIDAndNumber(result, lookupValue, lookupResultNo&: resultNo);
1410
1411 auto it = valueIDs.find(Val: lookupValue);
1412 if (it == valueIDs.end()) {
1413 stream << "<<UNKNOWN SSA VALUE>>";
1414 return;
1415 }
1416
1417 stream << '%';
1418 if (it->second != NameSentinel) {
1419 stream << it->second;
1420 } else {
1421 auto nameIt = valueNames.find(Val: lookupValue);
1422 assert(nameIt != valueNames.end() && "Didn't have a name entry?");
1423 stream << nameIt->second;
1424 }
1425
1426 if (resultNo && printResultNo)
1427 stream << '#' << *resultNo;
1428}
1429
1430void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const {
1431 auto it = operationIDs.find(Val: op);
1432 if (it == operationIDs.end()) {
1433 stream << "<<UNKNOWN OPERATION>>";
1434 } else {
1435 stream << '%' << it->second;
1436 }
1437}
1438
1439ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
1440 auto it = opResultGroups.find(Val: op);
1441 return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
1442}
1443
1444BlockInfo SSANameState::getBlockInfo(Block *block) {
1445 auto it = blockNames.find(Val: block);
1446 BlockInfo invalidBlock{.ordering: -1, .name: "INVALIDBLOCK"};
1447 return it != blockNames.end() ? it->second : invalidBlock;
1448}
1449
1450void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
1451 assert(!region.empty() && "cannot shadow arguments of an empty region");
1452 assert(region.getNumArguments() == namesToUse.size() &&
1453 "incorrect number of names passed in");
1454 assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
1455 "only KnownIsolatedFromAbove ops can shadow names");
1456
1457 SmallVector<char, 16> nameStr;
1458 for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
1459 auto nameToUse = namesToUse[i];
1460 if (nameToUse == nullptr)
1461 continue;
1462 auto nameToReplace = region.getArgument(i);
1463
1464 nameStr.clear();
1465 llvm::raw_svector_ostream nameStream(nameStr);
1466 printValueID(value: nameToUse, /*printResultNo=*/true, stream&: nameStream);
1467
1468 // Entry block arguments should already have a pretty "arg" name.
1469 assert(valueIDs[nameToReplace] == NameSentinel);
1470
1471 // Use the name without the leading %.
1472 auto name = StringRef(nameStream.str()).drop_front();
1473
1474 // Overwrite the name.
1475 valueNames[nameToReplace] = name.copy(A&: usedNameAllocator);
1476 }
1477}
1478
1479void SSANameState::numberValuesInRegion(Region &region) {
1480 auto setBlockArgNameFn = [&](Value arg, StringRef name) {
1481 assert(!valueIDs.count(arg) && "arg numbered multiple times");
1482 assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == &region &&
1483 "arg not defined in current region");
1484 setValueName(value: arg, name);
1485 };
1486
1487 if (!printerFlags.shouldPrintGenericOpForm()) {
1488 if (Operation *op = region.getParentOp()) {
1489 if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
1490 asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
1491 }
1492 }
1493
1494 // Number the values within this region in a breadth-first order.
1495 unsigned nextBlockID = 0;
1496 for (auto &block : region) {
1497 // Each block gets a unique ID, and all of the operations within it get
1498 // numbered as well.
1499 auto blockInfoIt = blockNames.insert(KV: {&block, {.ordering: -1, .name: ""}});
1500 if (blockInfoIt.second) {
1501 // This block hasn't been named through `getAsmBlockArgumentNames`, use
1502 // default `^bbNNN` format.
1503 std::string name;
1504 llvm::raw_string_ostream(name) << "^bb" << nextBlockID;
1505 blockInfoIt.first->second.name = StringRef(name).copy(A&: usedNameAllocator);
1506 }
1507 blockInfoIt.first->second.ordering = nextBlockID++;
1508
1509 numberValuesInBlock(block);
1510 }
1511}
1512
1513void SSANameState::numberValuesInBlock(Block &block) {
1514 // Number the block arguments. We give entry block arguments a special name
1515 // 'arg'.
1516 bool isEntryBlock = block.isEntryBlock();
1517 SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
1518 llvm::raw_svector_ostream specialName(specialNameBuffer);
1519 for (auto arg : block.getArguments()) {
1520 if (valueIDs.count(Val: arg))
1521 continue;
1522 if (isEntryBlock) {
1523 specialNameBuffer.resize(N: strlen(s: "arg"));
1524 specialName << nextArgumentID++;
1525 }
1526 setValueName(value: arg, name: specialName.str());
1527 }
1528
1529 // Number the operations in this block.
1530 for (auto &op : block)
1531 numberValuesInOp(op);
1532}
1533
1534void SSANameState::numberValuesInOp(Operation &op) {
1535 // Function used to set the special result names for the operation.
1536 SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
1537 auto setResultNameFn = [&](Value result, StringRef name) {
1538 assert(!valueIDs.count(result) && "result numbered multiple times");
1539 assert(result.getDefiningOp() == &op && "result not defined by 'op'");
1540 setValueName(value: result, name);
1541
1542 // Record the result number for groups not anchored at 0.
1543 if (int resultNo = llvm::cast<OpResult>(Val&: result).getResultNumber())
1544 resultGroups.push_back(Elt: resultNo);
1545 };
1546 // Operations can customize the printing of block names in OpAsmOpInterface.
1547 auto setBlockNameFn = [&](Block *block, StringRef name) {
1548 assert(block->getParentOp() == &op &&
1549 "getAsmBlockArgumentNames callback invoked on a block not directly "
1550 "nested under the current operation");
1551 assert(!blockNames.count(block) && "block numbered multiple times");
1552 SmallString<16> tmpBuffer{"^"};
1553 name = sanitizeIdentifier(name, buffer&: tmpBuffer);
1554 if (name.data() != tmpBuffer.data()) {
1555 tmpBuffer.append(RHS: name);
1556 name = tmpBuffer.str();
1557 }
1558 name = name.copy(A&: usedNameAllocator);
1559 blockNames[block] = {.ordering: -1, .name: name};
1560 };
1561
1562 if (!printerFlags.shouldPrintGenericOpForm()) {
1563 if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
1564 asmInterface.getAsmBlockNames(setBlockNameFn);
1565 asmInterface.getAsmResultNames(setResultNameFn);
1566 }
1567 }
1568
1569 unsigned numResults = op.getNumResults();
1570 if (numResults == 0) {
1571 // If value users should be printed, operations with no result need an id.
1572 if (printerFlags.shouldPrintValueUsers()) {
1573 if (operationIDs.try_emplace(Key: &op, Args&: nextValueID).second)
1574 ++nextValueID;
1575 }
1576 return;
1577 }
1578 Value resultBegin = op.getResult(idx: 0);
1579
1580 // If the first result wasn't numbered, give it a default number.
1581 if (valueIDs.try_emplace(Key: resultBegin, Args&: nextValueID).second)
1582 ++nextValueID;
1583
1584 // If this operation has multiple result groups, mark it.
1585 if (resultGroups.size() != 1) {
1586 llvm::array_pod_sort(Start: resultGroups.begin(), End: resultGroups.end());
1587 opResultGroups.try_emplace(Key: &op, Args: std::move(resultGroups));
1588 }
1589}
1590
1591void SSANameState::getResultIDAndNumber(
1592 OpResult result, Value &lookupValue,
1593 std::optional<int> &lookupResultNo) const {
1594 Operation *owner = result.getOwner();
1595 if (owner->getNumResults() == 1)
1596 return;
1597 int resultNo = result.getResultNumber();
1598
1599 // If this operation has multiple result groups, we will need to find the
1600 // one corresponding to this result.
1601 auto resultGroupIt = opResultGroups.find(Val: owner);
1602 if (resultGroupIt == opResultGroups.end()) {
1603 // If not, just use the first result.
1604 lookupResultNo = resultNo;
1605 lookupValue = owner->getResult(idx: 0);
1606 return;
1607 }
1608
1609 // Find the correct index using a binary search, as the groups are ordered.
1610 ArrayRef<int> resultGroups = resultGroupIt->second;
1611 const auto *it = llvm::upper_bound(Range&: resultGroups, Value&: resultNo);
1612 int groupResultNo = 0, groupSize = 0;
1613
1614 // If there are no smaller elements, the last result group is the lookup.
1615 if (it == resultGroups.end()) {
1616 groupResultNo = resultGroups.back();
1617 groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
1618 } else {
1619 // Otherwise, the previous element is the lookup.
1620 groupResultNo = *std::prev(x: it);
1621 groupSize = *it - groupResultNo;
1622 }
1623
1624 // We only record the result number for a group of size greater than 1.
1625 if (groupSize != 1)
1626 lookupResultNo = resultNo - groupResultNo;
1627 lookupValue = owner->getResult(idx: groupResultNo);
1628}
1629
1630void SSANameState::setValueName(Value value, StringRef name) {
1631 // If the name is empty, the value uses the default numbering.
1632 if (name.empty()) {
1633 valueIDs[value] = nextValueID++;
1634 return;
1635 }
1636
1637 valueIDs[value] = NameSentinel;
1638 valueNames[value] = uniqueValueName(name);
1639}
1640
1641StringRef SSANameState::uniqueValueName(StringRef name) {
1642 SmallString<16> tmpBuffer;
1643 name = sanitizeIdentifier(name, buffer&: tmpBuffer);
1644
1645 // Check to see if this name is already unique.
1646 if (!usedNames.count(Key: name)) {
1647 name = name.copy(A&: usedNameAllocator);
1648 } else {
1649 // Otherwise, we had a conflict - probe until we find a unique name. This
1650 // is guaranteed to terminate (and usually in a single iteration) because it
1651 // generates new names by incrementing nextConflictID.
1652 SmallString<64> probeName(name);
1653 probeName.push_back(Elt: '_');
1654 while (true) {
1655 probeName += llvm::utostr(X: nextConflictID++);
1656 if (!usedNames.count(Key: probeName)) {
1657 name = probeName.str().copy(A&: usedNameAllocator);
1658 break;
1659 }
1660 probeName.resize(N: name.size() + 1);
1661 }
1662 }
1663
1664 usedNames.insert(Key: name, Val: char());
1665 return name;
1666}
1667
1668//===----------------------------------------------------------------------===//
1669// DistinctState
1670//===----------------------------------------------------------------------===//
1671
1672namespace {
1673/// This class manages the state for distinct attributes.
1674class DistinctState {
1675public:
1676 /// Returns a unique identifier for the given distinct attribute.
1677 uint64_t getId(DistinctAttr distinctAttr);
1678
1679private:
1680 uint64_t distinctCounter = 0;
1681 DenseMap<DistinctAttr, uint64_t> distinctAttrMap;
1682};
1683} // namespace
1684
1685uint64_t DistinctState::getId(DistinctAttr distinctAttr) {
1686 auto [it, inserted] =
1687 distinctAttrMap.try_emplace(Key: distinctAttr, Args&: distinctCounter);
1688 if (inserted)
1689 distinctCounter++;
1690 return it->getSecond();
1691}
1692
1693//===----------------------------------------------------------------------===//
1694// Resources
1695//===----------------------------------------------------------------------===//
1696
1697AsmParsedResourceEntry::~AsmParsedResourceEntry() = default;
1698AsmResourceBuilder::~AsmResourceBuilder() = default;
1699AsmResourceParser::~AsmResourceParser() = default;
1700AsmResourcePrinter::~AsmResourcePrinter() = default;
1701
1702StringRef mlir::toString(AsmResourceEntryKind kind) {
1703 switch (kind) {
1704 case AsmResourceEntryKind::Blob:
1705 return "blob";
1706 case AsmResourceEntryKind::Bool:
1707 return "bool";
1708 case AsmResourceEntryKind::String:
1709 return "string";
1710 }
1711 llvm_unreachable("unknown AsmResourceEntryKind");
1712}
1713
1714AsmResourceParser &FallbackAsmResourceMap::getParserFor(StringRef key) {
1715 std::unique_ptr<ResourceCollection> &collection = keyToResources[key.str()];
1716 if (!collection)
1717 collection = std::make_unique<ResourceCollection>(args&: key);
1718 return *collection;
1719}
1720
1721std::vector<std::unique_ptr<AsmResourcePrinter>>
1722FallbackAsmResourceMap::getPrinters() {
1723 std::vector<std::unique_ptr<AsmResourcePrinter>> printers;
1724 for (auto &it : keyToResources) {
1725 ResourceCollection *collection = it.second.get();
1726 auto buildValues = [=](Operation *op, AsmResourceBuilder &builder) {
1727 return collection->buildResources(op, builder);
1728 };
1729 printers.emplace_back(
1730 args: AsmResourcePrinter::fromCallable(name: collection->getName(), printFn&: buildValues));
1731 }
1732 return printers;
1733}
1734
1735LogicalResult FallbackAsmResourceMap::ResourceCollection::parseResource(
1736 AsmParsedResourceEntry &entry) {
1737 switch (entry.getKind()) {
1738 case AsmResourceEntryKind::Blob: {
1739 FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
1740 if (failed(result: blob))
1741 return failure();
1742 resources.emplace_back(Args: entry.getKey(), Args: std::move(*blob));
1743 return success();
1744 }
1745 case AsmResourceEntryKind::Bool: {
1746 FailureOr<bool> value = entry.parseAsBool();
1747 if (failed(result: value))
1748 return failure();
1749 resources.emplace_back(Args: entry.getKey(), Args&: *value);
1750 break;
1751 }
1752 case AsmResourceEntryKind::String: {
1753 FailureOr<std::string> str = entry.parseAsString();
1754 if (failed(result: str))
1755 return failure();
1756 resources.emplace_back(Args: entry.getKey(), Args: std::move(*str));
1757 break;
1758 }
1759 }
1760 return success();
1761}
1762
1763void FallbackAsmResourceMap::ResourceCollection::buildResources(
1764 Operation *op, AsmResourceBuilder &builder) const {
1765 for (const auto &entry : resources) {
1766 if (const auto *value = std::get_if<AsmResourceBlob>(ptr: &entry.value))
1767 builder.buildBlob(key: entry.key, blob: *value);
1768 else if (const auto *value = std::get_if<bool>(ptr: &entry.value))
1769 builder.buildBool(key: entry.key, data: *value);
1770 else if (const auto *value = std::get_if<std::string>(ptr: &entry.value))
1771 builder.buildString(key: entry.key, data: *value);
1772 else
1773 llvm_unreachable("unknown AsmResourceEntryKind");
1774 }
1775}
1776
1777//===----------------------------------------------------------------------===//
1778// AsmState
1779//===----------------------------------------------------------------------===//
1780
1781namespace mlir {
1782namespace detail {
1783class AsmStateImpl {
1784public:
1785 explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags,
1786 AsmState::LocationMap *locationMap)
1787 : interfaces(op->getContext()), nameState(op, printerFlags),
1788 printerFlags(printerFlags), locationMap(locationMap) {}
1789 explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
1790 AsmState::LocationMap *locationMap)
1791 : interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {}
1792
1793 /// Initialize the alias state to enable the printing of aliases.
1794 void initializeAliases(Operation *op) {
1795 aliasState.initialize(op, printerFlags, interfaces);
1796 }
1797
1798 /// Get the state used for aliases.
1799 AliasState &getAliasState() { return aliasState; }
1800
1801 /// Get the state used for SSA names.
1802 SSANameState &getSSANameState() { return nameState; }
1803
1804 /// Get the state used for distinct attribute identifiers.
1805 DistinctState &getDistinctState() { return distinctState; }
1806
1807 /// Return the dialects within the context that implement
1808 /// OpAsmDialectInterface.
1809 DialectInterfaceCollection<OpAsmDialectInterface> &getDialectInterfaces() {
1810 return interfaces;
1811 }
1812
1813 /// Return the non-dialect resource printers.
1814 auto getResourcePrinters() {
1815 return llvm::make_pointee_range(Range&: externalResourcePrinters);
1816 }
1817
1818 /// Get the printer flags.
1819 const OpPrintingFlags &getPrinterFlags() const { return printerFlags; }
1820
1821 /// Register the location, line and column, within the buffer that the given
1822 /// operation was printed at.
1823 void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
1824 if (locationMap)
1825 (*locationMap)[op] = std::make_pair(x&: line, y&: col);
1826 }
1827
1828 /// Return the referenced dialect resources within the printer.
1829 DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> &
1830 getDialectResources() {
1831 return dialectResources;
1832 }
1833
1834 LogicalResult pushCyclicPrinting(const void *opaquePointer) {
1835 return success(isSuccess: cyclicPrintingStack.insert(X: opaquePointer));
1836 }
1837
1838 void popCyclicPrinting() { cyclicPrintingStack.pop_back(); }
1839
1840private:
1841 /// Collection of OpAsm interfaces implemented in the context.
1842 DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
1843
1844 /// A collection of non-dialect resource printers.
1845 SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters;
1846
1847 /// A set of dialect resources that were referenced during printing.
1848 DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> dialectResources;
1849
1850 /// The state used for attribute and type aliases.
1851 AliasState aliasState;
1852
1853 /// The state used for SSA value names.
1854 SSANameState nameState;
1855
1856 /// The state used for distinct attribute identifiers.
1857 DistinctState distinctState;
1858
1859 /// Flags that control op output.
1860 OpPrintingFlags printerFlags;
1861
1862 /// An optional location map to be populated.
1863 AsmState::LocationMap *locationMap;
1864
1865 /// Stack of potentially cyclic mutable attributes or type currently being
1866 /// printed.
1867 SetVector<const void *> cyclicPrintingStack;
1868
1869 // Allow direct access to the impl fields.
1870 friend AsmState;
1871};
1872
1873template <typename Range>
1874void printDimensionList(raw_ostream &stream, Range &&shape) {
1875 llvm::interleave(
1876 shape, stream,
1877 [&stream](const auto &dimSize) {
1878 if (ShapedType::isDynamic(dimSize))
1879 stream << "?";
1880 else
1881 stream << dimSize;
1882 },
1883 "x");
1884}
1885
1886} // namespace detail
1887} // namespace mlir
1888
1889/// Verifies the operation and switches to generic op printing if verification
1890/// fails. We need to do this because custom print functions may fail for
1891/// invalid ops.
1892static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op,
1893 OpPrintingFlags printerFlags) {
1894 if (printerFlags.shouldPrintGenericOpForm() ||
1895 printerFlags.shouldAssumeVerified())
1896 return printerFlags;
1897
1898 LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Verifying operation: "
1899 << op->getName() << "\n");
1900
1901 // Ignore errors emitted by the verifier. We check the thread id to avoid
1902 // consuming other threads' errors.
1903 auto parentThreadId = llvm::get_threadid();
1904 ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) {
1905 if (parentThreadId == llvm::get_threadid()) {
1906 LLVM_DEBUG({
1907 diag.print(llvm::dbgs());
1908 llvm::dbgs() << "\n";
1909 });
1910 return success();
1911 }
1912 return failure();
1913 });
1914 if (failed(result: verify(op))) {
1915 LLVM_DEBUG(llvm::dbgs()
1916 << DEBUG_TYPE << ": '" << op->getName()
1917 << "' failed to verify and will be printed in generic form\n");
1918 printerFlags.printGenericOpForm();
1919 }
1920
1921 return printerFlags;
1922}
1923
1924AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
1925 LocationMap *locationMap, FallbackAsmResourceMap *map)
1926 : impl(std::make_unique<AsmStateImpl>(
1927 args&: op, args: verifyOpAndAdjustFlags(op, printerFlags), args&: locationMap)) {
1928 if (map)
1929 attachFallbackResourcePrinter(map&: *map);
1930}
1931AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
1932 LocationMap *locationMap, FallbackAsmResourceMap *map)
1933 : impl(std::make_unique<AsmStateImpl>(args&: ctx, args: printerFlags, args&: locationMap)) {
1934 if (map)
1935 attachFallbackResourcePrinter(map&: *map);
1936}
1937AsmState::~AsmState() = default;
1938
1939const OpPrintingFlags &AsmState::getPrinterFlags() const {
1940 return impl->getPrinterFlags();
1941}
1942
1943void AsmState::attachResourcePrinter(
1944 std::unique_ptr<AsmResourcePrinter> printer) {
1945 impl->externalResourcePrinters.emplace_back(Args: std::move(printer));
1946}
1947
1948DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> &
1949AsmState::getDialectResources() const {
1950 return impl->getDialectResources();
1951}
1952
1953//===----------------------------------------------------------------------===//
1954// AsmPrinter::Impl
1955//===----------------------------------------------------------------------===//
1956
1957AsmPrinter::Impl::Impl(raw_ostream &os, AsmStateImpl &state)
1958 : os(os), state(state), printerFlags(state.getPrinterFlags()) {}
1959
1960void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) {
1961 // Check to see if we are printing debug information.
1962 if (!printerFlags.shouldPrintDebugInfo())
1963 return;
1964
1965 os << " ";
1966 printLocation(loc, /*allowAlias=*/allowAlias);
1967}
1968
1969void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
1970 bool isTopLevel) {
1971 // If this isn't a top-level location, check for an alias.
1972 if (!isTopLevel && succeeded(result: state.getAliasState().getAlias(attr: loc, os)))
1973 return;
1974
1975 TypeSwitch<LocationAttr>(loc)
1976 .Case<OpaqueLoc>([&](OpaqueLoc loc) {
1977 printLocationInternal(loc.getFallbackLocation(), pretty);
1978 })
1979 .Case<UnknownLoc>([&](UnknownLoc loc) {
1980 if (pretty)
1981 os << "[unknown]";
1982 else
1983 os << "unknown";
1984 })
1985 .Case<FileLineColLoc>([&](FileLineColLoc loc) {
1986 if (pretty)
1987 os << loc.getFilename().getValue();
1988 else
1989 printEscapedString(loc.getFilename());
1990 os << ':' << loc.getLine() << ':' << loc.getColumn();
1991 })
1992 .Case<NameLoc>([&](NameLoc loc) {
1993 printEscapedString(loc.getName());
1994
1995 // Print the child if it isn't unknown.
1996 auto childLoc = loc.getChildLoc();
1997 if (!llvm::isa<UnknownLoc>(childLoc)) {
1998 os << '(';
1999 printLocationInternal(childLoc, pretty);
2000 os << ')';
2001 }
2002 })
2003 .Case<CallSiteLoc>([&](CallSiteLoc loc) {
2004 Location caller = loc.getCaller();
2005 Location callee = loc.getCallee();
2006 if (!pretty)
2007 os << "callsite(";
2008 printLocationInternal(callee, pretty);
2009 if (pretty) {
2010 if (llvm::isa<NameLoc>(callee)) {
2011 if (llvm::isa<FileLineColLoc>(caller)) {
2012 os << " at ";
2013 } else {
2014 os << newLine << " at ";
2015 }
2016 } else {
2017 os << newLine << " at ";
2018 }
2019 } else {
2020 os << " at ";
2021 }
2022 printLocationInternal(caller, pretty);
2023 if (!pretty)
2024 os << ")";
2025 })
2026 .Case<FusedLoc>([&](FusedLoc loc) {
2027 if (!pretty)
2028 os << "fused";
2029 if (Attribute metadata = loc.getMetadata()) {
2030 os << '<';
2031 printAttribute(metadata);
2032 os << '>';
2033 }
2034 os << '[';
2035 interleave(
2036 loc.getLocations(),
2037 [&](Location loc) { printLocationInternal(loc, pretty); },
2038 [&]() { os << ", "; });
2039 os << ']';
2040 });
2041}
2042
2043/// Print a floating point value in a way that the parser will be able to
2044/// round-trip losslessly.
2045static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
2046 // We would like to output the FP constant value in exponential notation,
2047 // but we cannot do this if doing so will lose precision. Check here to
2048 // make sure that we only output it in exponential format if we can parse
2049 // the value back and get the same value.
2050 bool isInf = apValue.isInfinity();
2051 bool isNaN = apValue.isNaN();
2052 if (!isInf && !isNaN) {
2053 SmallString<128> strValue;
2054 apValue.toString(Str&: strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
2055 /*TruncateZero=*/false);
2056
2057 // Check to make sure that the stringized number is not some string like
2058 // "Inf" or NaN, that atof will accept, but the lexer will not. Check
2059 // that the string matches the "[-+]?[0-9]" regex.
2060 assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
2061 ((strValue[0] == '-' || strValue[0] == '+') &&
2062 (strValue[1] >= '0' && strValue[1] <= '9'))) &&
2063 "[-+]?[0-9] regex does not match!");
2064
2065 // Parse back the stringized version and check that the value is equal
2066 // (i.e., there is no precision loss).
2067 if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(RHS: apValue)) {
2068 os << strValue;
2069 return;
2070 }
2071
2072 // If it is not, use the default format of APFloat instead of the
2073 // exponential notation.
2074 strValue.clear();
2075 apValue.toString(Str&: strValue);
2076
2077 // Make sure that we can parse the default form as a float.
2078 if (strValue.str().contains(C: '.')) {
2079 os << strValue;
2080 return;
2081 }
2082 }
2083
2084 // Print special values in hexadecimal format. The sign bit should be included
2085 // in the literal.
2086 SmallVector<char, 16> str;
2087 APInt apInt = apValue.bitcastToAPInt();
2088 apInt.toString(Str&: str, /*Radix=*/16, /*Signed=*/false,
2089 /*formatAsCLiteral=*/true);
2090 os << str;
2091}
2092
2093void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
2094 if (printerFlags.shouldPrintDebugInfoPrettyForm())
2095 return printLocationInternal(loc, /*pretty=*/true, /*isTopLevel=*/true);
2096
2097 os << "loc(";
2098 if (!allowAlias || failed(result: printAlias(attr: loc)))
2099 printLocationInternal(loc, /*pretty=*/false, /*isTopLevel=*/true);
2100 os << ')';
2101}
2102
2103void AsmPrinter::Impl::printResourceHandle(
2104 const AsmDialectResourceHandle &resource) {
2105 auto *interface = cast<OpAsmDialectInterface>(Val: resource.getDialect());
2106 os << interface->getResourceKey(handle: resource);
2107 state.getDialectResources()[resource.getDialect()].insert(X: resource);
2108}
2109
2110/// Returns true if the given dialect symbol data is simple enough to print in
2111/// the pretty form. This is essentially when the symbol takes the form:
2112/// identifier (`<` body `>`)?
2113static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
2114 // The name must start with an identifier.
2115 if (symName.empty() || !isalpha(symName.front()))
2116 return false;
2117
2118 // Ignore all the characters that are valid in an identifier in the symbol
2119 // name.
2120 symName = symName.drop_while(
2121 F: [](char c) { return llvm::isAlnum(C: c) || c == '.' || c == '_'; });
2122 if (symName.empty())
2123 return true;
2124
2125 // If we got to an unexpected character, then it must be a <>. Check that the
2126 // rest of the symbol is wrapped within <>.
2127 return symName.front() == '<' && symName.back() == '>';
2128}
2129
2130/// Print the given dialect symbol to the stream.
2131static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
2132 StringRef dialectName, StringRef symString) {
2133 os << symPrefix << dialectName;
2134
2135 // If this symbol name is simple enough, print it directly in pretty form,
2136 // otherwise, we print it as an escaped string.
2137 if (isDialectSymbolSimpleEnoughForPrettyForm(symName: symString)) {
2138 os << '.' << symString;
2139 return;
2140 }
2141
2142 os << '<' << symString << '>';
2143}
2144
2145/// Returns true if the given string can be represented as a bare identifier.
2146static bool isBareIdentifier(StringRef name) {
2147 // By making this unsigned, the value passed in to isalnum will always be
2148 // in the range 0-255. This is important when building with MSVC because
2149 // its implementation will assert. This situation can arise when dealing
2150 // with UTF-8 multibyte characters.
2151 if (name.empty() || (!isalpha(name[0]) && name[0] != '_'))
2152 return false;
2153 return llvm::all_of(Range: name.drop_front(), P: [](unsigned char c) {
2154 return isalnum(c) || c == '_' || c == '$' || c == '.';
2155 });
2156}
2157
2158/// Print the given string as a keyword, or a quoted and escaped string if it
2159/// has any special or non-printable characters in it.
2160static void printKeywordOrString(StringRef keyword, raw_ostream &os) {
2161 // If it can be represented as a bare identifier, write it directly.
2162 if (isBareIdentifier(name: keyword)) {
2163 os << keyword;
2164 return;
2165 }
2166
2167 // Otherwise, output the keyword wrapped in quotes with proper escaping.
2168 os << "\"";
2169 printEscapedString(Name: keyword, Out&: os);
2170 os << '"';
2171}
2172
2173/// Print the given string as a symbol reference. A symbol reference is
2174/// represented as a string prefixed with '@'. The reference is surrounded with
2175/// ""'s and escaped if it has any special or non-printable characters in it.
2176static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
2177 if (symbolRef.empty()) {
2178 os << "@<<INVALID EMPTY SYMBOL>>";
2179 return;
2180 }
2181 os << '@';
2182 printKeywordOrString(keyword: symbolRef, os);
2183}
2184
2185// Print out a valid ElementsAttr that is succinct and can represent any
2186// potential shape/type, for use when eliding a large ElementsAttr.
2187//
2188// We choose to use a dense resource ElementsAttr literal with conspicuous
2189// content to hopefully alert readers to the fact that this has been elided.
2190static void printElidedElementsAttr(raw_ostream &os) {
2191 os << R"(dense_resource<__elided__>)";
2192}
2193
2194LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
2195 return state.getAliasState().getAlias(attr, os);
2196}
2197
2198LogicalResult AsmPrinter::Impl::printAlias(Type type) {
2199 return state.getAliasState().getAlias(ty: type, os);
2200}
2201
2202void AsmPrinter::Impl::printAttribute(Attribute attr,
2203 AttrTypeElision typeElision) {
2204 if (!attr) {
2205 os << "<<NULL ATTRIBUTE>>";
2206 return;
2207 }
2208
2209 // Try to print an alias for this attribute.
2210 if (succeeded(result: printAlias(attr)))
2211 return;
2212 return printAttributeImpl(attr, typeElision);
2213}
2214
2215void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
2216 AttrTypeElision typeElision) {
2217 if (!isa<BuiltinDialect>(Val: attr.getDialect())) {
2218 printDialectAttribute(attr);
2219 } else if (auto opaqueAttr = llvm::dyn_cast<OpaqueAttr>(attr)) {
2220 printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
2221 opaqueAttr.getAttrData());
2222 } else if (llvm::isa<UnitAttr>(Val: attr)) {
2223 os << "unit";
2224 return;
2225 } else if (auto distinctAttr = llvm::dyn_cast<DistinctAttr>(attr)) {
2226 os << "distinct[" << state.getDistinctState().getId(distinctAttr) << "]<";
2227 if (!llvm::isa<UnitAttr>(Val: distinctAttr.getReferencedAttr())) {
2228 printAttribute(attr: distinctAttr.getReferencedAttr());
2229 }
2230 os << '>';
2231 return;
2232 } else if (auto dictAttr = llvm::dyn_cast<DictionaryAttr>(attr)) {
2233 os << '{';
2234 interleaveComma(dictAttr.getValue(),
2235 [&](NamedAttribute attr) { printNamedAttribute(attr); });
2236 os << '}';
2237
2238 } else if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
2239 Type intType = intAttr.getType();
2240 if (intType.isSignlessInteger(width: 1)) {
2241 os << (intAttr.getValue().getBoolValue() ? "true" : "false");
2242
2243 // Boolean integer attributes always elides the type.
2244 return;
2245 }
2246
2247 // Only print attributes as unsigned if they are explicitly unsigned or are
2248 // signless 1-bit values. Indexes, signed values, and multi-bit signless
2249 // values print as signed.
2250 bool isUnsigned =
2251 intType.isUnsignedInteger() || intType.isSignlessInteger(width: 1);
2252 intAttr.getValue().print(os, !isUnsigned);
2253
2254 // IntegerAttr elides the type if I64.
2255 if (typeElision == AttrTypeElision::May && intType.isSignlessInteger(width: 64))
2256 return;
2257
2258 } else if (auto floatAttr = llvm::dyn_cast<FloatAttr>(attr)) {
2259 printFloatValue(floatAttr.getValue(), os);
2260
2261 // FloatAttr elides the type if F64.
2262 if (typeElision == AttrTypeElision::May && floatAttr.getType().isF64())
2263 return;
2264
2265 } else if (auto strAttr = llvm::dyn_cast<StringAttr>(attr)) {
2266 printEscapedString(str: strAttr.getValue());
2267
2268 } else if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr)) {
2269 os << '[';
2270 interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
2271 printAttribute(attr, typeElision: AttrTypeElision::May);
2272 });
2273 os << ']';
2274
2275 } else if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(attr)) {
2276 os << "affine_map<";
2277 affineMapAttr.getValue().print(os);
2278 os << '>';
2279
2280 // AffineMap always elides the type.
2281 return;
2282
2283 } else if (auto integerSetAttr = llvm::dyn_cast<IntegerSetAttr>(attr)) {
2284 os << "affine_set<";
2285 integerSetAttr.getValue().print(os);
2286 os << '>';
2287
2288 // IntegerSet always elides the type.
2289 return;
2290
2291 } else if (auto typeAttr = llvm::dyn_cast<TypeAttr>(attr)) {
2292 printType(type: typeAttr.getValue());
2293
2294 } else if (auto refAttr = llvm::dyn_cast<SymbolRefAttr>(attr)) {
2295 printSymbolReference(refAttr.getRootReference().getValue(), os);
2296 for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
2297 os << "::";
2298 printSymbolReference(nestedRef.getValue(), os);
2299 }
2300
2301 } else if (auto intOrFpEltAttr =
2302 llvm::dyn_cast<DenseIntOrFPElementsAttr>(attr)) {
2303 if (printerFlags.shouldElideElementsAttr(attr: intOrFpEltAttr)) {
2304 printElidedElementsAttr(os);
2305 } else {
2306 os << "dense<";
2307 printDenseIntOrFPElementsAttr(attr: intOrFpEltAttr, /*allowHex=*/true);
2308 os << '>';
2309 }
2310
2311 } else if (auto strEltAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr)) {
2312 if (printerFlags.shouldElideElementsAttr(attr: strEltAttr)) {
2313 printElidedElementsAttr(os);
2314 } else {
2315 os << "dense<";
2316 printDenseStringElementsAttr(attr: strEltAttr);
2317 os << '>';
2318 }
2319
2320 } else if (auto sparseEltAttr = llvm::dyn_cast<SparseElementsAttr>(attr)) {
2321 if (printerFlags.shouldElideElementsAttr(attr: sparseEltAttr.getIndices()) ||
2322 printerFlags.shouldElideElementsAttr(attr: sparseEltAttr.getValues())) {
2323 printElidedElementsAttr(os);
2324 } else {
2325 os << "sparse<";
2326 DenseIntElementsAttr indices = sparseEltAttr.getIndices();
2327 if (indices.getNumElements() != 0) {
2328 printDenseIntOrFPElementsAttr(attr: indices, /*allowHex=*/false);
2329 os << ", ";
2330 printDenseElementsAttr(attr: sparseEltAttr.getValues(), /*allowHex=*/true);
2331 }
2332 os << '>';
2333 }
2334 } else if (auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(attr)) {
2335 stridedLayoutAttr.print(os);
2336 } else if (auto denseArrayAttr = llvm::dyn_cast<DenseArrayAttr>(attr)) {
2337 os << "array<";
2338 printType(type: denseArrayAttr.getElementType());
2339 if (!denseArrayAttr.empty()) {
2340 os << ": ";
2341 printDenseArrayAttr(attr: denseArrayAttr);
2342 }
2343 os << ">";
2344 return;
2345 } else if (auto resourceAttr =
2346 llvm::dyn_cast<DenseResourceElementsAttr>(attr)) {
2347 os << "dense_resource<";
2348 printResourceHandle(resource: resourceAttr.getRawHandle());
2349 os << ">";
2350 } else if (auto locAttr = llvm::dyn_cast<LocationAttr>(Val&: attr)) {
2351 printLocation(loc: locAttr);
2352 } else {
2353 llvm::report_fatal_error(reason: "Unknown builtin attribute");
2354 }
2355 // Don't print the type if we must elide it, or if it is a None type.
2356 if (typeElision != AttrTypeElision::Must) {
2357 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) {
2358 Type attrType = typedAttr.getType();
2359 if (!llvm::isa<NoneType>(Val: attrType)) {
2360 os << " : ";
2361 printType(type: attrType);
2362 }
2363 }
2364 }
2365}
2366
2367/// Print the integer element of a DenseElementsAttr.
2368static void printDenseIntElement(const APInt &value, raw_ostream &os,
2369 Type type) {
2370 if (type.isInteger(width: 1))
2371 os << (value.getBoolValue() ? "true" : "false");
2372 else
2373 value.print(OS&: os, isSigned: !type.isUnsignedInteger());
2374}
2375
2376static void
2377printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
2378 function_ref<void(unsigned)> printEltFn) {
2379 // Special case for 0-d and splat tensors.
2380 if (isSplat)
2381 return printEltFn(0);
2382
2383 // Special case for degenerate tensors.
2384 auto numElements = type.getNumElements();
2385 if (numElements == 0)
2386 return;
2387
2388 // We use a mixed-radix counter to iterate through the shape. When we bump a
2389 // non-least-significant digit, we emit a close bracket. When we next emit an
2390 // element we re-open all closed brackets.
2391
2392 // The mixed-radix counter, with radices in 'shape'.
2393 int64_t rank = type.getRank();
2394 SmallVector<unsigned, 4> counter(rank, 0);
2395 // The number of brackets that have been opened and not closed.
2396 unsigned openBrackets = 0;
2397
2398 auto shape = type.getShape();
2399 auto bumpCounter = [&] {
2400 // Bump the least significant digit.
2401 ++counter[rank - 1];
2402 // Iterate backwards bubbling back the increment.
2403 for (unsigned i = rank - 1; i > 0; --i)
2404 if (counter[i] >= shape[i]) {
2405 // Index 'i' is rolled over. Bump (i-1) and close a bracket.
2406 counter[i] = 0;
2407 ++counter[i - 1];
2408 --openBrackets;
2409 os << ']';
2410 }
2411 };
2412
2413 for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
2414 if (idx != 0)
2415 os << ", ";
2416 while (openBrackets++ < rank)
2417 os << '[';
2418 openBrackets = rank;
2419 printEltFn(idx);
2420 bumpCounter();
2421 }
2422 while (openBrackets-- > 0)
2423 os << ']';
2424}
2425
2426void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr,
2427 bool allowHex) {
2428 if (auto stringAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr))
2429 return printDenseStringElementsAttr(attr: stringAttr);
2430
2431 printDenseIntOrFPElementsAttr(llvm::cast<DenseIntOrFPElementsAttr>(attr),
2432 allowHex);
2433}
2434
2435void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
2436 DenseIntOrFPElementsAttr attr, bool allowHex) {
2437 auto type = attr.getType();
2438 auto elementType = type.getElementType();
2439
2440 // Check to see if we should format this attribute as a hex string.
2441 auto numElements = type.getNumElements();
2442 if (!attr.isSplat() && allowHex &&
2443 shouldPrintElementsAttrWithHex(numElements)) {
2444 ArrayRef<char> rawData = attr.getRawData();
2445 if (llvm::endianness::native == llvm::endianness::big) {
2446 // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE
2447 // machines. It is converted here to print in LE format.
2448 SmallVector<char, 64> outDataVec(rawData.size());
2449 MutableArrayRef<char> convRawData(outDataVec);
2450 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
2451 rawData, convRawData, type);
2452 printHexString(data: convRawData);
2453 } else {
2454 printHexString(data: rawData);
2455 }
2456
2457 return;
2458 }
2459
2460 if (ComplexType complexTy = llvm::dyn_cast<ComplexType>(elementType)) {
2461 Type complexElementType = complexTy.getElementType();
2462 // Note: The if and else below had a common lambda function which invoked
2463 // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
2464 // and hence was replaced.
2465 if (llvm::isa<IntegerType>(Val: complexElementType)) {
2466 auto valueIt = attr.value_begin<std::complex<APInt>>();
2467 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2468 auto complexValue = *(valueIt + index);
2469 os << "(";
2470 printDenseIntElement(complexValue.real(), os, complexElementType);
2471 os << ",";
2472 printDenseIntElement(complexValue.imag(), os, complexElementType);
2473 os << ")";
2474 });
2475 } else {
2476 auto valueIt = attr.value_begin<std::complex<APFloat>>();
2477 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2478 auto complexValue = *(valueIt + index);
2479 os << "(";
2480 printFloatValue(complexValue.real(), os);
2481 os << ",";
2482 printFloatValue(complexValue.imag(), os);
2483 os << ")";
2484 });
2485 }
2486 } else if (elementType.isIntOrIndex()) {
2487 auto valueIt = attr.value_begin<APInt>();
2488 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2489 printDenseIntElement(*(valueIt + index), os, elementType);
2490 });
2491 } else {
2492 assert(llvm::isa<FloatType>(elementType) && "unexpected element type");
2493 auto valueIt = attr.value_begin<APFloat>();
2494 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2495 printFloatValue(*(valueIt + index), os);
2496 });
2497 }
2498}
2499
2500void AsmPrinter::Impl::printDenseStringElementsAttr(
2501 DenseStringElementsAttr attr) {
2502 ArrayRef<StringRef> data = attr.getRawStringData();
2503 auto printFn = [&](unsigned index) { printEscapedString(str: data[index]); };
2504 printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
2505}
2506
2507void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) {
2508 Type type = attr.getElementType();
2509 unsigned bitwidth = type.isInteger(width: 1) ? 8 : type.getIntOrFloatBitWidth();
2510 unsigned byteSize = bitwidth / 8;
2511 ArrayRef<char> data = attr.getRawData();
2512
2513 auto printElementAt = [&](unsigned i) {
2514 APInt value(bitwidth, 0);
2515 if (bitwidth) {
2516 llvm::LoadIntFromMemory(
2517 IntVal&: value, Src: reinterpret_cast<const uint8_t *>(data.begin() + byteSize * i),
2518 LoadBytes: byteSize);
2519 }
2520 // Print the data as-is or as a float.
2521 if (type.isIntOrIndex()) {
2522 printDenseIntElement(value, os&: getStream(), type);
2523 } else {
2524 APFloat fltVal(llvm::cast<FloatType>(Val&: type).getFloatSemantics(), value);
2525 printFloatValue(apValue: fltVal, os&: getStream());
2526 }
2527 };
2528 llvm::interleaveComma(llvm::seq<unsigned>(0, attr.size()), getStream(),
2529 printElementAt);
2530}
2531
2532void AsmPrinter::Impl::printType(Type type) {
2533 if (!type) {
2534 os << "<<NULL TYPE>>";
2535 return;
2536 }
2537
2538 // Try to print an alias for this type.
2539 if (succeeded(result: printAlias(type)))
2540 return;
2541 return printTypeImpl(type);
2542}
2543
2544void AsmPrinter::Impl::printTypeImpl(Type type) {
2545 TypeSwitch<Type>(type)
2546 .Case<OpaqueType>([&](OpaqueType opaqueTy) {
2547 printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
2548 opaqueTy.getTypeData());
2549 })
2550 .Case<IndexType>([&](Type) { os << "index"; })
2551 .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
2552 .Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
2553 .Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
2554 .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
2555 .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
2556 .Case<BFloat16Type>([&](Type) { os << "bf16"; })
2557 .Case<Float16Type>([&](Type) { os << "f16"; })
2558 .Case<FloatTF32Type>([&](Type) { os << "tf32"; })
2559 .Case<Float32Type>([&](Type) { os << "f32"; })
2560 .Case<Float64Type>([&](Type) { os << "f64"; })
2561 .Case<Float80Type>([&](Type) { os << "f80"; })
2562 .Case<Float128Type>([&](Type) { os << "f128"; })
2563 .Case<IntegerType>([&](IntegerType integerTy) {
2564 if (integerTy.isSigned())
2565 os << 's';
2566 else if (integerTy.isUnsigned())
2567 os << 'u';
2568 os << 'i' << integerTy.getWidth();
2569 })
2570 .Case<FunctionType>([&](FunctionType funcTy) {
2571 os << '(';
2572 interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
2573 os << ") -> ";
2574 ArrayRef<Type> results = funcTy.getResults();
2575 if (results.size() == 1 && !llvm::isa<FunctionType>(results[0])) {
2576 printType(results[0]);
2577 } else {
2578 os << '(';
2579 interleaveComma(results, [&](Type ty) { printType(ty); });
2580 os << ')';
2581 }
2582 })
2583 .Case<VectorType>([&](VectorType vectorTy) {
2584 auto scalableDims = vectorTy.getScalableDims();
2585 os << "vector<";
2586 auto vShape = vectorTy.getShape();
2587 unsigned lastDim = vShape.size();
2588 unsigned dimIdx = 0;
2589 for (dimIdx = 0; dimIdx < lastDim; dimIdx++) {
2590 if (!scalableDims.empty() && scalableDims[dimIdx])
2591 os << '[';
2592 os << vShape[dimIdx];
2593 if (!scalableDims.empty() && scalableDims[dimIdx])
2594 os << ']';
2595 os << 'x';
2596 }
2597 printType(vectorTy.getElementType());
2598 os << '>';
2599 })
2600 .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
2601 os << "tensor<";
2602 printDimensionList(tensorTy.getShape());
2603 if (!tensorTy.getShape().empty())
2604 os << 'x';
2605 printType(tensorTy.getElementType());
2606 // Only print the encoding attribute value if set.
2607 if (tensorTy.getEncoding()) {
2608 os << ", ";
2609 printAttribute(tensorTy.getEncoding());
2610 }
2611 os << '>';
2612 })
2613 .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
2614 os << "tensor<*x";
2615 printType(tensorTy.getElementType());
2616 os << '>';
2617 })
2618 .Case<MemRefType>([&](MemRefType memrefTy) {
2619 os << "memref<";
2620 printDimensionList(memrefTy.getShape());
2621 if (!memrefTy.getShape().empty())
2622 os << 'x';
2623 printType(memrefTy.getElementType());
2624 MemRefLayoutAttrInterface layout = memrefTy.getLayout();
2625 if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) {
2626 os << ", ";
2627 printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
2628 }
2629 // Only print the memory space if it is the non-default one.
2630 if (memrefTy.getMemorySpace()) {
2631 os << ", ";
2632 printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
2633 }
2634 os << '>';
2635 })
2636 .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
2637 os << "memref<*x";
2638 printType(memrefTy.getElementType());
2639 // Only print the memory space if it is the non-default one.
2640 if (memrefTy.getMemorySpace()) {
2641 os << ", ";
2642 printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
2643 }
2644 os << '>';
2645 })
2646 .Case<ComplexType>([&](ComplexType complexTy) {
2647 os << "complex<";
2648 printType(complexTy.getElementType());
2649 os << '>';
2650 })
2651 .Case<TupleType>([&](TupleType tupleTy) {
2652 os << "tuple<";
2653 interleaveComma(tupleTy.getTypes(),
2654 [&](Type type) { printType(type); });
2655 os << '>';
2656 })
2657 .Case<NoneType>([&](Type) { os << "none"; })
2658 .Default([&](Type type) { return printDialectType(type); });
2659}
2660
2661void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2662 ArrayRef<StringRef> elidedAttrs,
2663 bool withKeyword) {
2664 // If there are no attributes, then there is nothing to be done.
2665 if (attrs.empty())
2666 return;
2667
2668 // Functor used to print a filtered attribute list.
2669 auto printFilteredAttributesFn = [&](auto filteredAttrs) {
2670 // Print the 'attributes' keyword if necessary.
2671 if (withKeyword)
2672 os << " attributes";
2673
2674 // Otherwise, print them all out in braces.
2675 os << " {";
2676 interleaveComma(filteredAttrs,
2677 [&](NamedAttribute attr) { printNamedAttribute(attr); });
2678 os << '}';
2679 };
2680
2681 // If no attributes are elided, we can directly print with no filtering.
2682 if (elidedAttrs.empty())
2683 return printFilteredAttributesFn(attrs);
2684
2685 // Otherwise, filter out any attributes that shouldn't be included.
2686 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
2687 elidedAttrs.end());
2688 auto filteredAttrs = llvm::make_filter_range(Range&: attrs, Pred: [&](NamedAttribute attr) {
2689 return !elidedAttrsSet.contains(attr.getName().strref());
2690 });
2691 if (!filteredAttrs.empty())
2692 printFilteredAttributesFn(filteredAttrs);
2693}
2694void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
2695 // Print the name without quotes if possible.
2696 ::printKeywordOrString(keyword: attr.getName().strref(), os);
2697
2698 // Pretty printing elides the attribute value for unit attributes.
2699 if (llvm::isa<UnitAttr>(Val: attr.getValue()))
2700 return;
2701
2702 os << " = ";
2703 printAttribute(attr: attr.getValue());
2704}
2705
2706void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
2707 auto &dialect = attr.getDialect();
2708
2709 // Ask the dialect to serialize the attribute to a string.
2710 std::string attrName;
2711 {
2712 llvm::raw_string_ostream attrNameStr(attrName);
2713 Impl subPrinter(attrNameStr, state);
2714 DialectAsmPrinter printer(subPrinter);
2715 dialect.printAttribute(attr, printer);
2716 }
2717 printDialectSymbol(os, symPrefix: "#", dialectName: dialect.getNamespace(), symString: attrName);
2718}
2719
2720void AsmPrinter::Impl::printDialectType(Type type) {
2721 auto &dialect = type.getDialect();
2722
2723 // Ask the dialect to serialize the type to a string.
2724 std::string typeName;
2725 {
2726 llvm::raw_string_ostream typeNameStr(typeName);
2727 Impl subPrinter(typeNameStr, state);
2728 DialectAsmPrinter printer(subPrinter);
2729 dialect.printType(type, printer);
2730 }
2731 printDialectSymbol(os, symPrefix: "!", dialectName: dialect.getNamespace(), symString: typeName);
2732}
2733
2734void AsmPrinter::Impl::printEscapedString(StringRef str) {
2735 os << "\"";
2736 llvm::printEscapedString(Name: str, Out&: os);
2737 os << "\"";
2738}
2739
2740void AsmPrinter::Impl::printHexString(StringRef str) {
2741 os << "\"0x" << llvm::toHex(Input: str) << "\"";
2742}
2743void AsmPrinter::Impl::printHexString(ArrayRef<char> data) {
2744 printHexString(str: StringRef(data.data(), data.size()));
2745}
2746
2747LogicalResult AsmPrinter::Impl::pushCyclicPrinting(const void *opaquePointer) {
2748 return state.pushCyclicPrinting(opaquePointer);
2749}
2750
2751void AsmPrinter::Impl::popCyclicPrinting() { state.popCyclicPrinting(); }
2752
2753void AsmPrinter::Impl::printDimensionList(ArrayRef<int64_t> shape) {
2754 detail::printDimensionList(stream&: os, shape);
2755}
2756
2757//===--------------------------------------------------------------------===//
2758// AsmPrinter
2759//===--------------------------------------------------------------------===//
2760
2761AsmPrinter::~AsmPrinter() = default;
2762
2763raw_ostream &AsmPrinter::getStream() const {
2764 assert(impl && "expected AsmPrinter::getStream to be overriden");
2765 return impl->getStream();
2766}
2767
2768/// Print the given floating point value in a stablized form.
2769void AsmPrinter::printFloat(const APFloat &value) {
2770 assert(impl && "expected AsmPrinter::printFloat to be overriden");
2771 printFloatValue(apValue: value, os&: impl->getStream());
2772}
2773
2774void AsmPrinter::printType(Type type) {
2775 assert(impl && "expected AsmPrinter::printType to be overriden");
2776 impl->printType(type);
2777}
2778
2779void AsmPrinter::printAttribute(Attribute attr) {
2780 assert(impl && "expected AsmPrinter::printAttribute to be overriden");
2781 impl->printAttribute(attr);
2782}
2783
2784LogicalResult AsmPrinter::printAlias(Attribute attr) {
2785 assert(impl && "expected AsmPrinter::printAlias to be overriden");
2786 return impl->printAlias(attr);
2787}
2788
2789LogicalResult AsmPrinter::printAlias(Type type) {
2790 assert(impl && "expected AsmPrinter::printAlias to be overriden");
2791 return impl->printAlias(type);
2792}
2793
2794void AsmPrinter::printAttributeWithoutType(Attribute attr) {
2795 assert(impl &&
2796 "expected AsmPrinter::printAttributeWithoutType to be overriden");
2797 impl->printAttribute(attr, typeElision: Impl::AttrTypeElision::Must);
2798}
2799
2800void AsmPrinter::printKeywordOrString(StringRef keyword) {
2801 assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden");
2802 ::printKeywordOrString(keyword, os&: impl->getStream());
2803}
2804
2805void AsmPrinter::printString(StringRef keyword) {
2806 assert(impl && "expected AsmPrinter::printString to be overriden");
2807 *this << '"';
2808 printEscapedString(Name: keyword, Out&: getStream());
2809 *this << '"';
2810}
2811
2812void AsmPrinter::printSymbolName(StringRef symbolRef) {
2813 assert(impl && "expected AsmPrinter::printSymbolName to be overriden");
2814 ::printSymbolReference(symbolRef, os&: impl->getStream());
2815}
2816
2817void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) {
2818 assert(impl && "expected AsmPrinter::printResourceHandle to be overriden");
2819 impl->printResourceHandle(resource);
2820}
2821
2822void AsmPrinter::printDimensionList(ArrayRef<int64_t> shape) {
2823 detail::printDimensionList(stream&: getStream(), shape);
2824}
2825
2826LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) {
2827 return impl->pushCyclicPrinting(opaquePointer);
2828}
2829
2830void AsmPrinter::popCyclicPrinting() { impl->popCyclicPrinting(); }
2831
2832//===----------------------------------------------------------------------===//
2833// Affine expressions and maps
2834//===----------------------------------------------------------------------===//
2835
2836void AsmPrinter::Impl::printAffineExpr(
2837 AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
2838 printAffineExprInternal(expr, enclosingTightness: BindingStrength::Weak, printValueName);
2839}
2840
2841void AsmPrinter::Impl::printAffineExprInternal(
2842 AffineExpr expr, BindingStrength enclosingTightness,
2843 function_ref<void(unsigned, bool)> printValueName) {
2844 const char *binopSpelling = nullptr;
2845 switch (expr.getKind()) {
2846 case AffineExprKind::SymbolId: {
2847 unsigned pos = cast<AffineSymbolExpr>(Val&: expr).getPosition();
2848 if (printValueName)
2849 printValueName(pos, /*isSymbol=*/true);
2850 else
2851 os << 's' << pos;
2852 return;
2853 }
2854 case AffineExprKind::DimId: {
2855 unsigned pos = cast<AffineDimExpr>(Val&: expr).getPosition();
2856 if (printValueName)
2857 printValueName(pos, /*isSymbol=*/false);
2858 else
2859 os << 'd' << pos;
2860 return;
2861 }
2862 case AffineExprKind::Constant:
2863 os << cast<AffineConstantExpr>(Val&: expr).getValue();
2864 return;
2865 case AffineExprKind::Add:
2866 binopSpelling = " + ";
2867 break;
2868 case AffineExprKind::Mul:
2869 binopSpelling = " * ";
2870 break;
2871 case AffineExprKind::FloorDiv:
2872 binopSpelling = " floordiv ";
2873 break;
2874 case AffineExprKind::CeilDiv:
2875 binopSpelling = " ceildiv ";
2876 break;
2877 case AffineExprKind::Mod:
2878 binopSpelling = " mod ";
2879 break;
2880 }
2881
2882 auto binOp = cast<AffineBinaryOpExpr>(Val&: expr);
2883 AffineExpr lhsExpr = binOp.getLHS();
2884 AffineExpr rhsExpr = binOp.getRHS();
2885
2886 // Handle tightly binding binary operators.
2887 if (binOp.getKind() != AffineExprKind::Add) {
2888 if (enclosingTightness == BindingStrength::Strong)
2889 os << '(';
2890
2891 // Pretty print multiplication with -1.
2892 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhsExpr);
2893 if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
2894 rhsConst.getValue() == -1) {
2895 os << "-";
2896 printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Strong, printValueName);
2897 if (enclosingTightness == BindingStrength::Strong)
2898 os << ')';
2899 return;
2900 }
2901
2902 printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Strong, printValueName);
2903
2904 os << binopSpelling;
2905 printAffineExprInternal(expr: rhsExpr, enclosingTightness: BindingStrength::Strong, printValueName);
2906
2907 if (enclosingTightness == BindingStrength::Strong)
2908 os << ')';
2909 return;
2910 }
2911
2912 // Print out special "pretty" forms for add.
2913 if (enclosingTightness == BindingStrength::Strong)
2914 os << '(';
2915
2916 // Pretty print addition to a product that has a negative operand as a
2917 // subtraction.
2918 if (auto rhs = dyn_cast<AffineBinaryOpExpr>(Val&: rhsExpr)) {
2919 if (rhs.getKind() == AffineExprKind::Mul) {
2920 AffineExpr rrhsExpr = rhs.getRHS();
2921 if (auto rrhs = dyn_cast<AffineConstantExpr>(Val&: rrhsExpr)) {
2922 if (rrhs.getValue() == -1) {
2923 printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Weak,
2924 printValueName);
2925 os << " - ";
2926 if (rhs.getLHS().getKind() == AffineExprKind::Add) {
2927 printAffineExprInternal(expr: rhs.getLHS(), enclosingTightness: BindingStrength::Strong,
2928 printValueName);
2929 } else {
2930 printAffineExprInternal(expr: rhs.getLHS(), enclosingTightness: BindingStrength::Weak,
2931 printValueName);
2932 }
2933
2934 if (enclosingTightness == BindingStrength::Strong)
2935 os << ')';
2936 return;
2937 }
2938
2939 if (rrhs.getValue() < -1) {
2940 printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Weak,
2941 printValueName);
2942 os << " - ";
2943 printAffineExprInternal(expr: rhs.getLHS(), enclosingTightness: BindingStrength::Strong,
2944 printValueName);
2945 os << " * " << -rrhs.getValue();
2946 if (enclosingTightness == BindingStrength::Strong)
2947 os << ')';
2948 return;
2949 }
2950 }
2951 }
2952 }
2953
2954 // Pretty print addition to a negative number as a subtraction.
2955 if (auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhsExpr)) {
2956 if (rhsConst.getValue() < 0) {
2957 printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Weak, printValueName);
2958 os << " - " << -rhsConst.getValue();
2959 if (enclosingTightness == BindingStrength::Strong)
2960 os << ')';
2961 return;
2962 }
2963 }
2964
2965 printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Weak, printValueName);
2966
2967 os << " + ";
2968 printAffineExprInternal(expr: rhsExpr, enclosingTightness: BindingStrength::Weak, printValueName);
2969
2970 if (enclosingTightness == BindingStrength::Strong)
2971 os << ')';
2972}
2973
2974void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) {
2975 printAffineExprInternal(expr, enclosingTightness: BindingStrength::Weak);
2976 isEq ? os << " == 0" : os << " >= 0";
2977}
2978
2979void AsmPrinter::Impl::printAffineMap(AffineMap map) {
2980 // Dimension identifiers.
2981 os << '(';
2982 for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
2983 os << 'd' << i << ", ";
2984 if (map.getNumDims() >= 1)
2985 os << 'd' << map.getNumDims() - 1;
2986 os << ')';
2987
2988 // Symbolic identifiers.
2989 if (map.getNumSymbols() != 0) {
2990 os << '[';
2991 for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
2992 os << 's' << i << ", ";
2993 if (map.getNumSymbols() >= 1)
2994 os << 's' << map.getNumSymbols() - 1;
2995 os << ']';
2996 }
2997
2998 // Result affine expressions.
2999 os << " -> (";
3000 interleaveComma(c: map.getResults(),
3001 eachFn: [&](AffineExpr expr) { printAffineExpr(expr); });
3002 os << ')';
3003}
3004
3005void AsmPrinter::Impl::printIntegerSet(IntegerSet set) {
3006 // Dimension identifiers.
3007 os << '(';
3008 for (unsigned i = 1; i < set.getNumDims(); ++i)
3009 os << 'd' << i - 1 << ", ";
3010 if (set.getNumDims() >= 1)
3011 os << 'd' << set.getNumDims() - 1;
3012 os << ')';
3013
3014 // Symbolic identifiers.
3015 if (set.getNumSymbols() != 0) {
3016 os << '[';
3017 for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
3018 os << 's' << i << ", ";
3019 if (set.getNumSymbols() >= 1)
3020 os << 's' << set.getNumSymbols() - 1;
3021 os << ']';
3022 }
3023
3024 // Print constraints.
3025 os << " : (";
3026 int numConstraints = set.getNumConstraints();
3027 for (int i = 1; i < numConstraints; ++i) {
3028 printAffineConstraint(expr: set.getConstraint(idx: i - 1), isEq: set.isEq(idx: i - 1));
3029 os << ", ";
3030 }
3031 if (numConstraints >= 1)
3032 printAffineConstraint(expr: set.getConstraint(idx: numConstraints - 1),
3033 isEq: set.isEq(idx: numConstraints - 1));
3034 os << ')';
3035}
3036
3037//===----------------------------------------------------------------------===//
3038// OperationPrinter
3039//===----------------------------------------------------------------------===//
3040
3041namespace {
3042/// This class contains the logic for printing operations, regions, and blocks.
3043class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
3044public:
3045 using Impl = AsmPrinter::Impl;
3046 using Impl::printType;
3047
3048 explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state)
3049 : Impl(os, state), OpAsmPrinter(static_cast<Impl &>(*this)) {}
3050
3051 /// Print the given top-level operation.
3052 void printTopLevelOperation(Operation *op);
3053
3054 /// Print the given operation, including its left-hand side and its right-hand
3055 /// side, with its indent and location.
3056 void printFullOpWithIndentAndLoc(Operation *op);
3057 /// Print the given operation, including its left-hand side and its right-hand
3058 /// side, but not including indentation and location.
3059 void printFullOp(Operation *op);
3060 /// Print the right-hand size of the given operation in the custom or generic
3061 /// form.
3062 void printCustomOrGenericOp(Operation *op) override;
3063 /// Print the right-hand side of the given operation in the generic form.
3064 void printGenericOp(Operation *op, bool printOpName) override;
3065
3066 /// Print the name of the given block.
3067 void printBlockName(Block *block);
3068
3069 /// Print the given block. If 'printBlockArgs' is false, the arguments of the
3070 /// block are not printed. If 'printBlockTerminator' is false, the terminator
3071 /// operation of the block is not printed.
3072 void print(Block *block, bool printBlockArgs = true,
3073 bool printBlockTerminator = true);
3074
3075 /// Print the ID of the given value, optionally with its result number.
3076 void printValueID(Value value, bool printResultNo = true,
3077 raw_ostream *streamOverride = nullptr) const;
3078
3079 /// Print the ID of the given operation.
3080 void printOperationID(Operation *op,
3081 raw_ostream *streamOverride = nullptr) const;
3082
3083 //===--------------------------------------------------------------------===//
3084 // OpAsmPrinter methods
3085 //===--------------------------------------------------------------------===//
3086
3087 /// Print a loc(...) specifier if printing debug info is enabled. Locations
3088 /// may be deferred with an alias.
3089 void printOptionalLocationSpecifier(Location loc) override {
3090 printTrailingLocation(loc);
3091 }
3092
3093 /// Print a newline and indent the printer to the start of the current
3094 /// operation.
3095 void printNewline() override {
3096 os << newLine;
3097 os.indent(NumSpaces: currentIndent);
3098 }
3099
3100 /// Increase indentation.
3101 void increaseIndent() override { currentIndent += indentWidth; }
3102
3103 /// Decrease indentation.
3104 void decreaseIndent() override { currentIndent -= indentWidth; }
3105
3106 /// Print a block argument in the usual format of:
3107 /// %ssaName : type {attr1=42} loc("here")
3108 /// where location printing is controlled by the standard internal option.
3109 /// You may pass omitType=true to not print a type, and pass an empty
3110 /// attribute list if you don't care for attributes.
3111 void printRegionArgument(BlockArgument arg,
3112 ArrayRef<NamedAttribute> argAttrs = {},
3113 bool omitType = false) override;
3114
3115 /// Print the ID for the given value.
3116 void printOperand(Value value) override { printValueID(value); }
3117 void printOperand(Value value, raw_ostream &os) override {
3118 printValueID(value, /*printResultNo=*/true, streamOverride: &os);
3119 }
3120
3121 /// Print an optional attribute dictionary with a given set of elided values.
3122 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
3123 ArrayRef<StringRef> elidedAttrs = {}) override {
3124 Impl::printOptionalAttrDict(attrs, elidedAttrs);
3125 }
3126 void printOptionalAttrDictWithKeyword(
3127 ArrayRef<NamedAttribute> attrs,
3128 ArrayRef<StringRef> elidedAttrs = {}) override {
3129 Impl::printOptionalAttrDict(attrs, elidedAttrs,
3130 /*withKeyword=*/true);
3131 }
3132
3133 /// Print the given successor.
3134 void printSuccessor(Block *successor) override;
3135
3136 /// Print an operation successor with the operands used for the block
3137 /// arguments.
3138 void printSuccessorAndUseList(Block *successor,
3139 ValueRange succOperands) override;
3140
3141 /// Print the given region.
3142 void printRegion(Region &region, bool printEntryBlockArgs,
3143 bool printBlockTerminators, bool printEmptyBlock) override;
3144
3145 /// Renumber the arguments for the specified region to the same names as the
3146 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
3147 /// operations. If any entry in namesToUse is null, the corresponding
3148 /// argument name is left alone.
3149 void shadowRegionArgs(Region &region, ValueRange namesToUse) override {
3150 state.getSSANameState().shadowRegionArgs(region, namesToUse);
3151 }
3152
3153 /// Print the given affine map with the symbol and dimension operands printed
3154 /// inline with the map.
3155 void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
3156 ValueRange operands) override;
3157
3158 /// Print the given affine expression with the symbol and dimension operands
3159 /// printed inline with the expression.
3160 void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
3161 ValueRange symOperands) override;
3162
3163 /// Print users of this operation or id of this operation if it has no result.
3164 void printUsersComment(Operation *op);
3165
3166 /// Print users of this block arg.
3167 void printUsersComment(BlockArgument arg);
3168
3169 /// Print the users of a value.
3170 void printValueUsers(Value value);
3171
3172 /// Print either the ids of the result values or the id of the operation if
3173 /// the operation has no results.
3174 void printUserIDs(Operation *user, bool prefixComma = false);
3175
3176private:
3177 /// This class represents a resource builder implementation for the MLIR
3178 /// textual assembly format.
3179 class ResourceBuilder : public AsmResourceBuilder {
3180 public:
3181 using ValueFn = function_ref<void(raw_ostream &)>;
3182 using PrintFn = function_ref<void(StringRef, ValueFn)>;
3183
3184 ResourceBuilder(PrintFn printFn) : printFn(printFn) {}
3185 ~ResourceBuilder() override = default;
3186
3187 void buildBool(StringRef key, bool data) final {
3188 printFn(key, [&](raw_ostream &os) { os << (data ? "true" : "false"); });
3189 }
3190
3191 void buildString(StringRef key, StringRef data) final {
3192 printFn(key, [&](raw_ostream &os) {
3193 os << "\"";
3194 llvm::printEscapedString(Name: data, Out&: os);
3195 os << "\"";
3196 });
3197 }
3198
3199 void buildBlob(StringRef key, ArrayRef<char> data,
3200 uint32_t dataAlignment) final {
3201 printFn(key, [&](raw_ostream &os) {
3202 // Store the blob in a hex string containing the alignment and the data.
3203 llvm::support::ulittle32_t dataAlignmentLE(dataAlignment);
3204 os << "\"0x"
3205 << llvm::toHex(Input: StringRef(reinterpret_cast<char *>(&dataAlignmentLE),
3206 sizeof(dataAlignment)))
3207 << llvm::toHex(Input: StringRef(data.data(), data.size())) << "\"";
3208 });
3209 }
3210
3211 private:
3212 PrintFn printFn;
3213 };
3214
3215 /// Print the metadata dictionary for the file, eliding it if it is empty.
3216 void printFileMetadataDictionary(Operation *op);
3217
3218 /// Print the resource sections for the file metadata dictionary.
3219 /// `checkAddMetadataDict` is used to indicate that metadata is going to be
3220 /// added, and the file metadata dictionary should be started if it hasn't
3221 /// yet.
3222 void printResourceFileMetadata(function_ref<void()> checkAddMetadataDict,
3223 Operation *op);
3224
3225 // Contains the stack of default dialects to use when printing regions.
3226 // A new dialect is pushed to the stack before parsing regions nested under an
3227 // operation implementing `OpAsmOpInterface`, and popped when done. At the
3228 // top-level we start with "builtin" as the default, so that the top-level
3229 // `module` operation prints as-is.
3230 SmallVector<StringRef> defaultDialectStack{"builtin"};
3231
3232 /// The number of spaces used for indenting nested operations.
3233 const static unsigned indentWidth = 2;
3234
3235 // This is the current indentation level for nested structures.
3236 unsigned currentIndent = 0;
3237};
3238} // namespace
3239
3240void OperationPrinter::printTopLevelOperation(Operation *op) {
3241 // Output the aliases at the top level that can't be deferred.
3242 state.getAliasState().printNonDeferredAliases(p&: *this, newLine);
3243
3244 // Print the module.
3245 printFullOpWithIndentAndLoc(op);
3246 os << newLine;
3247
3248 // Output the aliases at the top level that can be deferred.
3249 state.getAliasState().printDeferredAliases(p&: *this, newLine);
3250
3251 // Output any file level metadata.
3252 printFileMetadataDictionary(op);
3253}
3254
3255void OperationPrinter::printFileMetadataDictionary(Operation *op) {
3256 bool sawMetadataEntry = false;
3257 auto checkAddMetadataDict = [&] {
3258 if (!std::exchange(obj&: sawMetadataEntry, new_val: true))
3259 os << newLine << "{-#" << newLine;
3260 };
3261
3262 // Add the various types of metadata.
3263 printResourceFileMetadata(checkAddMetadataDict, op);
3264
3265 // If the file dictionary exists, close it.
3266 if (sawMetadataEntry)
3267 os << newLine << "#-}" << newLine;
3268}
3269
3270void OperationPrinter::printResourceFileMetadata(
3271 function_ref<void()> checkAddMetadataDict, Operation *op) {
3272 // Functor used to add data entries to the file metadata dictionary.
3273 bool hadResource = false;
3274 bool needResourceComma = false;
3275 bool needEntryComma = false;
3276 auto processProvider = [&](StringRef dictName, StringRef name, auto &provider,
3277 auto &&...providerArgs) {
3278 bool hadEntry = false;
3279 auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) {
3280 checkAddMetadataDict();
3281
3282 auto printFormatting = [&]() {
3283 // Emit the top-level resource entry if we haven't yet.
3284 if (!std::exchange(obj&: hadResource, new_val: true)) {
3285 if (needResourceComma)
3286 os << "," << newLine;
3287 os << " " << dictName << "_resources: {" << newLine;
3288 }
3289 // Emit the parent resource entry if we haven't yet.
3290 if (!std::exchange(obj&: hadEntry, new_val: true)) {
3291 if (needEntryComma)
3292 os << "," << newLine;
3293 os << " " << name << ": {" << newLine;
3294 } else {
3295 os << "," << newLine;
3296 }
3297 };
3298
3299 std::optional<uint64_t> charLimit =
3300 printerFlags.getLargeResourceStringLimit();
3301 if (charLimit.has_value()) {
3302 std::string resourceStr;
3303 llvm::raw_string_ostream ss(resourceStr);
3304 valueFn(ss);
3305
3306 // Only print entry if it's string is small enough
3307 if (resourceStr.size() > charLimit.value())
3308 return;
3309
3310 printFormatting();
3311 os << " " << key << ": " << resourceStr;
3312 } else {
3313 printFormatting();
3314 os << " " << key << ": ";
3315 valueFn(os);
3316 }
3317 };
3318 ResourceBuilder entryBuilder(printFn);
3319 provider.buildResources(op, providerArgs..., entryBuilder);
3320
3321 needEntryComma |= hadEntry;
3322 if (hadEntry)
3323 os << newLine << " }";
3324 };
3325
3326 // Print the `dialect_resources` section if we have any dialects with
3327 // resources.
3328 for (const OpAsmDialectInterface &interface : state.getDialectInterfaces()) {
3329 auto &dialectResources = state.getDialectResources();
3330 StringRef name = interface.getDialect()->getNamespace();
3331 auto it = dialectResources.find(Val: interface.getDialect());
3332 if (it != dialectResources.end())
3333 processProvider("dialect", name, interface, it->second);
3334 else
3335 processProvider("dialect", name, interface,
3336 SetVector<AsmDialectResourceHandle>());
3337 }
3338 if (hadResource)
3339 os << newLine << " }";
3340
3341 // Print the `external_resources` section if we have any external clients with
3342 // resources.
3343 needEntryComma = false;
3344 needResourceComma = hadResource;
3345 hadResource = false;
3346 for (const auto &printer : state.getResourcePrinters())
3347 processProvider("external", printer.getName(), printer);
3348 if (hadResource)
3349 os << newLine << " }";
3350}
3351
3352/// Print a block argument in the usual format of:
3353/// %ssaName : type {attr1=42} loc("here")
3354/// where location printing is controlled by the standard internal option.
3355/// You may pass omitType=true to not print a type, and pass an empty
3356/// attribute list if you don't care for attributes.
3357void OperationPrinter::printRegionArgument(BlockArgument arg,
3358 ArrayRef<NamedAttribute> argAttrs,
3359 bool omitType) {
3360 printOperand(value: arg);
3361 if (!omitType) {
3362 os << ": ";
3363 printType(type: arg.getType());
3364 }
3365 printOptionalAttrDict(attrs: argAttrs);
3366 // TODO: We should allow location aliases on block arguments.
3367 printTrailingLocation(loc: arg.getLoc(), /*allowAlias*/ false);
3368}
3369
3370void OperationPrinter::printFullOpWithIndentAndLoc(Operation *op) {
3371 // Track the location of this operation.
3372 state.registerOperationLocation(op, line: newLine.curLine, col: currentIndent);
3373
3374 os.indent(NumSpaces: currentIndent);
3375 printFullOp(op);
3376 printTrailingLocation(loc: op->getLoc());
3377 if (printerFlags.shouldPrintValueUsers())
3378 printUsersComment(op);
3379}
3380
3381void OperationPrinter::printFullOp(Operation *op) {
3382 if (size_t numResults = op->getNumResults()) {
3383 auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
3384 printValueID(value: op->getResult(idx: resultNo), /*printResultNo=*/false);
3385 if (resultCount > 1)
3386 os << ':' << resultCount;
3387 };
3388
3389 // Check to see if this operation has multiple result groups.
3390 ArrayRef<int> resultGroups = state.getSSANameState().getOpResultGroups(op);
3391 if (!resultGroups.empty()) {
3392 // Interleave the groups excluding the last one, this one will be handled
3393 // separately.
3394 interleaveComma(c: llvm::seq<int>(Begin: 0, End: resultGroups.size() - 1), eachFn: [&](int i) {
3395 printResultGroup(resultGroups[i],
3396 resultGroups[i + 1] - resultGroups[i]);
3397 });
3398 os << ", ";
3399 printResultGroup(resultGroups.back(), numResults - resultGroups.back());
3400
3401 } else {
3402 printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
3403 }
3404
3405 os << " = ";
3406 }
3407
3408 printCustomOrGenericOp(op);
3409}
3410
3411void OperationPrinter::printUsersComment(Operation *op) {
3412 unsigned numResults = op->getNumResults();
3413 if (!numResults && op->getNumOperands()) {
3414 os << " // id: ";
3415 printOperationID(op);
3416 } else if (numResults && op->use_empty()) {
3417 os << " // unused";
3418 } else if (numResults && !op->use_empty()) {
3419 // Print "user" if the operation has one result used to compute one other
3420 // result, or is used in one operation with no result.
3421 unsigned usedInNResults = 0;
3422 unsigned usedInNOperations = 0;
3423 SmallPtrSet<Operation *, 1> userSet;
3424 for (Operation *user : op->getUsers()) {
3425 if (userSet.insert(Ptr: user).second) {
3426 ++usedInNOperations;
3427 usedInNResults += user->getNumResults();
3428 }
3429 }
3430
3431 // We already know that users is not empty.
3432 bool exactlyOneUniqueUse =
3433 usedInNResults <= 1 && usedInNOperations <= 1 && numResults == 1;
3434 os << " // " << (exactlyOneUniqueUse ? "user" : "users") << ": ";
3435 bool shouldPrintBrackets = numResults > 1;
3436 auto printOpResult = [&](OpResult opResult) {
3437 if (shouldPrintBrackets)
3438 os << "(";
3439 printValueUsers(value: opResult);
3440 if (shouldPrintBrackets)
3441 os << ")";
3442 };
3443
3444 interleaveComma(c: op->getResults(), eachFn: printOpResult);
3445 }
3446}
3447
3448void OperationPrinter::printUsersComment(BlockArgument arg) {
3449 os << "// ";
3450 printValueID(value: arg);
3451 if (arg.use_empty()) {
3452 os << " is unused";
3453 } else {
3454 os << " is used by ";
3455 printValueUsers(value: arg);
3456 }
3457 os << newLine;
3458}
3459
3460void OperationPrinter::printValueUsers(Value value) {
3461 if (value.use_empty())
3462 os << "unused";
3463
3464 // One value might be used as the operand of an operation more than once.
3465 // Only print the operations results once in that case.
3466 SmallPtrSet<Operation *, 1> userSet;
3467 for (auto [index, user] : enumerate(First: value.getUsers())) {
3468 if (userSet.insert(Ptr: user).second)
3469 printUserIDs(user, prefixComma: index);
3470 }
3471}
3472
3473void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) {
3474 if (prefixComma)
3475 os << ", ";
3476
3477 if (!user->getNumResults()) {
3478 printOperationID(op: user);
3479 } else {
3480 interleaveComma(c: user->getResults(),
3481 eachFn: [this](Value result) { printValueID(value: result); });
3482 }
3483}
3484
3485void OperationPrinter::printCustomOrGenericOp(Operation *op) {
3486 // If requested, always print the generic form.
3487 if (!printerFlags.shouldPrintGenericOpForm()) {
3488 // Check to see if this is a known operation. If so, use the registered
3489 // custom printer hook.
3490 if (auto opInfo = op->getRegisteredInfo()) {
3491 opInfo->printAssembly(op, p&: *this, defaultDialect: defaultDialectStack.back());
3492 return;
3493 }
3494 // Otherwise try to dispatch to the dialect, if available.
3495 if (Dialect *dialect = op->getDialect()) {
3496 if (auto opPrinter = dialect->getOperationPrinter(op)) {
3497 // Print the op name first.
3498 StringRef name = op->getName().getStringRef();
3499 // Only drop the default dialect prefix when it cannot lead to
3500 // ambiguities.
3501 if (name.count(C: '.') == 1)
3502 name.consume_front(Prefix: (defaultDialectStack.back() + ".").str());
3503 os << name;
3504
3505 // Print the rest of the op now.
3506 opPrinter(op, *this);
3507 return;
3508 }
3509 }
3510 }
3511
3512 // Otherwise print with the generic assembly form.
3513 printGenericOp(op, /*printOpName=*/true);
3514}
3515
3516void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
3517 if (printOpName)
3518 printEscapedString(str: op->getName().getStringRef());
3519 os << '(';
3520 interleaveComma(c: op->getOperands(), eachFn: [&](Value value) { printValueID(value); });
3521 os << ')';
3522
3523 // For terminators, print the list of successors and their operands.
3524 if (op->getNumSuccessors() != 0) {
3525 os << '[';
3526 interleaveComma(c: op->getSuccessors(),
3527 eachFn: [&](Block *successor) { printBlockName(block: successor); });
3528 os << ']';
3529 }
3530
3531 // Print the properties.
3532 if (Attribute prop = op->getPropertiesAsAttribute()) {
3533 os << " <";
3534 Impl::printAttribute(attr: prop);
3535 os << '>';
3536 }
3537
3538 // Print regions.
3539 if (op->getNumRegions() != 0) {
3540 os << " (";
3541 interleaveComma(c: op->getRegions(), eachFn: [&](Region &region) {
3542 printRegion(region, /*printEntryBlockArgs=*/true,
3543 /*printBlockTerminators=*/true, /*printEmptyBlock=*/true);
3544 });
3545 os << ')';
3546 }
3547
3548 printOptionalAttrDict(attrs: op->getPropertiesStorage()
3549 ? llvm::to_vector(op->getDiscardableAttrs())
3550 : op->getAttrs());
3551
3552 // Print the type signature of the operation.
3553 os << " : ";
3554 printFunctionalType(op);
3555}
3556
3557void OperationPrinter::printBlockName(Block *block) {
3558 os << state.getSSANameState().getBlockInfo(block).name;
3559}
3560
3561void OperationPrinter::print(Block *block, bool printBlockArgs,
3562 bool printBlockTerminator) {
3563 // Print the block label and argument list if requested.
3564 if (printBlockArgs) {
3565 os.indent(NumSpaces: currentIndent);
3566 printBlockName(block);
3567
3568 // Print the argument list if non-empty.
3569 if (!block->args_empty()) {
3570 os << '(';
3571 interleaveComma(c: block->getArguments(), eachFn: [&](BlockArgument arg) {
3572 printValueID(value: arg);
3573 os << ": ";
3574 printType(type: arg.getType());
3575 // TODO: We should allow location aliases on block arguments.
3576 printTrailingLocation(loc: arg.getLoc(), /*allowAlias*/ false);
3577 });
3578 os << ')';
3579 }
3580 os << ':';
3581
3582 // Print out some context information about the predecessors of this block.
3583 if (!block->getParent()) {
3584 os << " // block is not in a region!";
3585 } else if (block->hasNoPredecessors()) {
3586 if (!block->isEntryBlock())
3587 os << " // no predecessors";
3588 } else if (auto *pred = block->getSinglePredecessor()) {
3589 os << " // pred: ";
3590 printBlockName(block: pred);
3591 } else {
3592 // We want to print the predecessors in a stable order, not in
3593 // whatever order the use-list is in, so gather and sort them.
3594 SmallVector<BlockInfo, 4> predIDs;
3595 for (auto *pred : block->getPredecessors())
3596 predIDs.push_back(Elt: state.getSSANameState().getBlockInfo(block: pred));
3597 llvm::sort(C&: predIDs, Comp: [](BlockInfo lhs, BlockInfo rhs) {
3598 return lhs.ordering < rhs.ordering;
3599 });
3600
3601 os << " // " << predIDs.size() << " preds: ";
3602
3603 interleaveComma(c: predIDs, eachFn: [&](BlockInfo pred) { os << pred.name; });
3604 }
3605 os << newLine;
3606 }
3607
3608 currentIndent += indentWidth;
3609
3610 if (printerFlags.shouldPrintValueUsers()) {
3611 for (BlockArgument arg : block->getArguments()) {
3612 os.indent(NumSpaces: currentIndent);
3613 printUsersComment(arg);
3614 }
3615 }
3616
3617 bool hasTerminator =
3618 !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
3619 auto range = llvm::make_range(
3620 x: block->begin(),
3621 y: std::prev(x: block->end(),
3622 n: (!hasTerminator || printBlockTerminator) ? 0 : 1));
3623 for (auto &op : range) {
3624 printFullOpWithIndentAndLoc(op: &op);
3625 os << newLine;
3626 }
3627 currentIndent -= indentWidth;
3628}
3629
3630void OperationPrinter::printValueID(Value value, bool printResultNo,
3631 raw_ostream *streamOverride) const {
3632 state.getSSANameState().printValueID(value, printResultNo,
3633 stream&: streamOverride ? *streamOverride : os);
3634}
3635
3636void OperationPrinter::printOperationID(Operation *op,
3637 raw_ostream *streamOverride) const {
3638 state.getSSANameState().printOperationID(op, stream&: streamOverride ? *streamOverride
3639 : os);
3640}
3641
3642void OperationPrinter::printSuccessor(Block *successor) {
3643 printBlockName(block: successor);
3644}
3645
3646void OperationPrinter::printSuccessorAndUseList(Block *successor,
3647 ValueRange succOperands) {
3648 printBlockName(block: successor);
3649 if (succOperands.empty())
3650 return;
3651
3652 os << '(';
3653 interleaveComma(c: succOperands,
3654 eachFn: [this](Value operand) { printValueID(value: operand); });
3655 os << " : ";
3656 interleaveComma(c: succOperands,
3657 eachFn: [this](Value operand) { printType(type: operand.getType()); });
3658 os << ')';
3659}
3660
3661void OperationPrinter::printRegion(Region &region, bool printEntryBlockArgs,
3662 bool printBlockTerminators,
3663 bool printEmptyBlock) {
3664 if (printerFlags.shouldSkipRegions()) {
3665 os << "{...}";
3666 return;
3667 }
3668 os << "{" << newLine;
3669 if (!region.empty()) {
3670 auto restoreDefaultDialect =
3671 llvm::make_scope_exit(F: [&]() { defaultDialectStack.pop_back(); });
3672 if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp()))
3673 defaultDialectStack.push_back(Elt: iface.getDefaultDialect());
3674 else
3675 defaultDialectStack.push_back(Elt: "");
3676
3677 auto *entryBlock = &region.front();
3678 // Force printing the block header if printEmptyBlock is set and the block
3679 // is empty or if printEntryBlockArgs is set and there are arguments to
3680 // print.
3681 bool shouldAlwaysPrintBlockHeader =
3682 (printEmptyBlock && entryBlock->empty()) ||
3683 (printEntryBlockArgs && entryBlock->getNumArguments() != 0);
3684 print(block: entryBlock, printBlockArgs: shouldAlwaysPrintBlockHeader, printBlockTerminator: printBlockTerminators);
3685 for (auto &b : llvm::drop_begin(RangeOrContainer&: region.getBlocks(), N: 1))
3686 print(block: &b);
3687 }
3688 os.indent(NumSpaces: currentIndent) << "}";
3689}
3690
3691void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
3692 ValueRange operands) {
3693 if (!mapAttr) {
3694 os << "<<NULL AFFINE MAP>>";
3695 return;
3696 }
3697 AffineMap map = mapAttr.getValue();
3698 unsigned numDims = map.getNumDims();
3699 auto printValueName = [&](unsigned pos, bool isSymbol) {
3700 unsigned index = isSymbol ? numDims + pos : pos;
3701 assert(index < operands.size());
3702 if (isSymbol)
3703 os << "symbol(";
3704 printValueID(value: operands[index]);
3705 if (isSymbol)
3706 os << ')';
3707 };
3708
3709 interleaveComma(c: map.getResults(), eachFn: [&](AffineExpr expr) {
3710 printAffineExpr(expr, printValueName);
3711 });
3712}
3713
3714void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr,
3715 ValueRange dimOperands,
3716 ValueRange symOperands) {
3717 auto printValueName = [&](unsigned pos, bool isSymbol) {
3718 if (!isSymbol)
3719 return printValueID(value: dimOperands[pos]);
3720 os << "symbol(";
3721 printValueID(value: symOperands[pos]);
3722 os << ')';
3723 };
3724 printAffineExpr(expr, printValueName);
3725}
3726
3727//===----------------------------------------------------------------------===//
3728// print and dump methods
3729//===----------------------------------------------------------------------===//
3730
3731void Attribute::print(raw_ostream &os, bool elideType) const {
3732 if (!*this) {
3733 os << "<<NULL ATTRIBUTE>>";
3734 return;
3735 }
3736
3737 AsmState state(getContext());
3738 print(os, state, elideType);
3739}
3740void Attribute::print(raw_ostream &os, AsmState &state, bool elideType) const {
3741 using AttrTypeElision = AsmPrinter::Impl::AttrTypeElision;
3742 AsmPrinter::Impl(os, state.getImpl())
3743 .printAttribute(attr: *this, typeElision: elideType ? AttrTypeElision::Must
3744 : AttrTypeElision::Never);
3745}
3746
3747void Attribute::dump() const {
3748 print(os&: llvm::errs());
3749 llvm::errs() << "\n";
3750}
3751
3752void Attribute::printStripped(raw_ostream &os, AsmState &state) const {
3753 if (!*this) {
3754 os << "<<NULL ATTRIBUTE>>";
3755 return;
3756 }
3757
3758 AsmPrinter::Impl subPrinter(os, state.getImpl());
3759 if (succeeded(result: subPrinter.printAlias(attr: *this)))
3760 return;
3761
3762 auto &dialect = this->getDialect();
3763 uint64_t posPrior = os.tell();
3764 DialectAsmPrinter printer(subPrinter);
3765 dialect.printAttribute(*this, printer);
3766 if (posPrior != os.tell())
3767 return;
3768
3769 // Fallback to printing with prefix if the above failed to write anything
3770 // to the output stream.
3771 print(os, state);
3772}
3773void Attribute::printStripped(raw_ostream &os) const {
3774 if (!*this) {
3775 os << "<<NULL ATTRIBUTE>>";
3776 return;
3777 }
3778
3779 AsmState state(getContext());
3780 printStripped(os, state);
3781}
3782
3783void Type::print(raw_ostream &os) const {
3784 if (!*this) {
3785 os << "<<NULL TYPE>>";
3786 return;
3787 }
3788
3789 AsmState state(getContext());
3790 print(os, state);
3791}
3792void Type::print(raw_ostream &os, AsmState &state) const {
3793 AsmPrinter::Impl(os, state.getImpl()).printType(type: *this);
3794}
3795
3796void Type::dump() const {
3797 print(os&: llvm::errs());
3798 llvm::errs() << "\n";
3799}
3800
3801void AffineMap::dump() const {
3802 print(os&: llvm::errs());
3803 llvm::errs() << "\n";
3804}
3805
3806void IntegerSet::dump() const {
3807 print(os&: llvm::errs());
3808 llvm::errs() << "\n";
3809}
3810
3811void AffineExpr::print(raw_ostream &os) const {
3812 if (!expr) {
3813 os << "<<NULL AFFINE EXPR>>";
3814 return;
3815 }
3816 AsmState state(getContext());
3817 AsmPrinter::Impl(os, state.getImpl()).printAffineExpr(expr: *this);
3818}
3819
3820void AffineExpr::dump() const {
3821 print(os&: llvm::errs());
3822 llvm::errs() << "\n";
3823}
3824
3825void AffineMap::print(raw_ostream &os) const {
3826 if (!map) {
3827 os << "<<NULL AFFINE MAP>>";
3828 return;
3829 }
3830 AsmState state(getContext());
3831 AsmPrinter::Impl(os, state.getImpl()).printAffineMap(map: *this);
3832}
3833
3834void IntegerSet::print(raw_ostream &os) const {
3835 AsmState state(getContext());
3836 AsmPrinter::Impl(os, state.getImpl()).printIntegerSet(set: *this);
3837}
3838
3839void Value::print(raw_ostream &os) const { print(os, flags: OpPrintingFlags()); }
3840void Value::print(raw_ostream &os, const OpPrintingFlags &flags) const {
3841 if (!impl) {
3842 os << "<<NULL VALUE>>";
3843 return;
3844 }
3845
3846 if (auto *op = getDefiningOp())
3847 return op->print(os, flags);
3848 // TODO: Improve BlockArgument print'ing.
3849 BlockArgument arg = llvm::cast<BlockArgument>(Val: *this);
3850 os << "<block argument> of type '" << arg.getType()
3851 << "' at index: " << arg.getArgNumber();
3852}
3853void Value::print(raw_ostream &os, AsmState &state) const {
3854 if (!impl) {
3855 os << "<<NULL VALUE>>";
3856 return;
3857 }
3858
3859 if (auto *op = getDefiningOp())
3860 return op->print(os, state);
3861
3862 // TODO: Improve BlockArgument print'ing.
3863 BlockArgument arg = llvm::cast<BlockArgument>(Val: *this);
3864 os << "<block argument> of type '" << arg.getType()
3865 << "' at index: " << arg.getArgNumber();
3866}
3867
3868void Value::dump() const {
3869 print(os&: llvm::errs());
3870 llvm::errs() << "\n";
3871}
3872
3873void Value::printAsOperand(raw_ostream &os, AsmState &state) const {
3874 // TODO: This doesn't necessarily capture all potential cases.
3875 // Currently, region arguments can be shadowed when printing the main
3876 // operation. If the IR hasn't been printed, this will produce the old SSA
3877 // name and not the shadowed name.
3878 state.getImpl().getSSANameState().printValueID(value: *this, /*printResultNo=*/true,
3879 stream&: os);
3880}
3881
3882static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
3883 do {
3884 // If we are printing local scope, stop at the first operation that is
3885 // isolated from above.
3886 if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
3887 break;
3888
3889 // Otherwise, traverse up to the next parent.
3890 Operation *parentOp = op->getParentOp();
3891 if (!parentOp)
3892 break;
3893 op = parentOp;
3894 } while (true);
3895 return op;
3896}
3897
3898void Value::printAsOperand(raw_ostream &os,
3899 const OpPrintingFlags &flags) const {
3900 Operation *op;
3901 if (auto result = llvm::dyn_cast<OpResult>(Val: *this)) {
3902 op = result.getOwner();
3903 } else {
3904 op = llvm::cast<BlockArgument>(Val: *this).getOwner()->getParentOp();
3905 if (!op) {
3906 os << "<<UNKNOWN SSA VALUE>>";
3907 return;
3908 }
3909 }
3910 op = findParent(op, shouldUseLocalScope: flags.shouldUseLocalScope());
3911 AsmState state(op, flags);
3912 printAsOperand(os, state);
3913}
3914
3915void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
3916 // Find the operation to number from based upon the provided flags.
3917 Operation *op = findParent(op: this, shouldUseLocalScope: printerFlags.shouldUseLocalScope());
3918 AsmState state(op, printerFlags);
3919 print(os, state);
3920}
3921void Operation::print(raw_ostream &os, AsmState &state) {
3922 OperationPrinter printer(os, state.getImpl());
3923 if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) {
3924 state.getImpl().initializeAliases(op: this);
3925 printer.printTopLevelOperation(op: this);
3926 } else {
3927 printer.printFullOpWithIndentAndLoc(op: this);
3928 }
3929}
3930
3931void Operation::dump() {
3932 print(os&: llvm::errs(), printerFlags: OpPrintingFlags().useLocalScope());
3933 llvm::errs() << "\n";
3934}
3935
3936void Block::print(raw_ostream &os) {
3937 Operation *parentOp = getParentOp();
3938 if (!parentOp) {
3939 os << "<<UNLINKED BLOCK>>\n";
3940 return;
3941 }
3942 // Get the top-level op.
3943 while (auto *nextOp = parentOp->getParentOp())
3944 parentOp = nextOp;
3945
3946 AsmState state(parentOp);
3947 print(os, state);
3948}
3949void Block::print(raw_ostream &os, AsmState &state) {
3950 OperationPrinter(os, state.getImpl()).print(block: this);
3951}
3952
3953void Block::dump() { print(os&: llvm::errs()); }
3954
3955/// Print out the name of the block without printing its body.
3956void Block::printAsOperand(raw_ostream &os, bool printType) {
3957 Operation *parentOp = getParentOp();
3958 if (!parentOp) {
3959 os << "<<UNLINKED BLOCK>>\n";
3960 return;
3961 }
3962 AsmState state(parentOp);
3963 printAsOperand(os, state);
3964}
3965void Block::printAsOperand(raw_ostream &os, AsmState &state) {
3966 OperationPrinter printer(os, state.getImpl());
3967 printer.printBlockName(block: this);
3968}
3969
3970//===--------------------------------------------------------------------===//
3971// Custom printers
3972//===--------------------------------------------------------------------===//
3973namespace mlir {
3974
3975void printDimensionList(OpAsmPrinter &printer, Operation *op,
3976 ArrayRef<int64_t> dimensions) {
3977 if (dimensions.empty())
3978 printer << "[";
3979 printer.printDimensionList(shape: dimensions);
3980 if (dimensions.empty())
3981 printer << "]";
3982}
3983
3984ParseResult parseDimensionList(OpAsmParser &parser,
3985 DenseI64ArrayAttr &dimensions) {
3986 // Empty list case denoted by "[]".
3987 if (succeeded(result: parser.parseOptionalLSquare())) {
3988 if (failed(result: parser.parseRSquare())) {
3989 return parser.emitError(loc: parser.getCurrentLocation())
3990 << "Failed parsing dimension list.";
3991 }
3992 dimensions =
3993 DenseI64ArrayAttr::get(parser.getContext(), ArrayRef<int64_t>());
3994 return success();
3995 }
3996
3997 // Non-empty list case.
3998 SmallVector<int64_t> shapeArr;
3999 if (failed(result: parser.parseDimensionList(dimensions&: shapeArr, allowDynamic: true, withTrailingX: false))) {
4000 return parser.emitError(loc: parser.getCurrentLocation())
4001 << "Failed parsing dimension list.";
4002 }
4003 if (shapeArr.empty()) {
4004 return parser.emitError(loc: parser.getCurrentLocation())
4005 << "Failed parsing dimension list. Did you mean an empty list? It "
4006 "must be denoted by \"[]\".";
4007 }
4008 dimensions = DenseI64ArrayAttr::get(parser.getContext(), shapeArr);
4009 return success();
4010}
4011
4012} // namespace mlir
4013

source code of mlir/lib/IR/AsmPrinter.cpp