| 1 | //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements the MLIR AsmPrinter class, which is used to implement |
| 10 | // the various print() methods on the core IR objects. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/IR/AffineExpr.h" |
| 15 | #include "mlir/IR/AffineMap.h" |
| 16 | #include "mlir/IR/AsmState.h" |
| 17 | #include "mlir/IR/Attributes.h" |
| 18 | #include "mlir/IR/Builders.h" |
| 19 | #include "mlir/IR/BuiltinAttributes.h" |
| 20 | #include "mlir/IR/BuiltinDialect.h" |
| 21 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
| 22 | #include "mlir/IR/BuiltinTypes.h" |
| 23 | #include "mlir/IR/Dialect.h" |
| 24 | #include "mlir/IR/DialectImplementation.h" |
| 25 | #include "mlir/IR/DialectResourceBlobManager.h" |
| 26 | #include "mlir/IR/IntegerSet.h" |
| 27 | #include "mlir/IR/MLIRContext.h" |
| 28 | #include "mlir/IR/OpImplementation.h" |
| 29 | #include "mlir/IR/Operation.h" |
| 30 | #include "mlir/IR/Verifier.h" |
| 31 | #include "llvm/ADT/APFloat.h" |
| 32 | #include "llvm/ADT/ArrayRef.h" |
| 33 | #include "llvm/ADT/DenseMap.h" |
| 34 | #include "llvm/ADT/MapVector.h" |
| 35 | #include "llvm/ADT/STLExtras.h" |
| 36 | #include "llvm/ADT/ScopeExit.h" |
| 37 | #include "llvm/ADT/ScopedHashTable.h" |
| 38 | #include "llvm/ADT/SetVector.h" |
| 39 | #include "llvm/ADT/StringExtras.h" |
| 40 | #include "llvm/ADT/StringSet.h" |
| 41 | #include "llvm/ADT/TypeSwitch.h" |
| 42 | #include "llvm/Support/CommandLine.h" |
| 43 | #include "llvm/Support/Debug.h" |
| 44 | #include "llvm/Support/Endian.h" |
| 45 | #include "llvm/Support/ManagedStatic.h" |
| 46 | #include "llvm/Support/Regex.h" |
| 47 | #include "llvm/Support/SaveAndRestore.h" |
| 48 | #include "llvm/Support/Threading.h" |
| 49 | #include "llvm/Support/raw_ostream.h" |
| 50 | #include <type_traits> |
| 51 | |
| 52 | #include <optional> |
| 53 | #include <tuple> |
| 54 | |
| 55 | using namespace mlir; |
| 56 | using namespace mlir::detail; |
| 57 | |
| 58 | #define DEBUG_TYPE "mlir-asm-printer" |
| 59 | |
| 60 | void OperationName::print(raw_ostream &os) const { os << getStringRef(); } |
| 61 | |
| 62 | void OperationName::dump() const { print(os&: llvm::errs()); } |
| 63 | |
| 64 | //===--------------------------------------------------------------------===// |
| 65 | // AsmParser |
| 66 | //===--------------------------------------------------------------------===// |
| 67 | |
| 68 | AsmParser::~AsmParser() = default; |
| 69 | DialectAsmParser::~DialectAsmParser() = default; |
| 70 | OpAsmParser::~OpAsmParser() = default; |
| 71 | |
| 72 | MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); } |
| 73 | |
| 74 | /// Parse a type list. |
| 75 | /// This is out-of-line to work-around |
| 76 | /// https://github.com/llvm/llvm-project/issues/62918 |
| 77 | ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) { |
| 78 | return parseCommaSeparatedList( |
| 79 | parseElementFn: [&]() { return parseType(result&: result.emplace_back()); }); |
| 80 | } |
| 81 | |
| 82 | //===----------------------------------------------------------------------===// |
| 83 | // DialectAsmPrinter |
| 84 | //===----------------------------------------------------------------------===// |
| 85 | |
| 86 | DialectAsmPrinter::~DialectAsmPrinter() = default; |
| 87 | |
| 88 | //===----------------------------------------------------------------------===// |
| 89 | // OpAsmPrinter |
| 90 | //===----------------------------------------------------------------------===// |
| 91 | |
| 92 | OpAsmPrinter::~OpAsmPrinter() = default; |
| 93 | |
| 94 | void OpAsmPrinter::printFunctionalType(Operation *op) { |
| 95 | auto &os = getStream(); |
| 96 | os << '('; |
| 97 | llvm::interleaveComma(c: op->getOperands(), os, each_fn: [&](Value operand) { |
| 98 | // Print the types of null values as <<NULL TYPE>>. |
| 99 | *this << (operand ? operand.getType() : Type()); |
| 100 | }); |
| 101 | os << ") -> " ; |
| 102 | |
| 103 | // Print the result list. We don't parenthesize single result types unless |
| 104 | // it is a function (avoiding a grammar ambiguity). |
| 105 | bool wrapped = op->getNumResults() != 1; |
| 106 | if (!wrapped && op->getResult(idx: 0).getType() && |
| 107 | llvm::isa<FunctionType>(Val: op->getResult(idx: 0).getType())) |
| 108 | wrapped = true; |
| 109 | |
| 110 | if (wrapped) |
| 111 | os << '('; |
| 112 | |
| 113 | llvm::interleaveComma(c: op->getResults(), os, each_fn: [&](const OpResult &result) { |
| 114 | // Print the types of null values as <<NULL TYPE>>. |
| 115 | *this << (result ? result.getType() : Type()); |
| 116 | }); |
| 117 | |
| 118 | if (wrapped) |
| 119 | os << ')'; |
| 120 | } |
| 121 | |
| 122 | //===----------------------------------------------------------------------===// |
| 123 | // Operation OpAsm interface. |
| 124 | //===----------------------------------------------------------------------===// |
| 125 | |
| 126 | /// The OpAsmOpInterface, see OpAsmInterface.td for more details. |
| 127 | #include "mlir/IR/OpAsmAttrInterface.cpp.inc" |
| 128 | #include "mlir/IR/OpAsmOpInterface.cpp.inc" |
| 129 | #include "mlir/IR/OpAsmTypeInterface.cpp.inc" |
| 130 | |
| 131 | LogicalResult |
| 132 | OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const { |
| 133 | return entry.emitError() << "unknown 'resource' key '" << entry.getKey() |
| 134 | << "' for dialect '" << getDialect()->getNamespace() |
| 135 | << "'" ; |
| 136 | } |
| 137 | |
| 138 | //===----------------------------------------------------------------------===// |
| 139 | // OpPrintingFlags |
| 140 | //===----------------------------------------------------------------------===// |
| 141 | |
| 142 | namespace { |
| 143 | /// This struct contains command line options that can be used to initialize |
| 144 | /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need |
| 145 | /// for global command line options. |
| 146 | struct AsmPrinterOptions { |
| 147 | llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{ |
| 148 | "mlir-print-elementsattrs-with-hex-if-larger" , |
| 149 | llvm::cl::desc( |
| 150 | "Print DenseElementsAttrs with a hex string that have " |
| 151 | "more elements than the given upper limit (use -1 to disable)" )}; |
| 152 | |
| 153 | llvm::cl::opt<unsigned> elideElementsAttrIfLarger{ |
| 154 | "mlir-elide-elementsattrs-if-larger" , |
| 155 | llvm::cl::desc("Elide ElementsAttrs with \"...\" that have " |
| 156 | "more elements than the given upper limit" )}; |
| 157 | |
| 158 | llvm::cl::opt<unsigned> elideResourceStringsIfLarger{ |
| 159 | "mlir-elide-resource-strings-if-larger" , |
| 160 | llvm::cl::desc( |
| 161 | "Elide printing value of resources if string is too long in chars." )}; |
| 162 | |
| 163 | llvm::cl::opt<bool> printDebugInfoOpt{ |
| 164 | "mlir-print-debuginfo" , llvm::cl::init(Val: false), |
| 165 | llvm::cl::desc("Print debug info in MLIR output" )}; |
| 166 | |
| 167 | llvm::cl::opt<bool> printPrettyDebugInfoOpt{ |
| 168 | "mlir-pretty-debuginfo" , llvm::cl::init(Val: false), |
| 169 | llvm::cl::desc("Print pretty debug info in MLIR output" )}; |
| 170 | |
| 171 | // Use the generic op output form in the operation printer even if the custom |
| 172 | // form is defined. |
| 173 | llvm::cl::opt<bool> printGenericOpFormOpt{ |
| 174 | "mlir-print-op-generic" , llvm::cl::init(Val: false), |
| 175 | llvm::cl::desc("Print the generic op form" ), llvm::cl::Hidden}; |
| 176 | |
| 177 | llvm::cl::opt<bool> assumeVerifiedOpt{ |
| 178 | "mlir-print-assume-verified" , llvm::cl::init(Val: false), |
| 179 | llvm::cl::desc("Skip op verification when using custom printers" ), |
| 180 | llvm::cl::Hidden}; |
| 181 | |
| 182 | llvm::cl::opt<bool> printLocalScopeOpt{ |
| 183 | "mlir-print-local-scope" , llvm::cl::init(Val: false), |
| 184 | llvm::cl::desc("Print with local scope and inline information (eliding " |
| 185 | "aliases for attributes, types, and locations)" )}; |
| 186 | |
| 187 | llvm::cl::opt<bool> skipRegionsOpt{ |
| 188 | "mlir-print-skip-regions" , llvm::cl::init(Val: false), |
| 189 | llvm::cl::desc("Skip regions when printing ops." )}; |
| 190 | |
| 191 | llvm::cl::opt<bool> printValueUsers{ |
| 192 | "mlir-print-value-users" , llvm::cl::init(Val: false), |
| 193 | llvm::cl::desc( |
| 194 | "Print users of operation results and block arguments as a comment" )}; |
| 195 | |
| 196 | llvm::cl::opt<bool> printUniqueSSAIDs{ |
| 197 | "mlir-print-unique-ssa-ids" , llvm::cl::init(Val: false), |
| 198 | llvm::cl::desc("Print unique SSA ID numbers for values, block arguments " |
| 199 | "and naming conflicts across all regions" )}; |
| 200 | |
| 201 | llvm::cl::opt<bool> useNameLocAsPrefix{ |
| 202 | "mlir-use-nameloc-as-prefix" , llvm::cl::init(Val: false), |
| 203 | llvm::cl::desc("Print SSA IDs using NameLocs as prefixes" )}; |
| 204 | }; |
| 205 | } // namespace |
| 206 | |
| 207 | static llvm::ManagedStatic<AsmPrinterOptions> clOptions; |
| 208 | |
| 209 | /// Register a set of useful command-line options that can be used to configure |
| 210 | /// various flags within the AsmPrinter. |
| 211 | void mlir::registerAsmPrinterCLOptions() { |
| 212 | // Make sure that the options struct has been initialized. |
| 213 | *clOptions; |
| 214 | } |
| 215 | |
| 216 | /// Initialize the printing flags with default supplied by the cl::opts above. |
| 217 | OpPrintingFlags::OpPrintingFlags() |
| 218 | : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false), |
| 219 | printGenericOpFormFlag(false), skipRegionsFlag(false), |
| 220 | assumeVerifiedFlag(false), printLocalScope(false), |
| 221 | printValueUsersFlag(false), printUniqueSSAIDsFlag(false), |
| 222 | useNameLocAsPrefix(false) { |
| 223 | // Initialize based upon command line options, if they are available. |
| 224 | if (!clOptions.isConstructed()) |
| 225 | return; |
| 226 | if (clOptions->elideElementsAttrIfLarger.getNumOccurrences()) |
| 227 | elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger; |
| 228 | if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) |
| 229 | elementsAttrHexElementLimit = |
| 230 | clOptions->printElementsAttrWithHexIfLarger.getValue(); |
| 231 | if (clOptions->elideResourceStringsIfLarger.getNumOccurrences()) |
| 232 | resourceStringCharLimit = clOptions->elideResourceStringsIfLarger; |
| 233 | printDebugInfoFlag = clOptions->printDebugInfoOpt; |
| 234 | printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt; |
| 235 | printGenericOpFormFlag = clOptions->printGenericOpFormOpt; |
| 236 | assumeVerifiedFlag = clOptions->assumeVerifiedOpt; |
| 237 | printLocalScope = clOptions->printLocalScopeOpt; |
| 238 | skipRegionsFlag = clOptions->skipRegionsOpt; |
| 239 | printValueUsersFlag = clOptions->printValueUsers; |
| 240 | printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs; |
| 241 | useNameLocAsPrefix = clOptions->useNameLocAsPrefix; |
| 242 | } |
| 243 | |
| 244 | /// Enable the elision of large elements attributes, by printing a '...' |
| 245 | /// instead of the element data, when the number of elements is greater than |
| 246 | /// `largeElementLimit`. Note: The IR generated with this option is not |
| 247 | /// parsable. |
| 248 | OpPrintingFlags & |
| 249 | OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) { |
| 250 | elementsAttrElementLimit = largeElementLimit; |
| 251 | return *this; |
| 252 | } |
| 253 | |
| 254 | OpPrintingFlags & |
| 255 | OpPrintingFlags::printLargeElementsAttrWithHex(int64_t largeElementLimit) { |
| 256 | elementsAttrHexElementLimit = largeElementLimit; |
| 257 | return *this; |
| 258 | } |
| 259 | |
| 260 | OpPrintingFlags & |
| 261 | OpPrintingFlags::elideLargeResourceString(int64_t largeResourceLimit) { |
| 262 | resourceStringCharLimit = largeResourceLimit; |
| 263 | return *this; |
| 264 | } |
| 265 | |
| 266 | /// Enable printing of debug information. If 'prettyForm' is set to true, |
| 267 | /// debug information is printed in a more readable 'pretty' form. |
| 268 | OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool enable, |
| 269 | bool prettyForm) { |
| 270 | printDebugInfoFlag = enable; |
| 271 | printDebugInfoPrettyFormFlag = prettyForm; |
| 272 | return *this; |
| 273 | } |
| 274 | |
| 275 | /// Always print operations in the generic form. |
| 276 | OpPrintingFlags &OpPrintingFlags::printGenericOpForm(bool enable) { |
| 277 | printGenericOpFormFlag = enable; |
| 278 | return *this; |
| 279 | } |
| 280 | |
| 281 | /// Always skip Regions. |
| 282 | OpPrintingFlags &OpPrintingFlags::skipRegions(bool skip) { |
| 283 | skipRegionsFlag = skip; |
| 284 | return *this; |
| 285 | } |
| 286 | |
| 287 | /// Do not verify the operation when using custom operation printers. |
| 288 | OpPrintingFlags &OpPrintingFlags::assumeVerified(bool enable) { |
| 289 | assumeVerifiedFlag = enable; |
| 290 | return *this; |
| 291 | } |
| 292 | |
| 293 | /// Use local scope when printing the operation. This allows for using the |
| 294 | /// printer in a more localized and thread-safe setting, but may not necessarily |
| 295 | /// be identical of what the IR will look like when dumping the full module. |
| 296 | OpPrintingFlags &OpPrintingFlags::useLocalScope(bool enable) { |
| 297 | printLocalScope = enable; |
| 298 | return *this; |
| 299 | } |
| 300 | |
| 301 | /// Print users of values as comments. |
| 302 | OpPrintingFlags &OpPrintingFlags::printValueUsers(bool enable) { |
| 303 | printValueUsersFlag = enable; |
| 304 | return *this; |
| 305 | } |
| 306 | |
| 307 | /// Print unique SSA ID numbers for values, block arguments and naming conflicts |
| 308 | /// across all regions |
| 309 | OpPrintingFlags &OpPrintingFlags::printUniqueSSAIDs(bool enable) { |
| 310 | printUniqueSSAIDsFlag = enable; |
| 311 | return *this; |
| 312 | } |
| 313 | |
| 314 | /// Return if the given ElementsAttr should be elided. |
| 315 | bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const { |
| 316 | return elementsAttrElementLimit && |
| 317 | *elementsAttrElementLimit < int64_t(attr.getNumElements()) && |
| 318 | !llvm::isa<SplatElementsAttr>(Val: attr); |
| 319 | } |
| 320 | |
| 321 | /// Return if the given ElementsAttr should be printed as hex string. |
| 322 | bool OpPrintingFlags::shouldPrintElementsAttrWithHex(ElementsAttr attr) const { |
| 323 | // -1 is used to disable hex printing. |
| 324 | return (elementsAttrHexElementLimit != -1) && |
| 325 | (elementsAttrHexElementLimit < int64_t(attr.getNumElements())) && |
| 326 | !llvm::isa<SplatElementsAttr>(Val: attr); |
| 327 | } |
| 328 | |
| 329 | OpPrintingFlags &OpPrintingFlags::printNameLocAsPrefix(bool enable) { |
| 330 | useNameLocAsPrefix = enable; |
| 331 | return *this; |
| 332 | } |
| 333 | |
| 334 | /// Return the size limit for printing large ElementsAttr. |
| 335 | std::optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const { |
| 336 | return elementsAttrElementLimit; |
| 337 | } |
| 338 | |
| 339 | /// Return the size limit for printing large ElementsAttr as hex string. |
| 340 | int64_t OpPrintingFlags::getLargeElementsAttrHexLimit() const { |
| 341 | return elementsAttrHexElementLimit; |
| 342 | } |
| 343 | |
| 344 | /// Return the size limit for printing large ElementsAttr. |
| 345 | std::optional<uint64_t> OpPrintingFlags::getLargeResourceStringLimit() const { |
| 346 | return resourceStringCharLimit; |
| 347 | } |
| 348 | |
| 349 | /// Return if debug information should be printed. |
| 350 | bool OpPrintingFlags::shouldPrintDebugInfo() const { |
| 351 | return printDebugInfoFlag; |
| 352 | } |
| 353 | |
| 354 | /// Return if debug information should be printed in the pretty form. |
| 355 | bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const { |
| 356 | return printDebugInfoPrettyFormFlag; |
| 357 | } |
| 358 | |
| 359 | /// Return if operations should be printed in the generic form. |
| 360 | bool OpPrintingFlags::shouldPrintGenericOpForm() const { |
| 361 | return printGenericOpFormFlag; |
| 362 | } |
| 363 | |
| 364 | /// Return if Region should be skipped. |
| 365 | bool OpPrintingFlags::shouldSkipRegions() const { return skipRegionsFlag; } |
| 366 | |
| 367 | /// Return if operation verification should be skipped. |
| 368 | bool OpPrintingFlags::shouldAssumeVerified() const { |
| 369 | return assumeVerifiedFlag; |
| 370 | } |
| 371 | |
| 372 | /// Return if the printer should use local scope when dumping the IR. |
| 373 | bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; } |
| 374 | |
| 375 | /// Return if the printer should print users of values. |
| 376 | bool OpPrintingFlags::shouldPrintValueUsers() const { |
| 377 | return printValueUsersFlag; |
| 378 | } |
| 379 | |
| 380 | /// Return if the printer should use unique IDs. |
| 381 | bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const { |
| 382 | return printUniqueSSAIDsFlag || shouldPrintGenericOpForm(); |
| 383 | } |
| 384 | |
| 385 | /// Return if the printer should use NameLocs as prefixes when printing SSA IDs. |
| 386 | bool OpPrintingFlags::shouldUseNameLocAsPrefix() const { |
| 387 | return useNameLocAsPrefix; |
| 388 | } |
| 389 | |
| 390 | //===----------------------------------------------------------------------===// |
| 391 | // NewLineCounter |
| 392 | //===----------------------------------------------------------------------===// |
| 393 | |
| 394 | namespace { |
| 395 | /// This class is a simple formatter that emits a new line when inputted into a |
| 396 | /// stream, that enables counting the number of newlines emitted. This class |
| 397 | /// should be used whenever emitting newlines in the printer. |
| 398 | struct NewLineCounter { |
| 399 | unsigned curLine = 1; |
| 400 | }; |
| 401 | |
| 402 | static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) { |
| 403 | ++newLine.curLine; |
| 404 | return os << '\n'; |
| 405 | } |
| 406 | } // namespace |
| 407 | |
| 408 | //===----------------------------------------------------------------------===// |
| 409 | // AsmPrinter::Impl |
| 410 | //===----------------------------------------------------------------------===// |
| 411 | |
| 412 | namespace mlir { |
| 413 | class AsmPrinter::Impl { |
| 414 | public: |
| 415 | Impl(raw_ostream &os, AsmStateImpl &state); |
| 416 | explicit Impl(Impl &other) : Impl(other.os, other.state) {} |
| 417 | |
| 418 | /// Returns the output stream of the printer. |
| 419 | raw_ostream &getStream() { return os; } |
| 420 | |
| 421 | template <typename Container, typename UnaryFunctor> |
| 422 | inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const { |
| 423 | llvm::interleaveComma(c, os, eachFn); |
| 424 | } |
| 425 | |
| 426 | /// This enum describes the different kinds of elision for the type of an |
| 427 | /// attribute when printing it. |
| 428 | enum class AttrTypeElision { |
| 429 | /// The type must not be elided, |
| 430 | Never, |
| 431 | /// The type may be elided when it matches the default used in the parser |
| 432 | /// (for example i64 is the default for integer attributes). |
| 433 | May, |
| 434 | /// The type must be elided. |
| 435 | Must |
| 436 | }; |
| 437 | |
| 438 | /// Print the given attribute or an alias. |
| 439 | void printAttribute(Attribute attr, |
| 440 | AttrTypeElision typeElision = AttrTypeElision::Never); |
| 441 | /// Print the given attribute without considering an alias. |
| 442 | void printAttributeImpl(Attribute attr, |
| 443 | AttrTypeElision typeElision = AttrTypeElision::Never); |
| 444 | void printNamedAttribute(NamedAttribute attr); |
| 445 | |
| 446 | /// Print the alias for the given attribute, return failure if no alias could |
| 447 | /// be printed. |
| 448 | LogicalResult printAlias(Attribute attr); |
| 449 | |
| 450 | /// Print the given type or an alias. |
| 451 | void printType(Type type); |
| 452 | /// Print the given type. |
| 453 | void printTypeImpl(Type type); |
| 454 | |
| 455 | /// Print the alias for the given type, return failure if no alias could |
| 456 | /// be printed. |
| 457 | LogicalResult printAlias(Type type); |
| 458 | |
| 459 | /// Print the given location to the stream. If `allowAlias` is true, this |
| 460 | /// allows for the internal location to use an attribute alias. |
| 461 | void printLocation(LocationAttr loc, bool allowAlias = false); |
| 462 | |
| 463 | /// Print a reference to the given resource that is owned by the given |
| 464 | /// dialect. |
| 465 | void printResourceHandle(const AsmDialectResourceHandle &resource); |
| 466 | |
| 467 | void printAffineMap(AffineMap map); |
| 468 | void |
| 469 | printAffineExpr(AffineExpr expr, |
| 470 | function_ref<void(unsigned, bool)> printValueName = nullptr); |
| 471 | void printAffineConstraint(AffineExpr expr, bool isEq); |
| 472 | void printIntegerSet(IntegerSet set); |
| 473 | |
| 474 | LogicalResult pushCyclicPrinting(const void *opaquePointer); |
| 475 | |
| 476 | void popCyclicPrinting(); |
| 477 | |
| 478 | void printDimensionList(ArrayRef<int64_t> shape); |
| 479 | |
| 480 | protected: |
| 481 | void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
| 482 | ArrayRef<StringRef> elidedAttrs = {}, |
| 483 | bool withKeyword = false); |
| 484 | void printTrailingLocation(Location loc, bool allowAlias = true); |
| 485 | void printLocationInternal(LocationAttr loc, bool pretty = false, |
| 486 | bool isTopLevel = false); |
| 487 | |
| 488 | /// Print a dense elements attribute. If 'allowHex' is true, a hex string is |
| 489 | /// used instead of individual elements when the elements attr is large. |
| 490 | void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex); |
| 491 | |
| 492 | /// Print a dense string elements attribute. |
| 493 | void printDenseStringElementsAttr(DenseStringElementsAttr attr); |
| 494 | |
| 495 | /// Print a dense elements attribute. If 'allowHex' is true, a hex string is |
| 496 | /// used instead of individual elements when the elements attr is large. |
| 497 | void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, |
| 498 | bool allowHex); |
| 499 | |
| 500 | /// Print a dense array attribute. |
| 501 | void printDenseArrayAttr(DenseArrayAttr attr); |
| 502 | |
| 503 | void printDialectAttribute(Attribute attr); |
| 504 | void printDialectType(Type type); |
| 505 | |
| 506 | /// Print an escaped string, wrapped with "". |
| 507 | void printEscapedString(StringRef str); |
| 508 | |
| 509 | /// Print a hex string, wrapped with "". |
| 510 | void printHexString(StringRef str); |
| 511 | void printHexString(ArrayRef<char> data); |
| 512 | |
| 513 | /// This enum is used to represent the binding strength of the enclosing |
| 514 | /// context that an AffineExprStorage is being printed in, so we can |
| 515 | /// intelligently produce parens. |
| 516 | enum class BindingStrength { |
| 517 | Weak, // + and - |
| 518 | Strong, // All other binary operators. |
| 519 | }; |
| 520 | void printAffineExprInternal( |
| 521 | AffineExpr expr, BindingStrength enclosingTightness, |
| 522 | function_ref<void(unsigned, bool)> printValueName = nullptr); |
| 523 | |
| 524 | /// The output stream for the printer. |
| 525 | raw_ostream &os; |
| 526 | |
| 527 | /// An underlying assembly printer state. |
| 528 | AsmStateImpl &state; |
| 529 | |
| 530 | /// A set of flags to control the printer's behavior. |
| 531 | OpPrintingFlags printerFlags; |
| 532 | |
| 533 | /// A tracker for the number of new lines emitted during printing. |
| 534 | NewLineCounter newLine; |
| 535 | }; |
| 536 | } // namespace mlir |
| 537 | |
| 538 | //===----------------------------------------------------------------------===// |
| 539 | // AliasInitializer |
| 540 | //===----------------------------------------------------------------------===// |
| 541 | |
| 542 | namespace { |
| 543 | /// This class represents a specific instance of a symbol Alias. |
| 544 | class SymbolAlias { |
| 545 | public: |
| 546 | SymbolAlias(StringRef name, uint32_t suffixIndex, bool isType, |
| 547 | bool isDeferrable) |
| 548 | : name(name), suffixIndex(suffixIndex), isType(isType), |
| 549 | isDeferrable(isDeferrable) {} |
| 550 | |
| 551 | /// Print this alias to the given stream. |
| 552 | void print(raw_ostream &os) const { |
| 553 | os << (isType ? "!" : "#" ) << name; |
| 554 | if (suffixIndex) { |
| 555 | if (isdigit(name.back())) |
| 556 | os << '_'; |
| 557 | os << suffixIndex; |
| 558 | } |
| 559 | } |
| 560 | |
| 561 | /// Returns true if this is a type alias. |
| 562 | bool isTypeAlias() const { return isType; } |
| 563 | |
| 564 | /// Returns true if this alias supports deferred resolution when parsing. |
| 565 | bool canBeDeferred() const { return isDeferrable; } |
| 566 | |
| 567 | private: |
| 568 | /// The main name of the alias. |
| 569 | StringRef name; |
| 570 | /// The suffix index of the alias. |
| 571 | uint32_t suffixIndex : 30; |
| 572 | /// A flag indicating whether this alias is for a type. |
| 573 | bool isType : 1; |
| 574 | /// A flag indicating whether this alias may be deferred or not. |
| 575 | bool isDeferrable : 1; |
| 576 | |
| 577 | public: |
| 578 | /// Used to avoid printing incomplete aliases for recursive types. |
| 579 | bool isPrinted = false; |
| 580 | }; |
| 581 | |
| 582 | /// This class represents a utility that initializes the set of attribute and |
| 583 | /// type aliases, without the need to store the extra information within the |
| 584 | /// main AliasState class or pass it around via function arguments. |
| 585 | class AliasInitializer { |
| 586 | public: |
| 587 | AliasInitializer( |
| 588 | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces, |
| 589 | llvm::BumpPtrAllocator &aliasAllocator) |
| 590 | : interfaces(interfaces), aliasAllocator(aliasAllocator), |
| 591 | aliasOS(aliasBuffer) {} |
| 592 | |
| 593 | void initialize(Operation *op, const OpPrintingFlags &printerFlags, |
| 594 | llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias); |
| 595 | |
| 596 | /// Visit the given attribute to see if it has an alias. `canBeDeferred` is |
| 597 | /// set to true if the originator of this attribute can resolve the alias |
| 598 | /// after parsing has completed (e.g. in the case of operation locations). |
| 599 | /// `elideType` indicates if the type of the attribute should be skipped when |
| 600 | /// looking for nested aliases. Returns the maximum alias depth of the |
| 601 | /// attribute, and the alias index of this attribute. |
| 602 | std::pair<size_t, size_t> visit(Attribute attr, bool canBeDeferred = false, |
| 603 | bool elideType = false) { |
| 604 | return visitImpl(value: attr, aliases, canBeDeferred, printArgs&: elideType); |
| 605 | } |
| 606 | |
| 607 | /// Visit the given type to see if it has an alias. `canBeDeferred` is |
| 608 | /// set to true if the originator of this attribute can resolve the alias |
| 609 | /// after parsing has completed. Returns the maximum alias depth of the type, |
| 610 | /// and the alias index of this type. |
| 611 | std::pair<size_t, size_t> visit(Type type, bool canBeDeferred = false) { |
| 612 | return visitImpl(value: type, aliases, canBeDeferred); |
| 613 | } |
| 614 | |
| 615 | private: |
| 616 | struct InProgressAliasInfo { |
| 617 | InProgressAliasInfo() |
| 618 | : aliasDepth(0), isType(false), canBeDeferred(false) {} |
| 619 | InProgressAliasInfo(StringRef alias) |
| 620 | : alias(alias), aliasDepth(1), isType(false), canBeDeferred(false) {} |
| 621 | |
| 622 | bool operator<(const InProgressAliasInfo &rhs) const { |
| 623 | // Order first by depth, then by attr/type kind, and then by name. |
| 624 | if (aliasDepth != rhs.aliasDepth) |
| 625 | return aliasDepth < rhs.aliasDepth; |
| 626 | if (isType != rhs.isType) |
| 627 | return isType; |
| 628 | return alias < rhs.alias; |
| 629 | } |
| 630 | |
| 631 | /// The alias for the attribute or type, or std::nullopt if the value has no |
| 632 | /// alias. |
| 633 | std::optional<StringRef> alias; |
| 634 | /// The alias depth of this attribute or type, i.e. an indication of the |
| 635 | /// relative ordering of when to print this alias. |
| 636 | unsigned aliasDepth : 30; |
| 637 | /// If this alias represents a type or an attribute. |
| 638 | bool isType : 1; |
| 639 | /// If this alias can be deferred or not. |
| 640 | bool canBeDeferred : 1; |
| 641 | /// Indices for child aliases. |
| 642 | SmallVector<size_t> childIndices; |
| 643 | }; |
| 644 | |
| 645 | /// Visit the given attribute or type to see if it has an alias. |
| 646 | /// `canBeDeferred` is set to true if the originator of this value can resolve |
| 647 | /// the alias after parsing has completed (e.g. in the case of operation |
| 648 | /// locations). Returns the maximum alias depth of the value, and its alias |
| 649 | /// index. |
| 650 | template <typename T, typename... PrintArgs> |
| 651 | std::pair<size_t, size_t> |
| 652 | visitImpl(T value, |
| 653 | llvm::MapVector<const void *, InProgressAliasInfo> &aliases, |
| 654 | bool canBeDeferred, PrintArgs &&...printArgs); |
| 655 | |
| 656 | /// Mark the given alias as non-deferrable. |
| 657 | void markAliasNonDeferrable(size_t aliasIndex); |
| 658 | |
| 659 | /// Try to generate an alias for the provided symbol. If an alias is |
| 660 | /// generated, the provided alias mapping and reverse mapping are updated. |
| 661 | template <typename T> |
| 662 | void generateAlias(T symbol, InProgressAliasInfo &alias, bool canBeDeferred); |
| 663 | |
| 664 | /// Uniques the given alias name within the printer by generating name index |
| 665 | /// used as alias name suffix. |
| 666 | static unsigned |
| 667 | uniqueAliasNameIndex(StringRef alias, llvm::StringMap<unsigned> &nameCounts, |
| 668 | llvm::StringSet<llvm::BumpPtrAllocator &> &usedAliases); |
| 669 | |
| 670 | /// Given a collection of aliases and symbols, initialize a mapping from a |
| 671 | /// symbol to a given alias. |
| 672 | static void initializeAliases( |
| 673 | llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols, |
| 674 | llvm::MapVector<const void *, SymbolAlias> &symbolToAlias); |
| 675 | |
| 676 | /// The set of asm interfaces within the context. |
| 677 | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces; |
| 678 | |
| 679 | /// An allocator used for alias names. |
| 680 | llvm::BumpPtrAllocator &aliasAllocator; |
| 681 | |
| 682 | /// The set of built aliases. |
| 683 | llvm::MapVector<const void *, InProgressAliasInfo> aliases; |
| 684 | |
| 685 | /// Storage and stream used when generating an alias. |
| 686 | SmallString<32> aliasBuffer; |
| 687 | llvm::raw_svector_ostream aliasOS; |
| 688 | }; |
| 689 | |
| 690 | /// This class implements a dummy OpAsmPrinter that doesn't print any output, |
| 691 | /// and merely collects the attributes and types that *would* be printed in a |
| 692 | /// normal print invocation so that we can generate proper aliases. This allows |
| 693 | /// for us to generate aliases only for the attributes and types that would be |
| 694 | /// in the output, and trims down unnecessary output. |
| 695 | class DummyAliasOperationPrinter : private OpAsmPrinter { |
| 696 | public: |
| 697 | explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags, |
| 698 | AliasInitializer &initializer) |
| 699 | : printerFlags(printerFlags), initializer(initializer) {} |
| 700 | |
| 701 | /// Prints the entire operation with the custom assembly form, if available, |
| 702 | /// or the generic assembly form, otherwise. |
| 703 | void printCustomOrGenericOp(Operation *op) override { |
| 704 | // Visit the operation location. |
| 705 | if (printerFlags.shouldPrintDebugInfo()) |
| 706 | initializer.visit(attr: op->getLoc(), /*canBeDeferred=*/true); |
| 707 | |
| 708 | // If requested, always print the generic form. |
| 709 | if (!printerFlags.shouldPrintGenericOpForm()) { |
| 710 | op->getName().printAssembly(op, p&: *this, /*defaultDialect=*/"" ); |
| 711 | return; |
| 712 | } |
| 713 | |
| 714 | // Otherwise print with the generic assembly form. |
| 715 | printGenericOp(op); |
| 716 | } |
| 717 | |
| 718 | private: |
| 719 | /// Print the given operation in the generic form. |
| 720 | void printGenericOp(Operation *op, bool printOpName = true) override { |
| 721 | // Consider nested operations for aliases. |
| 722 | if (!printerFlags.shouldSkipRegions()) { |
| 723 | for (Region ®ion : op->getRegions()) |
| 724 | printRegion(region, /*printEntryBlockArgs=*/true, |
| 725 | /*printBlockTerminators=*/true); |
| 726 | } |
| 727 | |
| 728 | // Visit all the types used in the operation. |
| 729 | for (Type type : op->getOperandTypes()) |
| 730 | printType(type); |
| 731 | for (Type type : op->getResultTypes()) |
| 732 | printType(type); |
| 733 | |
| 734 | // Consider the attributes of the operation for aliases. |
| 735 | for (const NamedAttribute &attr : op->getAttrs()) |
| 736 | printAttribute(attr: attr.getValue()); |
| 737 | } |
| 738 | |
| 739 | /// Print the given block. If 'printBlockArgs' is false, the arguments of the |
| 740 | /// block are not printed. If 'printBlockTerminator' is false, the terminator |
| 741 | /// operation of the block is not printed. |
| 742 | void print(Block *block, bool printBlockArgs = true, |
| 743 | bool printBlockTerminator = true) { |
| 744 | // Consider the types of the block arguments for aliases if 'printBlockArgs' |
| 745 | // is set to true. |
| 746 | if (printBlockArgs) { |
| 747 | for (BlockArgument arg : block->getArguments()) { |
| 748 | printType(type: arg.getType()); |
| 749 | |
| 750 | // Visit the argument location. |
| 751 | if (printerFlags.shouldPrintDebugInfo()) |
| 752 | // TODO: Allow deferring argument locations. |
| 753 | initializer.visit(attr: arg.getLoc(), /*canBeDeferred=*/false); |
| 754 | } |
| 755 | } |
| 756 | |
| 757 | // Consider the operations within this block, ignoring the terminator if |
| 758 | // requested. |
| 759 | bool hasTerminator = |
| 760 | !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>(); |
| 761 | auto range = llvm::make_range( |
| 762 | x: block->begin(), |
| 763 | y: std::prev(x: block->end(), |
| 764 | n: (!hasTerminator || printBlockTerminator) ? 0 : 1)); |
| 765 | for (Operation &op : range) |
| 766 | printCustomOrGenericOp(op: &op); |
| 767 | } |
| 768 | |
| 769 | /// Print the given region. |
| 770 | void printRegion(Region ®ion, bool printEntryBlockArgs, |
| 771 | bool printBlockTerminators, |
| 772 | bool printEmptyBlock = false) override { |
| 773 | if (region.empty()) |
| 774 | return; |
| 775 | if (printerFlags.shouldSkipRegions()) { |
| 776 | os << "{...}" ; |
| 777 | return; |
| 778 | } |
| 779 | |
| 780 | auto *entryBlock = ®ion.front(); |
| 781 | print(block: entryBlock, printBlockArgs: printEntryBlockArgs, printBlockTerminator: printBlockTerminators); |
| 782 | for (Block &b : llvm::drop_begin(RangeOrContainer&: region, N: 1)) |
| 783 | print(block: &b); |
| 784 | } |
| 785 | |
| 786 | void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, |
| 787 | bool omitType) override { |
| 788 | printType(type: arg.getType()); |
| 789 | // Visit the argument location. |
| 790 | if (printerFlags.shouldPrintDebugInfo()) |
| 791 | // TODO: Allow deferring argument locations. |
| 792 | initializer.visit(attr: arg.getLoc(), /*canBeDeferred=*/false); |
| 793 | } |
| 794 | |
| 795 | /// Consider the given type to be printed for an alias. |
| 796 | void printType(Type type) override { initializer.visit(type); } |
| 797 | |
| 798 | /// Consider the given attribute to be printed for an alias. |
| 799 | void printAttribute(Attribute attr) override { initializer.visit(attr); } |
| 800 | void printAttributeWithoutType(Attribute attr) override { |
| 801 | printAttribute(attr); |
| 802 | } |
| 803 | void printNamedAttribute(NamedAttribute attr) override { |
| 804 | printAttribute(attr: attr.getValue()); |
| 805 | } |
| 806 | |
| 807 | LogicalResult printAlias(Attribute attr) override { |
| 808 | initializer.visit(attr); |
| 809 | return success(); |
| 810 | } |
| 811 | LogicalResult printAlias(Type type) override { |
| 812 | initializer.visit(type); |
| 813 | return success(); |
| 814 | } |
| 815 | |
| 816 | /// Consider the given location to be printed for an alias. |
| 817 | void printOptionalLocationSpecifier(Location loc) override { |
| 818 | printAttribute(attr: loc); |
| 819 | } |
| 820 | |
| 821 | /// Print the given set of attributes with names not included within |
| 822 | /// 'elidedAttrs'. |
| 823 | void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
| 824 | ArrayRef<StringRef> elidedAttrs = {}) override { |
| 825 | if (attrs.empty()) |
| 826 | return; |
| 827 | if (elidedAttrs.empty()) { |
| 828 | for (const NamedAttribute &attr : attrs) |
| 829 | printAttribute(attr: attr.getValue()); |
| 830 | return; |
| 831 | } |
| 832 | llvm::SmallDenseSet<StringRef> (elidedAttrs.begin(), |
| 833 | elidedAttrs.end()); |
| 834 | for (const NamedAttribute &attr : attrs) |
| 835 | if (!elidedAttrsSet.contains(V: attr.getName().strref())) |
| 836 | printAttribute(attr: attr.getValue()); |
| 837 | } |
| 838 | void printOptionalAttrDictWithKeyword( |
| 839 | ArrayRef<NamedAttribute> attrs, |
| 840 | ArrayRef<StringRef> elidedAttrs = {}) override { |
| 841 | printOptionalAttrDict(attrs, elidedAttrs); |
| 842 | } |
| 843 | |
| 844 | /// Return a null stream as the output stream, this will ignore any data fed |
| 845 | /// to it. |
| 846 | raw_ostream &getStream() const override { return os; } |
| 847 | |
| 848 | /// The following are hooks of `OpAsmPrinter` that are not necessary for |
| 849 | /// determining potential aliases. |
| 850 | void printFloat(const APFloat &) override {} |
| 851 | void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {} |
| 852 | void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {} |
| 853 | void printNewline() override {} |
| 854 | void increaseIndent() override {} |
| 855 | void decreaseIndent() override {} |
| 856 | void printOperand(Value) override {} |
| 857 | void printOperand(Value, raw_ostream &os) override { |
| 858 | // Users expect the output string to have at least the prefixed % to signal |
| 859 | // a value name. To maintain this invariant, emit a name even if it is |
| 860 | // guaranteed to go unused. |
| 861 | os << "%" ; |
| 862 | } |
| 863 | void printKeywordOrString(StringRef) override {} |
| 864 | void printString(StringRef) override {} |
| 865 | void printResourceHandle(const AsmDialectResourceHandle &) override {} |
| 866 | void printSymbolName(StringRef) override {} |
| 867 | void printSuccessor(Block *) override {} |
| 868 | void printSuccessorAndUseList(Block *, ValueRange) override {} |
| 869 | void shadowRegionArgs(Region &, ValueRange) override {} |
| 870 | |
| 871 | /// The printer flags to use when determining potential aliases. |
| 872 | const OpPrintingFlags &printerFlags; |
| 873 | |
| 874 | /// The initializer to use when identifying aliases. |
| 875 | AliasInitializer &initializer; |
| 876 | |
| 877 | /// A dummy output stream. |
| 878 | mutable llvm::raw_null_ostream os; |
| 879 | }; |
| 880 | |
| 881 | class DummyAliasDialectAsmPrinter : public DialectAsmPrinter { |
| 882 | public: |
| 883 | explicit DummyAliasDialectAsmPrinter(AliasInitializer &initializer, |
| 884 | bool canBeDeferred, |
| 885 | SmallVectorImpl<size_t> &childIndices) |
| 886 | : initializer(initializer), canBeDeferred(canBeDeferred), |
| 887 | childIndices(childIndices) {} |
| 888 | |
| 889 | /// Print the given attribute/type, visiting any nested aliases that would be |
| 890 | /// generated as part of printing. Returns the maximum alias depth found while |
| 891 | /// printing the given value. |
| 892 | template <typename T, typename... PrintArgs> |
| 893 | size_t printAndVisitNestedAliases(T value, PrintArgs &&...printArgs) { |
| 894 | printAndVisitNestedAliasesImpl(value, printArgs...); |
| 895 | return maxAliasDepth; |
| 896 | } |
| 897 | |
| 898 | private: |
| 899 | /// Print the given attribute/type, visiting any nested aliases that would be |
| 900 | /// generated as part of printing. |
| 901 | void printAndVisitNestedAliasesImpl(Attribute attr, bool elideType) { |
| 902 | if (!isa<BuiltinDialect>(Val: attr.getDialect())) { |
| 903 | attr.getDialect().printAttribute(attr, *this); |
| 904 | |
| 905 | // Process the builtin attributes. |
| 906 | } else if (llvm::isa<AffineMapAttr, DenseArrayAttr, FloatAttr, IntegerAttr, |
| 907 | IntegerSetAttr, UnitAttr>(Val: attr)) { |
| 908 | return; |
| 909 | } else if (auto distinctAttr = dyn_cast<DistinctAttr>(Val&: attr)) { |
| 910 | printAttribute(attr: distinctAttr.getReferencedAttr()); |
| 911 | } else if (auto dictAttr = dyn_cast<DictionaryAttr>(Val&: attr)) { |
| 912 | for (const NamedAttribute &nestedAttr : dictAttr.getValue()) { |
| 913 | printAttribute(attr: nestedAttr.getName()); |
| 914 | printAttribute(attr: nestedAttr.getValue()); |
| 915 | } |
| 916 | } else if (auto arrayAttr = dyn_cast<ArrayAttr>(Val&: attr)) { |
| 917 | for (Attribute nestedAttr : arrayAttr.getValue()) |
| 918 | printAttribute(attr: nestedAttr); |
| 919 | } else if (auto typeAttr = dyn_cast<TypeAttr>(Val&: attr)) { |
| 920 | printType(type: typeAttr.getValue()); |
| 921 | } else if (auto locAttr = dyn_cast<OpaqueLoc>(Val&: attr)) { |
| 922 | printAttribute(attr: locAttr.getFallbackLocation()); |
| 923 | } else if (auto locAttr = dyn_cast<NameLoc>(Val&: attr)) { |
| 924 | if (!isa<UnknownLoc>(Val: locAttr.getChildLoc())) |
| 925 | printAttribute(attr: locAttr.getChildLoc()); |
| 926 | } else if (auto locAttr = dyn_cast<CallSiteLoc>(Val&: attr)) { |
| 927 | printAttribute(attr: locAttr.getCallee()); |
| 928 | printAttribute(attr: locAttr.getCaller()); |
| 929 | } else if (auto locAttr = dyn_cast<FusedLoc>(Val&: attr)) { |
| 930 | if (Attribute metadata = locAttr.getMetadata()) |
| 931 | printAttribute(attr: metadata); |
| 932 | for (Location nestedLoc : locAttr.getLocations()) |
| 933 | printAttribute(attr: nestedLoc); |
| 934 | } |
| 935 | |
| 936 | // Don't print the type if we must elide it, or if it is a None type. |
| 937 | if (!elideType) { |
| 938 | if (auto typedAttr = llvm::dyn_cast<TypedAttr>(Val&: attr)) { |
| 939 | Type attrType = typedAttr.getType(); |
| 940 | if (!llvm::isa<NoneType>(Val: attrType)) |
| 941 | printType(type: attrType); |
| 942 | } |
| 943 | } |
| 944 | } |
| 945 | void printAndVisitNestedAliasesImpl(Type type) { |
| 946 | if (!isa<BuiltinDialect>(Val: type.getDialect())) |
| 947 | return type.getDialect().printType(type, *this); |
| 948 | |
| 949 | // Only visit the layout of memref if it isn't the identity. |
| 950 | if (auto memrefTy = llvm::dyn_cast<MemRefType>(Val&: type)) { |
| 951 | printType(type: memrefTy.getElementType()); |
| 952 | MemRefLayoutAttrInterface layout = memrefTy.getLayout(); |
| 953 | if (!llvm::isa<AffineMapAttr>(Val: layout) || !layout.isIdentity()) |
| 954 | printAttribute(attr: memrefTy.getLayout()); |
| 955 | if (memrefTy.getMemorySpace()) |
| 956 | printAttribute(attr: memrefTy.getMemorySpace()); |
| 957 | return; |
| 958 | } |
| 959 | |
| 960 | // For most builtin types, we can simply walk the sub elements. |
| 961 | auto visitFn = [&](auto element) { |
| 962 | if (element) |
| 963 | (void)printAlias(element); |
| 964 | }; |
| 965 | type.walkImmediateSubElements(walkAttrsFn: visitFn, walkTypesFn: visitFn); |
| 966 | } |
| 967 | |
| 968 | /// Consider the given type to be printed for an alias. |
| 969 | void printType(Type type) override { |
| 970 | recordAliasResult(aliasDepthAndIndex: initializer.visit(type, canBeDeferred)); |
| 971 | } |
| 972 | |
| 973 | /// Consider the given attribute to be printed for an alias. |
| 974 | void printAttribute(Attribute attr) override { |
| 975 | recordAliasResult(aliasDepthAndIndex: initializer.visit(attr, canBeDeferred)); |
| 976 | } |
| 977 | void printAttributeWithoutType(Attribute attr) override { |
| 978 | recordAliasResult( |
| 979 | aliasDepthAndIndex: initializer.visit(attr, canBeDeferred, /*elideType=*/true)); |
| 980 | } |
| 981 | void printNamedAttribute(NamedAttribute attr) override { |
| 982 | printAttribute(attr: attr.getValue()); |
| 983 | } |
| 984 | |
| 985 | LogicalResult printAlias(Attribute attr) override { |
| 986 | printAttribute(attr); |
| 987 | return success(); |
| 988 | } |
| 989 | LogicalResult printAlias(Type type) override { |
| 990 | printType(type); |
| 991 | return success(); |
| 992 | } |
| 993 | |
| 994 | /// Record the alias result of a child element. |
| 995 | void recordAliasResult(std::pair<size_t, size_t> aliasDepthAndIndex) { |
| 996 | childIndices.push_back(Elt: aliasDepthAndIndex.second); |
| 997 | if (aliasDepthAndIndex.first > maxAliasDepth) |
| 998 | maxAliasDepth = aliasDepthAndIndex.first; |
| 999 | } |
| 1000 | |
| 1001 | /// Return a null stream as the output stream, this will ignore any data fed |
| 1002 | /// to it. |
| 1003 | raw_ostream &getStream() const override { return os; } |
| 1004 | |
| 1005 | /// The following are hooks of `DialectAsmPrinter` that are not necessary for |
| 1006 | /// determining potential aliases. |
| 1007 | void printFloat(const APFloat &) override {} |
| 1008 | void printKeywordOrString(StringRef) override {} |
| 1009 | void printString(StringRef) override {} |
| 1010 | void printSymbolName(StringRef) override {} |
| 1011 | void printResourceHandle(const AsmDialectResourceHandle &) override {} |
| 1012 | |
| 1013 | LogicalResult pushCyclicPrinting(const void *opaquePointer) override { |
| 1014 | return success(IsSuccess: cyclicPrintingStack.insert(X: opaquePointer)); |
| 1015 | } |
| 1016 | |
| 1017 | void popCyclicPrinting() override { cyclicPrintingStack.pop_back(); } |
| 1018 | |
| 1019 | /// Stack of potentially cyclic mutable attributes or type currently being |
| 1020 | /// printed. |
| 1021 | SetVector<const void *> cyclicPrintingStack; |
| 1022 | |
| 1023 | /// The initializer to use when identifying aliases. |
| 1024 | AliasInitializer &initializer; |
| 1025 | |
| 1026 | /// If the aliases visited by this printer can be deferred. |
| 1027 | bool canBeDeferred; |
| 1028 | |
| 1029 | /// The indices of child aliases. |
| 1030 | SmallVectorImpl<size_t> &childIndices; |
| 1031 | |
| 1032 | /// The maximum alias depth found by the printer. |
| 1033 | size_t maxAliasDepth = 0; |
| 1034 | |
| 1035 | /// A dummy output stream. |
| 1036 | mutable llvm::raw_null_ostream os; |
| 1037 | }; |
| 1038 | } // namespace |
| 1039 | |
| 1040 | /// Sanitize the given name such that it can be used as a valid identifier. If |
| 1041 | /// the string needs to be modified in any way, the provided buffer is used to |
| 1042 | /// store the new copy, |
| 1043 | static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, |
| 1044 | StringRef allowedPunctChars = "$._-" ) { |
| 1045 | assert(!name.empty() && "Shouldn't have an empty name here" ); |
| 1046 | |
| 1047 | auto validChar = [&](char ch) { |
| 1048 | return llvm::isAlnum(C: ch) || allowedPunctChars.contains(C: ch); |
| 1049 | }; |
| 1050 | |
| 1051 | auto copyNameToBuffer = [&] { |
| 1052 | for (char ch : name) { |
| 1053 | if (validChar(ch)) |
| 1054 | buffer.push_back(Elt: ch); |
| 1055 | else if (ch == ' ') |
| 1056 | buffer.push_back(Elt: '_'); |
| 1057 | else |
| 1058 | buffer.append(RHS: llvm::utohexstr(X: (unsigned char)ch)); |
| 1059 | } |
| 1060 | }; |
| 1061 | |
| 1062 | // Check to see if this name is valid. If it starts with a digit, then it |
| 1063 | // could conflict with the autogenerated numeric ID's, so add an underscore |
| 1064 | // prefix to avoid problems. |
| 1065 | if (isdigit(name[0]) || (!validChar(name[0]) && name[0] != ' ')) { |
| 1066 | buffer.push_back(Elt: '_'); |
| 1067 | copyNameToBuffer(); |
| 1068 | return buffer; |
| 1069 | } |
| 1070 | |
| 1071 | // Check to see that the name consists of only valid identifier characters. |
| 1072 | for (char ch : name) { |
| 1073 | if (!validChar(ch)) { |
| 1074 | copyNameToBuffer(); |
| 1075 | return buffer; |
| 1076 | } |
| 1077 | } |
| 1078 | |
| 1079 | // If there are no invalid characters, return the original name. |
| 1080 | return name; |
| 1081 | } |
| 1082 | |
| 1083 | unsigned AliasInitializer::uniqueAliasNameIndex( |
| 1084 | StringRef alias, llvm::StringMap<unsigned> &nameCounts, |
| 1085 | llvm::StringSet<llvm::BumpPtrAllocator &> &usedAliases) { |
| 1086 | if (!usedAliases.count(Key: alias)) { |
| 1087 | usedAliases.insert(key: alias); |
| 1088 | // 0 is not printed in SymbolAlias. |
| 1089 | return 0; |
| 1090 | } |
| 1091 | // Otherwise, we had a conflict - probe until we find a unique name. |
| 1092 | SmallString<64> probeAlias(alias); |
| 1093 | // alias with trailing digit will be printed as _N |
| 1094 | if (isdigit(alias.back())) |
| 1095 | probeAlias.push_back(Elt: '_'); |
| 1096 | // nameCounts start from 1 because 0 is not printed in SymbolAlias. |
| 1097 | if (nameCounts[probeAlias] == 0) |
| 1098 | nameCounts[probeAlias] = 1; |
| 1099 | // This is guaranteed to terminate (and usually in a single iteration) |
| 1100 | // because it generates new names by incrementing nameCounts. |
| 1101 | while (true) { |
| 1102 | unsigned nameIndex = nameCounts[probeAlias]++; |
| 1103 | probeAlias += llvm::utostr(X: nameIndex); |
| 1104 | if (!usedAliases.count(Key: probeAlias)) { |
| 1105 | usedAliases.insert(key: probeAlias); |
| 1106 | return nameIndex; |
| 1107 | } |
| 1108 | // Reset probeAlias to the original alias for the next iteration. |
| 1109 | probeAlias.resize(N: alias.size() + isdigit(alias.back()) ? 1 : 0); |
| 1110 | } |
| 1111 | } |
| 1112 | |
| 1113 | /// Given a collection of aliases and symbols, initialize a mapping from a |
| 1114 | /// symbol to a given alias. |
| 1115 | void AliasInitializer::initializeAliases( |
| 1116 | llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols, |
| 1117 | llvm::MapVector<const void *, SymbolAlias> &symbolToAlias) { |
| 1118 | SmallVector<std::pair<const void *, InProgressAliasInfo>, 0> |
| 1119 | unprocessedAliases = visitedSymbols.takeVector(); |
| 1120 | llvm::stable_sort(Range&: unprocessedAliases, C: llvm::less_second()); |
| 1121 | |
| 1122 | // This keeps track of all of the non-numeric names that are in flight, |
| 1123 | // allowing us to check for duplicates. |
| 1124 | llvm::BumpPtrAllocator usedAliasAllocator; |
| 1125 | llvm::StringSet<llvm::BumpPtrAllocator &> usedAliases(usedAliasAllocator); |
| 1126 | |
| 1127 | llvm::StringMap<unsigned> nameCounts; |
| 1128 | for (auto &[symbol, aliasInfo] : unprocessedAliases) { |
| 1129 | if (!aliasInfo.alias) |
| 1130 | continue; |
| 1131 | StringRef alias = *aliasInfo.alias; |
| 1132 | unsigned nameIndex = uniqueAliasNameIndex(alias, nameCounts, usedAliases); |
| 1133 | symbolToAlias.insert( |
| 1134 | KV: {symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType, |
| 1135 | aliasInfo.canBeDeferred)}); |
| 1136 | } |
| 1137 | } |
| 1138 | |
| 1139 | void AliasInitializer::initialize( |
| 1140 | Operation *op, const OpPrintingFlags &printerFlags, |
| 1141 | llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias) { |
| 1142 | // Use a dummy printer when walking the IR so that we can collect the |
| 1143 | // attributes/types that will actually be used during printing when |
| 1144 | // considering aliases. |
| 1145 | DummyAliasOperationPrinter aliasPrinter(printerFlags, *this); |
| 1146 | aliasPrinter.printCustomOrGenericOp(op); |
| 1147 | |
| 1148 | // Initialize the aliases. |
| 1149 | initializeAliases(visitedSymbols&: aliases, symbolToAlias&: attrTypeToAlias); |
| 1150 | } |
| 1151 | |
| 1152 | template <typename T, typename... PrintArgs> |
| 1153 | std::pair<size_t, size_t> AliasInitializer::visitImpl( |
| 1154 | T value, llvm::MapVector<const void *, InProgressAliasInfo> &aliases, |
| 1155 | bool canBeDeferred, PrintArgs &&...printArgs) { |
| 1156 | auto [it, inserted] = aliases.try_emplace(value.getAsOpaquePointer()); |
| 1157 | size_t aliasIndex = std::distance(aliases.begin(), it); |
| 1158 | if (!inserted) { |
| 1159 | // Make sure that the alias isn't deferred if we don't permit it. |
| 1160 | if (!canBeDeferred) |
| 1161 | markAliasNonDeferrable(aliasIndex); |
| 1162 | return {static_cast<size_t>(it->second.aliasDepth), aliasIndex}; |
| 1163 | } |
| 1164 | |
| 1165 | // Try to generate an alias for this value. |
| 1166 | generateAlias(value, it->second, canBeDeferred); |
| 1167 | it->second.isType = std::is_base_of_v<Type, T>; |
| 1168 | it->second.canBeDeferred = canBeDeferred; |
| 1169 | |
| 1170 | // Print the value, capturing any nested elements that require aliases. |
| 1171 | SmallVector<size_t> childAliases; |
| 1172 | DummyAliasDialectAsmPrinter printer(*this, canBeDeferred, childAliases); |
| 1173 | size_t maxAliasDepth = |
| 1174 | printer.printAndVisitNestedAliases(value, printArgs...); |
| 1175 | |
| 1176 | // Make sure to recompute `it` in case the map was reallocated. |
| 1177 | it = std::next(x: aliases.begin(), n: aliasIndex); |
| 1178 | |
| 1179 | // If we had sub elements, update to account for the depth. |
| 1180 | it->second.childIndices = std::move(childAliases); |
| 1181 | if (maxAliasDepth) |
| 1182 | it->second.aliasDepth = maxAliasDepth + 1; |
| 1183 | |
| 1184 | // Propagate the alias depth of the value. |
| 1185 | return {(size_t)it->second.aliasDepth, aliasIndex}; |
| 1186 | } |
| 1187 | |
| 1188 | void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) { |
| 1189 | auto *it = std::next(x: aliases.begin(), n: aliasIndex); |
| 1190 | |
| 1191 | // If already marked non-deferrable stop the recursion. |
| 1192 | // All children should already be marked non-deferrable as well. |
| 1193 | if (!it->second.canBeDeferred) |
| 1194 | return; |
| 1195 | |
| 1196 | it->second.canBeDeferred = false; |
| 1197 | |
| 1198 | // Propagate the non-deferrable flag to any child aliases. |
| 1199 | for (size_t childIndex : it->second.childIndices) |
| 1200 | markAliasNonDeferrable(aliasIndex: childIndex); |
| 1201 | } |
| 1202 | |
| 1203 | template <typename T> |
| 1204 | void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias, |
| 1205 | bool canBeDeferred) { |
| 1206 | SmallString<32> nameBuffer; |
| 1207 | |
| 1208 | OpAsmDialectInterface::AliasResult symbolInterfaceResult = |
| 1209 | OpAsmDialectInterface::AliasResult::NoAlias; |
| 1210 | using InterfaceT = std::conditional_t<std::is_base_of_v<Attribute, T>, |
| 1211 | OpAsmAttrInterface, OpAsmTypeInterface>; |
| 1212 | if (auto symbolInterface = dyn_cast<InterfaceT>(symbol)) { |
| 1213 | symbolInterfaceResult = symbolInterface.getAlias(aliasOS); |
| 1214 | if (symbolInterfaceResult != OpAsmDialectInterface::AliasResult::NoAlias) { |
| 1215 | nameBuffer = std::move(aliasBuffer); |
| 1216 | assert(!nameBuffer.empty() && "expected valid alias name" ); |
| 1217 | } |
| 1218 | } |
| 1219 | |
| 1220 | if (symbolInterfaceResult != OpAsmDialectInterface::AliasResult::FinalAlias) { |
| 1221 | for (const auto &interface : interfaces) { |
| 1222 | OpAsmDialectInterface::AliasResult result = |
| 1223 | interface.getAlias(symbol, aliasOS); |
| 1224 | if (result == OpAsmDialectInterface::AliasResult::NoAlias) |
| 1225 | continue; |
| 1226 | nameBuffer = std::move(aliasBuffer); |
| 1227 | assert(!nameBuffer.empty() && "expected valid alias name" ); |
| 1228 | if (result == OpAsmDialectInterface::AliasResult::FinalAlias) |
| 1229 | break; |
| 1230 | } |
| 1231 | } |
| 1232 | |
| 1233 | if (nameBuffer.empty()) |
| 1234 | return; |
| 1235 | |
| 1236 | SmallString<16> tempBuffer; |
| 1237 | StringRef name = |
| 1238 | sanitizeIdentifier(name: nameBuffer, buffer&: tempBuffer, /*allowedPunctChars=*/"$_-" ); |
| 1239 | name = name.copy(A&: aliasAllocator); |
| 1240 | alias = InProgressAliasInfo(name); |
| 1241 | } |
| 1242 | |
| 1243 | //===----------------------------------------------------------------------===// |
| 1244 | // AliasState |
| 1245 | //===----------------------------------------------------------------------===// |
| 1246 | |
| 1247 | namespace { |
| 1248 | /// This class manages the state for type and attribute aliases. |
| 1249 | class AliasState { |
| 1250 | public: |
| 1251 | // Initialize the internal aliases. |
| 1252 | void |
| 1253 | initialize(Operation *op, const OpPrintingFlags &printerFlags, |
| 1254 | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); |
| 1255 | |
| 1256 | /// Get an alias for the given attribute if it has one and print it in `os`. |
| 1257 | /// Returns success if an alias was printed, failure otherwise. |
| 1258 | LogicalResult getAlias(Attribute attr, raw_ostream &os) const; |
| 1259 | |
| 1260 | /// Get an alias for the given type if it has one and print it in `os`. |
| 1261 | /// Returns success if an alias was printed, failure otherwise. |
| 1262 | LogicalResult getAlias(Type ty, raw_ostream &os) const; |
| 1263 | |
| 1264 | /// Print all of the referenced aliases that can not be resolved in a deferred |
| 1265 | /// manner. |
| 1266 | void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) { |
| 1267 | printAliases(p, newLine, /*isDeferred=*/false); |
| 1268 | } |
| 1269 | |
| 1270 | /// Print all of the referenced aliases that support deferred resolution. |
| 1271 | void printDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) { |
| 1272 | printAliases(p, newLine, /*isDeferred=*/true); |
| 1273 | } |
| 1274 | |
| 1275 | private: |
| 1276 | /// Print all of the referenced aliases that support the provided resolution |
| 1277 | /// behavior. |
| 1278 | void printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, |
| 1279 | bool isDeferred); |
| 1280 | |
| 1281 | /// Mapping between attribute/type and alias. |
| 1282 | llvm::MapVector<const void *, SymbolAlias> attrTypeToAlias; |
| 1283 | |
| 1284 | /// An allocator used for alias names. |
| 1285 | llvm::BumpPtrAllocator aliasAllocator; |
| 1286 | }; |
| 1287 | } // namespace |
| 1288 | |
| 1289 | void AliasState::initialize( |
| 1290 | Operation *op, const OpPrintingFlags &printerFlags, |
| 1291 | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { |
| 1292 | AliasInitializer initializer(interfaces, aliasAllocator); |
| 1293 | initializer.initialize(op, printerFlags, attrTypeToAlias); |
| 1294 | } |
| 1295 | |
| 1296 | LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const { |
| 1297 | const auto *it = attrTypeToAlias.find(Key: attr.getAsOpaquePointer()); |
| 1298 | if (it == attrTypeToAlias.end()) |
| 1299 | return failure(); |
| 1300 | it->second.print(os); |
| 1301 | return success(); |
| 1302 | } |
| 1303 | |
| 1304 | LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const { |
| 1305 | const auto *it = attrTypeToAlias.find(Key: ty.getAsOpaquePointer()); |
| 1306 | if (it == attrTypeToAlias.end()) |
| 1307 | return failure(); |
| 1308 | if (!it->second.isPrinted) |
| 1309 | return failure(); |
| 1310 | |
| 1311 | it->second.print(os); |
| 1312 | return success(); |
| 1313 | } |
| 1314 | |
| 1315 | void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, |
| 1316 | bool isDeferred) { |
| 1317 | auto filterFn = [=](const auto &aliasIt) { |
| 1318 | return aliasIt.second.canBeDeferred() == isDeferred; |
| 1319 | }; |
| 1320 | for (auto &[opaqueSymbol, alias] : |
| 1321 | llvm::make_filter_range(Range&: attrTypeToAlias, Pred: filterFn)) { |
| 1322 | alias.print(os&: p.getStream()); |
| 1323 | p.getStream() << " = " ; |
| 1324 | |
| 1325 | if (alias.isTypeAlias()) { |
| 1326 | Type type = Type::getFromOpaquePointer(pointer: opaqueSymbol); |
| 1327 | p.printTypeImpl(type); |
| 1328 | alias.isPrinted = true; |
| 1329 | } else { |
| 1330 | // TODO: Support nested aliases in mutable attributes. |
| 1331 | Attribute attr = Attribute::getFromOpaquePointer(ptr: opaqueSymbol); |
| 1332 | if (attr.hasTrait<AttributeTrait::IsMutable>()) |
| 1333 | p.getStream() << attr; |
| 1334 | else |
| 1335 | p.printAttributeImpl(attr); |
| 1336 | } |
| 1337 | |
| 1338 | p.getStream() << newLine; |
| 1339 | } |
| 1340 | } |
| 1341 | |
| 1342 | //===----------------------------------------------------------------------===// |
| 1343 | // SSANameState |
| 1344 | //===----------------------------------------------------------------------===// |
| 1345 | |
| 1346 | namespace { |
| 1347 | /// Info about block printing: a number which is its position in the visitation |
| 1348 | /// order, and a name that is used to print reference to it, e.g. ^bb42. |
| 1349 | struct BlockInfo { |
| 1350 | int ordering; |
| 1351 | StringRef name; |
| 1352 | }; |
| 1353 | |
| 1354 | /// This class manages the state of SSA value names. |
| 1355 | class SSANameState { |
| 1356 | public: |
| 1357 | /// A sentinel value used for values with names set. |
| 1358 | enum : unsigned { NameSentinel = ~0U }; |
| 1359 | |
| 1360 | SSANameState(Operation *op, const OpPrintingFlags &printerFlags); |
| 1361 | SSANameState() = default; |
| 1362 | |
| 1363 | /// Print the SSA identifier for the given value to 'stream'. If |
| 1364 | /// 'printResultNo' is true, it also presents the result number ('#' number) |
| 1365 | /// of this value. |
| 1366 | void printValueID(Value value, bool printResultNo, raw_ostream &stream) const; |
| 1367 | |
| 1368 | /// Print the operation identifier. |
| 1369 | void printOperationID(Operation *op, raw_ostream &stream) const; |
| 1370 | |
| 1371 | /// Return the result indices for each of the result groups registered by this |
| 1372 | /// operation, or empty if none exist. |
| 1373 | ArrayRef<int> getOpResultGroups(Operation *op); |
| 1374 | |
| 1375 | /// Get the info for the given block. |
| 1376 | BlockInfo getBlockInfo(Block *block); |
| 1377 | |
| 1378 | /// Renumber the arguments for the specified region to the same names as the |
| 1379 | /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for |
| 1380 | /// details. |
| 1381 | void shadowRegionArgs(Region ®ion, ValueRange namesToUse); |
| 1382 | |
| 1383 | private: |
| 1384 | /// Number the SSA values within the given IR unit. |
| 1385 | void numberValuesInRegion(Region ®ion); |
| 1386 | void numberValuesInBlock(Block &block); |
| 1387 | void numberValuesInOp(Operation &op); |
| 1388 | |
| 1389 | /// Given a result of an operation 'result', find the result group head |
| 1390 | /// 'lookupValue' and the result of 'result' within that group in |
| 1391 | /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group |
| 1392 | /// has more than 1 result. |
| 1393 | void getResultIDAndNumber(OpResult result, Value &lookupValue, |
| 1394 | std::optional<int> &lookupResultNo) const; |
| 1395 | |
| 1396 | /// Set a special value name for the given value. |
| 1397 | void setValueName(Value value, StringRef name); |
| 1398 | |
| 1399 | /// Uniques the given value name within the printer. If the given name |
| 1400 | /// conflicts, it is automatically renamed. |
| 1401 | StringRef uniqueValueName(StringRef name); |
| 1402 | |
| 1403 | /// This is the value ID for each SSA value. If this returns NameSentinel, |
| 1404 | /// then the valueID has an entry in valueNames. |
| 1405 | DenseMap<Value, unsigned> valueIDs; |
| 1406 | DenseMap<Value, StringRef> valueNames; |
| 1407 | |
| 1408 | /// When printing users of values, an operation without a result might |
| 1409 | /// be the user. This map holds ids for such operations. |
| 1410 | DenseMap<Operation *, unsigned> operationIDs; |
| 1411 | |
| 1412 | /// This is a map of operations that contain multiple named result groups, |
| 1413 | /// i.e. there may be multiple names for the results of the operation. The |
| 1414 | /// value of this map are the result numbers that start a result group. |
| 1415 | DenseMap<Operation *, SmallVector<int, 1>> opResultGroups; |
| 1416 | |
| 1417 | /// This maps blocks to there visitation number in the current region as well |
| 1418 | /// as the string representing their name. |
| 1419 | DenseMap<Block *, BlockInfo> blockNames; |
| 1420 | |
| 1421 | /// This keeps track of all of the non-numeric names that are in flight, |
| 1422 | /// allowing us to check for duplicates. |
| 1423 | /// Note: the value of the map is unused. |
| 1424 | llvm::ScopedHashTable<StringRef, char> usedNames; |
| 1425 | llvm::BumpPtrAllocator usedNameAllocator; |
| 1426 | |
| 1427 | /// This is the next value ID to assign in numbering. |
| 1428 | unsigned nextValueID = 0; |
| 1429 | /// This is the next ID to assign to a region entry block argument. |
| 1430 | unsigned nextArgumentID = 0; |
| 1431 | /// This is the next ID to assign when a name conflict is detected. |
| 1432 | unsigned nextConflictID = 0; |
| 1433 | |
| 1434 | /// These are the printing flags. They control, eg., whether to print in |
| 1435 | /// generic form. |
| 1436 | OpPrintingFlags printerFlags; |
| 1437 | }; |
| 1438 | } // namespace |
| 1439 | |
| 1440 | SSANameState::SSANameState(Operation *op, const OpPrintingFlags &printerFlags) |
| 1441 | : printerFlags(printerFlags) { |
| 1442 | llvm::SaveAndRestore valueIDSaver(nextValueID); |
| 1443 | llvm::SaveAndRestore argumentIDSaver(nextArgumentID); |
| 1444 | llvm::SaveAndRestore conflictIDSaver(nextConflictID); |
| 1445 | |
| 1446 | // The naming context includes `nextValueID`, `nextArgumentID`, |
| 1447 | // `nextConflictID` and `usedNames` scoped HashTable. This information is |
| 1448 | // carried from the parent region. |
| 1449 | using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy; |
| 1450 | using NamingContext = |
| 1451 | std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>; |
| 1452 | |
| 1453 | // Allocator for UsedNamesScopeTy |
| 1454 | llvm::BumpPtrAllocator allocator; |
| 1455 | |
| 1456 | // Add a scope for the top level operation. |
| 1457 | auto *topLevelNamesScope = |
| 1458 | new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames); |
| 1459 | |
| 1460 | SmallVector<NamingContext, 8> nameContext; |
| 1461 | for (Region ®ion : op->getRegions()) |
| 1462 | nameContext.push_back(Elt: std::make_tuple(args: ®ion, args&: nextValueID, args&: nextArgumentID, |
| 1463 | args&: nextConflictID, args&: topLevelNamesScope)); |
| 1464 | |
| 1465 | numberValuesInOp(op&: *op); |
| 1466 | |
| 1467 | while (!nameContext.empty()) { |
| 1468 | Region *region; |
| 1469 | UsedNamesScopeTy *parentScope; |
| 1470 | |
| 1471 | if (printerFlags.shouldPrintUniqueSSAIDs()) |
| 1472 | // To print unique SSA IDs, ignore saved ID counts from parent regions |
| 1473 | std::tie(args&: region, args: std::ignore, args: std::ignore, args: std::ignore, args&: parentScope) = |
| 1474 | nameContext.pop_back_val(); |
| 1475 | else |
| 1476 | std::tie(args&: region, args&: nextValueID, args&: nextArgumentID, args&: nextConflictID, |
| 1477 | args&: parentScope) = nameContext.pop_back_val(); |
| 1478 | |
| 1479 | // When we switch from one subtree to another, pop the scopes(needless) |
| 1480 | // until the parent scope. |
| 1481 | while (usedNames.getCurScope() != parentScope) { |
| 1482 | usedNames.getCurScope()->~UsedNamesScopeTy(); |
| 1483 | assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) && |
| 1484 | "top level parentScope must be a nullptr" ); |
| 1485 | } |
| 1486 | |
| 1487 | // Add a scope for the current region. |
| 1488 | auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>()) |
| 1489 | UsedNamesScopeTy(usedNames); |
| 1490 | |
| 1491 | numberValuesInRegion(region&: *region); |
| 1492 | |
| 1493 | for (Operation &op : region->getOps()) |
| 1494 | for (Region ®ion : op.getRegions()) |
| 1495 | nameContext.push_back(Elt: std::make_tuple(args: ®ion, args&: nextValueID, |
| 1496 | args&: nextArgumentID, args&: nextConflictID, |
| 1497 | args&: curNamesScope)); |
| 1498 | } |
| 1499 | |
| 1500 | // Manually remove all the scopes. |
| 1501 | while (usedNames.getCurScope() != nullptr) |
| 1502 | usedNames.getCurScope()->~UsedNamesScopeTy(); |
| 1503 | } |
| 1504 | |
| 1505 | void SSANameState::printValueID(Value value, bool printResultNo, |
| 1506 | raw_ostream &stream) const { |
| 1507 | if (!value) { |
| 1508 | stream << "<<NULL VALUE>>" ; |
| 1509 | return; |
| 1510 | } |
| 1511 | |
| 1512 | std::optional<int> resultNo; |
| 1513 | auto lookupValue = value; |
| 1514 | |
| 1515 | // If this is an operation result, collect the head lookup value of the result |
| 1516 | // group and the result number of 'result' within that group. |
| 1517 | if (OpResult result = dyn_cast<OpResult>(Val&: value)) |
| 1518 | getResultIDAndNumber(result, lookupValue, lookupResultNo&: resultNo); |
| 1519 | |
| 1520 | auto it = valueIDs.find(Val: lookupValue); |
| 1521 | if (it == valueIDs.end()) { |
| 1522 | stream << "<<UNKNOWN SSA VALUE>>" ; |
| 1523 | return; |
| 1524 | } |
| 1525 | |
| 1526 | stream << '%'; |
| 1527 | if (it->second != NameSentinel) { |
| 1528 | stream << it->second; |
| 1529 | } else { |
| 1530 | auto nameIt = valueNames.find(Val: lookupValue); |
| 1531 | assert(nameIt != valueNames.end() && "Didn't have a name entry?" ); |
| 1532 | stream << nameIt->second; |
| 1533 | } |
| 1534 | |
| 1535 | if (resultNo && printResultNo) |
| 1536 | stream << '#' << *resultNo; |
| 1537 | } |
| 1538 | |
| 1539 | void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const { |
| 1540 | auto it = operationIDs.find(Val: op); |
| 1541 | if (it == operationIDs.end()) { |
| 1542 | stream << "<<UNKNOWN OPERATION>>" ; |
| 1543 | } else { |
| 1544 | stream << '%' << it->second; |
| 1545 | } |
| 1546 | } |
| 1547 | |
| 1548 | ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) { |
| 1549 | auto it = opResultGroups.find(Val: op); |
| 1550 | return it == opResultGroups.end() ? ArrayRef<int>() : it->second; |
| 1551 | } |
| 1552 | |
| 1553 | BlockInfo SSANameState::getBlockInfo(Block *block) { |
| 1554 | auto it = blockNames.find(Val: block); |
| 1555 | BlockInfo invalidBlock{.ordering: -1, .name: "INVALIDBLOCK" }; |
| 1556 | return it != blockNames.end() ? it->second : invalidBlock; |
| 1557 | } |
| 1558 | |
| 1559 | void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { |
| 1560 | assert(!region.empty() && "cannot shadow arguments of an empty region" ); |
| 1561 | assert(region.getNumArguments() == namesToUse.size() && |
| 1562 | "incorrect number of names passed in" ); |
| 1563 | assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() && |
| 1564 | "only KnownIsolatedFromAbove ops can shadow names" ); |
| 1565 | |
| 1566 | SmallVector<char, 16> nameStr; |
| 1567 | for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { |
| 1568 | auto nameToUse = namesToUse[i]; |
| 1569 | if (nameToUse == nullptr) |
| 1570 | continue; |
| 1571 | auto nameToReplace = region.getArgument(i); |
| 1572 | |
| 1573 | nameStr.clear(); |
| 1574 | llvm::raw_svector_ostream nameStream(nameStr); |
| 1575 | printValueID(value: nameToUse, /*printResultNo=*/true, stream&: nameStream); |
| 1576 | |
| 1577 | // Entry block arguments should already have a pretty "arg" name. |
| 1578 | assert(valueIDs[nameToReplace] == NameSentinel); |
| 1579 | |
| 1580 | // Use the name without the leading %. |
| 1581 | auto name = StringRef(nameStream.str()).drop_front(); |
| 1582 | |
| 1583 | // Overwrite the name. |
| 1584 | valueNames[nameToReplace] = name.copy(A&: usedNameAllocator); |
| 1585 | } |
| 1586 | } |
| 1587 | |
| 1588 | namespace { |
| 1589 | /// Try to get value name from value's location, fallback to `name`. |
| 1590 | StringRef maybeGetValueNameFromLoc(Value value, StringRef name) { |
| 1591 | if (auto maybeNameLoc = value.getLoc()->findInstanceOf<NameLoc>()) |
| 1592 | return maybeNameLoc.getName(); |
| 1593 | return name; |
| 1594 | } |
| 1595 | } // namespace |
| 1596 | |
| 1597 | void SSANameState::numberValuesInRegion(Region ®ion) { |
| 1598 | // Indicates whether OpAsmOpInterface set a name. |
| 1599 | bool opAsmOpInterfaceUsed = false; |
| 1600 | auto setBlockArgNameFn = [&](Value arg, StringRef name) { |
| 1601 | assert(!valueIDs.count(arg) && "arg numbered multiple times" ); |
| 1602 | assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == ®ion && |
| 1603 | "arg not defined in current region" ); |
| 1604 | opAsmOpInterfaceUsed = true; |
| 1605 | if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix())) |
| 1606 | name = maybeGetValueNameFromLoc(value: arg, name); |
| 1607 | setValueName(value: arg, name); |
| 1608 | }; |
| 1609 | |
| 1610 | if (!printerFlags.shouldPrintGenericOpForm()) { |
| 1611 | if (Operation *op = region.getParentOp()) { |
| 1612 | if (auto asmInterface = dyn_cast<OpAsmOpInterface>(Val: op)) |
| 1613 | asmInterface.getAsmBlockArgumentNames(region, setNameFn: setBlockArgNameFn); |
| 1614 | // If the OpAsmOpInterface didn't set a name, get name from the type. |
| 1615 | if (!opAsmOpInterfaceUsed) { |
| 1616 | for (BlockArgument arg : region.getArguments()) { |
| 1617 | if (auto interface = dyn_cast<OpAsmTypeInterface>(Val: arg.getType())) { |
| 1618 | interface.getAsmName( |
| 1619 | setNameFn: [&](StringRef name) { setBlockArgNameFn(arg, name); }); |
| 1620 | } |
| 1621 | } |
| 1622 | } |
| 1623 | } |
| 1624 | } |
| 1625 | |
| 1626 | // Number the values within this region in a breadth-first order. |
| 1627 | unsigned nextBlockID = 0; |
| 1628 | for (auto &block : region) { |
| 1629 | // Each block gets a unique ID, and all of the operations within it get |
| 1630 | // numbered as well. |
| 1631 | auto blockInfoIt = blockNames.insert(KV: {&block, {.ordering: -1, .name: "" }}); |
| 1632 | if (blockInfoIt.second) { |
| 1633 | // This block hasn't been named through `getAsmBlockArgumentNames`, use |
| 1634 | // default `^bbNNN` format. |
| 1635 | std::string name; |
| 1636 | llvm::raw_string_ostream(name) << "^bb" << nextBlockID; |
| 1637 | blockInfoIt.first->second.name = StringRef(name).copy(A&: usedNameAllocator); |
| 1638 | } |
| 1639 | blockInfoIt.first->second.ordering = nextBlockID++; |
| 1640 | |
| 1641 | numberValuesInBlock(block); |
| 1642 | } |
| 1643 | } |
| 1644 | |
| 1645 | void SSANameState::numberValuesInBlock(Block &block) { |
| 1646 | // Number the block arguments. We give entry block arguments a special name |
| 1647 | // 'arg'. |
| 1648 | bool isEntryBlock = block.isEntryBlock(); |
| 1649 | SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "" ); |
| 1650 | llvm::raw_svector_ostream specialName(specialNameBuffer); |
| 1651 | for (auto arg : block.getArguments()) { |
| 1652 | if (valueIDs.count(Val: arg)) |
| 1653 | continue; |
| 1654 | if (isEntryBlock) { |
| 1655 | specialNameBuffer.resize(N: strlen(s: "arg" )); |
| 1656 | specialName << nextArgumentID++; |
| 1657 | } |
| 1658 | StringRef specialNameStr = specialName.str(); |
| 1659 | if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix())) |
| 1660 | specialNameStr = maybeGetValueNameFromLoc(value: arg, name: specialNameStr); |
| 1661 | setValueName(value: arg, name: specialNameStr); |
| 1662 | } |
| 1663 | |
| 1664 | // Number the operations in this block. |
| 1665 | for (auto &op : block) |
| 1666 | numberValuesInOp(op); |
| 1667 | } |
| 1668 | |
| 1669 | void SSANameState::numberValuesInOp(Operation &op) { |
| 1670 | // Function used to set the special result names for the operation. |
| 1671 | SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0); |
| 1672 | // Indicates whether OpAsmOpInterface set a name. |
| 1673 | bool opAsmOpInterfaceUsed = false; |
| 1674 | auto setResultNameFn = [&](Value result, StringRef name) { |
| 1675 | assert(!valueIDs.count(result) && "result numbered multiple times" ); |
| 1676 | assert(result.getDefiningOp() == &op && "result not defined by 'op'" ); |
| 1677 | opAsmOpInterfaceUsed = true; |
| 1678 | if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix())) |
| 1679 | name = maybeGetValueNameFromLoc(value: result, name); |
| 1680 | setValueName(value: result, name); |
| 1681 | |
| 1682 | // Record the result number for groups not anchored at 0. |
| 1683 | if (int resultNo = llvm::cast<OpResult>(Val&: result).getResultNumber()) |
| 1684 | resultGroups.push_back(Elt: resultNo); |
| 1685 | }; |
| 1686 | // Operations can customize the printing of block names in OpAsmOpInterface. |
| 1687 | auto setBlockNameFn = [&](Block *block, StringRef name) { |
| 1688 | assert(block->getParentOp() == &op && |
| 1689 | "getAsmBlockArgumentNames callback invoked on a block not directly " |
| 1690 | "nested under the current operation" ); |
| 1691 | assert(!blockNames.count(block) && "block numbered multiple times" ); |
| 1692 | SmallString<16> tmpBuffer{"^" }; |
| 1693 | name = sanitizeIdentifier(name, buffer&: tmpBuffer); |
| 1694 | if (name.data() != tmpBuffer.data()) { |
| 1695 | tmpBuffer.append(RHS: name); |
| 1696 | name = tmpBuffer.str(); |
| 1697 | } |
| 1698 | name = name.copy(A&: usedNameAllocator); |
| 1699 | blockNames[block] = {.ordering: -1, .name: name}; |
| 1700 | }; |
| 1701 | |
| 1702 | if (!printerFlags.shouldPrintGenericOpForm()) { |
| 1703 | if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(Val: &op)) { |
| 1704 | asmInterface.getAsmBlockNames(setNameFn: setBlockNameFn); |
| 1705 | asmInterface.getAsmResultNames(setNameFn: setResultNameFn); |
| 1706 | } |
| 1707 | if (!opAsmOpInterfaceUsed) { |
| 1708 | // If the OpAsmOpInterface didn't set a name, and all results have |
| 1709 | // OpAsmTypeInterface, get names from types. |
| 1710 | bool allHaveOpAsmTypeInterface = |
| 1711 | llvm::all_of(Range: op.getResultTypes(), P: [&](Type type) { |
| 1712 | return isa<OpAsmTypeInterface>(Val: type); |
| 1713 | }); |
| 1714 | if (allHaveOpAsmTypeInterface) { |
| 1715 | for (OpResult result : op.getResults()) { |
| 1716 | auto interface = cast<OpAsmTypeInterface>(Val: result.getType()); |
| 1717 | interface.getAsmName( |
| 1718 | setNameFn: [&](StringRef name) { setResultNameFn(result, name); }); |
| 1719 | } |
| 1720 | } |
| 1721 | } |
| 1722 | } |
| 1723 | |
| 1724 | unsigned numResults = op.getNumResults(); |
| 1725 | if (numResults == 0) { |
| 1726 | // If value users should be printed, operations with no result need an id. |
| 1727 | if (printerFlags.shouldPrintValueUsers()) { |
| 1728 | if (operationIDs.try_emplace(Key: &op, Args&: nextValueID).second) |
| 1729 | ++nextValueID; |
| 1730 | } |
| 1731 | return; |
| 1732 | } |
| 1733 | Value resultBegin = op.getResult(idx: 0); |
| 1734 | |
| 1735 | if (printerFlags.shouldUseNameLocAsPrefix() && !valueIDs.count(Val: resultBegin)) { |
| 1736 | if (auto nameLoc = resultBegin.getLoc()->findInstanceOf<NameLoc>()) { |
| 1737 | setValueName(value: resultBegin, name: nameLoc.getName()); |
| 1738 | } |
| 1739 | } |
| 1740 | |
| 1741 | // If the first result wasn't numbered, give it a default number. |
| 1742 | if (valueIDs.try_emplace(Key: resultBegin, Args&: nextValueID).second) |
| 1743 | ++nextValueID; |
| 1744 | |
| 1745 | // If this operation has multiple result groups, mark it. |
| 1746 | if (resultGroups.size() != 1) { |
| 1747 | llvm::array_pod_sort(Start: resultGroups.begin(), End: resultGroups.end()); |
| 1748 | opResultGroups.try_emplace(Key: &op, Args: std::move(resultGroups)); |
| 1749 | } |
| 1750 | } |
| 1751 | |
| 1752 | void SSANameState::getResultIDAndNumber( |
| 1753 | OpResult result, Value &lookupValue, |
| 1754 | std::optional<int> &lookupResultNo) const { |
| 1755 | Operation *owner = result.getOwner(); |
| 1756 | if (owner->getNumResults() == 1) |
| 1757 | return; |
| 1758 | int resultNo = result.getResultNumber(); |
| 1759 | |
| 1760 | // If this operation has multiple result groups, we will need to find the |
| 1761 | // one corresponding to this result. |
| 1762 | auto resultGroupIt = opResultGroups.find(Val: owner); |
| 1763 | if (resultGroupIt == opResultGroups.end()) { |
| 1764 | // If not, just use the first result. |
| 1765 | lookupResultNo = resultNo; |
| 1766 | lookupValue = owner->getResult(idx: 0); |
| 1767 | return; |
| 1768 | } |
| 1769 | |
| 1770 | // Find the correct index using a binary search, as the groups are ordered. |
| 1771 | ArrayRef<int> resultGroups = resultGroupIt->second; |
| 1772 | const auto *it = llvm::upper_bound(Range&: resultGroups, Value&: resultNo); |
| 1773 | int groupResultNo = 0, groupSize = 0; |
| 1774 | |
| 1775 | // If there are no smaller elements, the last result group is the lookup. |
| 1776 | if (it == resultGroups.end()) { |
| 1777 | groupResultNo = resultGroups.back(); |
| 1778 | groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back(); |
| 1779 | } else { |
| 1780 | // Otherwise, the previous element is the lookup. |
| 1781 | groupResultNo = *std::prev(x: it); |
| 1782 | groupSize = *it - groupResultNo; |
| 1783 | } |
| 1784 | |
| 1785 | // We only record the result number for a group of size greater than 1. |
| 1786 | if (groupSize != 1) |
| 1787 | lookupResultNo = resultNo - groupResultNo; |
| 1788 | lookupValue = owner->getResult(idx: groupResultNo); |
| 1789 | } |
| 1790 | |
| 1791 | void SSANameState::setValueName(Value value, StringRef name) { |
| 1792 | // If the name is empty, the value uses the default numbering. |
| 1793 | if (name.empty()) { |
| 1794 | valueIDs[value] = nextValueID++; |
| 1795 | return; |
| 1796 | } |
| 1797 | |
| 1798 | valueIDs[value] = NameSentinel; |
| 1799 | valueNames[value] = uniqueValueName(name); |
| 1800 | } |
| 1801 | |
| 1802 | StringRef SSANameState::uniqueValueName(StringRef name) { |
| 1803 | SmallString<16> tmpBuffer; |
| 1804 | name = sanitizeIdentifier(name, buffer&: tmpBuffer); |
| 1805 | |
| 1806 | // Check to see if this name is already unique. |
| 1807 | if (!usedNames.count(Key: name)) { |
| 1808 | name = name.copy(A&: usedNameAllocator); |
| 1809 | } else { |
| 1810 | // Otherwise, we had a conflict - probe until we find a unique name. This |
| 1811 | // is guaranteed to terminate (and usually in a single iteration) because it |
| 1812 | // generates new names by incrementing nextConflictID. |
| 1813 | SmallString<64> probeName(name); |
| 1814 | probeName.push_back(Elt: '_'); |
| 1815 | while (true) { |
| 1816 | probeName += llvm::utostr(X: nextConflictID++); |
| 1817 | if (!usedNames.count(Key: probeName)) { |
| 1818 | name = probeName.str().copy(A&: usedNameAllocator); |
| 1819 | break; |
| 1820 | } |
| 1821 | probeName.resize(N: name.size() + 1); |
| 1822 | } |
| 1823 | } |
| 1824 | |
| 1825 | usedNames.insert(Key: name, Val: char()); |
| 1826 | return name; |
| 1827 | } |
| 1828 | |
| 1829 | //===----------------------------------------------------------------------===// |
| 1830 | // DistinctState |
| 1831 | //===----------------------------------------------------------------------===// |
| 1832 | |
| 1833 | namespace { |
| 1834 | /// This class manages the state for distinct attributes. |
| 1835 | class DistinctState { |
| 1836 | public: |
| 1837 | /// Returns a unique identifier for the given distinct attribute. |
| 1838 | uint64_t getId(DistinctAttr distinctAttr); |
| 1839 | |
| 1840 | private: |
| 1841 | uint64_t distinctCounter = 0; |
| 1842 | DenseMap<DistinctAttr, uint64_t> distinctAttrMap; |
| 1843 | }; |
| 1844 | } // namespace |
| 1845 | |
| 1846 | uint64_t DistinctState::getId(DistinctAttr distinctAttr) { |
| 1847 | auto [it, inserted] = |
| 1848 | distinctAttrMap.try_emplace(Key: distinctAttr, Args&: distinctCounter); |
| 1849 | if (inserted) |
| 1850 | distinctCounter++; |
| 1851 | return it->getSecond(); |
| 1852 | } |
| 1853 | |
| 1854 | //===----------------------------------------------------------------------===// |
| 1855 | // Resources |
| 1856 | //===----------------------------------------------------------------------===// |
| 1857 | |
| 1858 | AsmParsedResourceEntry::~AsmParsedResourceEntry() = default; |
| 1859 | AsmResourceBuilder::~AsmResourceBuilder() = default; |
| 1860 | AsmResourceParser::~AsmResourceParser() = default; |
| 1861 | AsmResourcePrinter::~AsmResourcePrinter() = default; |
| 1862 | |
| 1863 | StringRef mlir::toString(AsmResourceEntryKind kind) { |
| 1864 | switch (kind) { |
| 1865 | case AsmResourceEntryKind::Blob: |
| 1866 | return "blob" ; |
| 1867 | case AsmResourceEntryKind::Bool: |
| 1868 | return "bool" ; |
| 1869 | case AsmResourceEntryKind::String: |
| 1870 | return "string" ; |
| 1871 | } |
| 1872 | llvm_unreachable("unknown AsmResourceEntryKind" ); |
| 1873 | } |
| 1874 | |
| 1875 | AsmResourceParser &FallbackAsmResourceMap::getParserFor(StringRef key) { |
| 1876 | std::unique_ptr<ResourceCollection> &collection = keyToResources[key.str()]; |
| 1877 | if (!collection) |
| 1878 | collection = std::make_unique<ResourceCollection>(args&: key); |
| 1879 | return *collection; |
| 1880 | } |
| 1881 | |
| 1882 | std::vector<std::unique_ptr<AsmResourcePrinter>> |
| 1883 | FallbackAsmResourceMap::getPrinters() { |
| 1884 | std::vector<std::unique_ptr<AsmResourcePrinter>> printers; |
| 1885 | for (auto &it : keyToResources) { |
| 1886 | ResourceCollection *collection = it.second.get(); |
| 1887 | auto buildValues = [=](Operation *op, AsmResourceBuilder &builder) { |
| 1888 | return collection->buildResources(op, builder); |
| 1889 | }; |
| 1890 | printers.emplace_back( |
| 1891 | args: AsmResourcePrinter::fromCallable(name: collection->getName(), printFn&: buildValues)); |
| 1892 | } |
| 1893 | return printers; |
| 1894 | } |
| 1895 | |
| 1896 | LogicalResult FallbackAsmResourceMap::ResourceCollection::parseResource( |
| 1897 | AsmParsedResourceEntry &entry) { |
| 1898 | switch (entry.getKind()) { |
| 1899 | case AsmResourceEntryKind::Blob: { |
| 1900 | FailureOr<AsmResourceBlob> blob = entry.parseAsBlob(); |
| 1901 | if (failed(Result: blob)) |
| 1902 | return failure(); |
| 1903 | resources.emplace_back(Args: entry.getKey(), Args: std::move(*blob)); |
| 1904 | return success(); |
| 1905 | } |
| 1906 | case AsmResourceEntryKind::Bool: { |
| 1907 | FailureOr<bool> value = entry.parseAsBool(); |
| 1908 | if (failed(Result: value)) |
| 1909 | return failure(); |
| 1910 | resources.emplace_back(Args: entry.getKey(), Args&: *value); |
| 1911 | break; |
| 1912 | } |
| 1913 | case AsmResourceEntryKind::String: { |
| 1914 | FailureOr<std::string> str = entry.parseAsString(); |
| 1915 | if (failed(Result: str)) |
| 1916 | return failure(); |
| 1917 | resources.emplace_back(Args: entry.getKey(), Args: std::move(*str)); |
| 1918 | break; |
| 1919 | } |
| 1920 | } |
| 1921 | return success(); |
| 1922 | } |
| 1923 | |
| 1924 | void FallbackAsmResourceMap::ResourceCollection::buildResources( |
| 1925 | Operation *op, AsmResourceBuilder &builder) const { |
| 1926 | for (const auto &entry : resources) { |
| 1927 | if (const auto *value = std::get_if<AsmResourceBlob>(ptr: &entry.value)) |
| 1928 | builder.buildBlob(key: entry.key, blob: *value); |
| 1929 | else if (const auto *value = std::get_if<bool>(ptr: &entry.value)) |
| 1930 | builder.buildBool(key: entry.key, data: *value); |
| 1931 | else if (const auto *value = std::get_if<std::string>(ptr: &entry.value)) |
| 1932 | builder.buildString(key: entry.key, data: *value); |
| 1933 | else |
| 1934 | llvm_unreachable("unknown AsmResourceEntryKind" ); |
| 1935 | } |
| 1936 | } |
| 1937 | |
| 1938 | //===----------------------------------------------------------------------===// |
| 1939 | // AsmState |
| 1940 | //===----------------------------------------------------------------------===// |
| 1941 | |
| 1942 | namespace mlir { |
| 1943 | namespace detail { |
| 1944 | class AsmStateImpl { |
| 1945 | public: |
| 1946 | explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags, |
| 1947 | AsmState::LocationMap *locationMap) |
| 1948 | : interfaces(op->getContext()), nameState(op, printerFlags), |
| 1949 | printerFlags(printerFlags), locationMap(locationMap) {} |
| 1950 | explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags, |
| 1951 | AsmState::LocationMap *locationMap) |
| 1952 | : interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {} |
| 1953 | |
| 1954 | /// Initialize the alias state to enable the printing of aliases. |
| 1955 | void initializeAliases(Operation *op) { |
| 1956 | aliasState.initialize(op, printerFlags, interfaces); |
| 1957 | } |
| 1958 | |
| 1959 | /// Get the state used for aliases. |
| 1960 | AliasState &getAliasState() { return aliasState; } |
| 1961 | |
| 1962 | /// Get the state used for SSA names. |
| 1963 | SSANameState &getSSANameState() { return nameState; } |
| 1964 | |
| 1965 | /// Get the state used for distinct attribute identifiers. |
| 1966 | DistinctState &getDistinctState() { return distinctState; } |
| 1967 | |
| 1968 | /// Return the dialects within the context that implement |
| 1969 | /// OpAsmDialectInterface. |
| 1970 | DialectInterfaceCollection<OpAsmDialectInterface> &getDialectInterfaces() { |
| 1971 | return interfaces; |
| 1972 | } |
| 1973 | |
| 1974 | /// Return the non-dialect resource printers. |
| 1975 | auto getResourcePrinters() { |
| 1976 | return llvm::make_pointee_range(Range&: externalResourcePrinters); |
| 1977 | } |
| 1978 | |
| 1979 | /// Get the printer flags. |
| 1980 | const OpPrintingFlags &getPrinterFlags() const { return printerFlags; } |
| 1981 | |
| 1982 | /// Register the location, line and column, within the buffer that the given |
| 1983 | /// operation was printed at. |
| 1984 | void registerOperationLocation(Operation *op, unsigned line, unsigned col) { |
| 1985 | if (locationMap) |
| 1986 | (*locationMap)[op] = std::make_pair(x&: line, y&: col); |
| 1987 | } |
| 1988 | |
| 1989 | /// Return the referenced dialect resources within the printer. |
| 1990 | DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> & |
| 1991 | getDialectResources() { |
| 1992 | return dialectResources; |
| 1993 | } |
| 1994 | |
| 1995 | LogicalResult pushCyclicPrinting(const void *opaquePointer) { |
| 1996 | return success(IsSuccess: cyclicPrintingStack.insert(X: opaquePointer)); |
| 1997 | } |
| 1998 | |
| 1999 | void popCyclicPrinting() { cyclicPrintingStack.pop_back(); } |
| 2000 | |
| 2001 | private: |
| 2002 | /// Collection of OpAsm interfaces implemented in the context. |
| 2003 | DialectInterfaceCollection<OpAsmDialectInterface> interfaces; |
| 2004 | |
| 2005 | /// A collection of non-dialect resource printers. |
| 2006 | SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters; |
| 2007 | |
| 2008 | /// A set of dialect resources that were referenced during printing. |
| 2009 | DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> dialectResources; |
| 2010 | |
| 2011 | /// The state used for attribute and type aliases. |
| 2012 | AliasState aliasState; |
| 2013 | |
| 2014 | /// The state used for SSA value names. |
| 2015 | SSANameState nameState; |
| 2016 | |
| 2017 | /// The state used for distinct attribute identifiers. |
| 2018 | DistinctState distinctState; |
| 2019 | |
| 2020 | /// Flags that control op output. |
| 2021 | OpPrintingFlags printerFlags; |
| 2022 | |
| 2023 | /// An optional location map to be populated. |
| 2024 | AsmState::LocationMap *locationMap; |
| 2025 | |
| 2026 | /// Stack of potentially cyclic mutable attributes or type currently being |
| 2027 | /// printed. |
| 2028 | SetVector<const void *> cyclicPrintingStack; |
| 2029 | |
| 2030 | // Allow direct access to the impl fields. |
| 2031 | friend AsmState; |
| 2032 | }; |
| 2033 | |
| 2034 | template <typename Range> |
| 2035 | void printDimensionList(raw_ostream &stream, Range &&shape) { |
| 2036 | llvm::interleave( |
| 2037 | shape, stream, |
| 2038 | [&stream](const auto &dimSize) { |
| 2039 | if (ShapedType::isDynamic(dValue: dimSize)) |
| 2040 | stream << "?" ; |
| 2041 | else |
| 2042 | stream << dimSize; |
| 2043 | }, |
| 2044 | "x" ); |
| 2045 | } |
| 2046 | |
| 2047 | } // namespace detail |
| 2048 | } // namespace mlir |
| 2049 | |
| 2050 | /// Verifies the operation and switches to generic op printing if verification |
| 2051 | /// fails. We need to do this because custom print functions may fail for |
| 2052 | /// invalid ops. |
| 2053 | static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op, |
| 2054 | OpPrintingFlags printerFlags) { |
| 2055 | if (printerFlags.shouldPrintGenericOpForm() || |
| 2056 | printerFlags.shouldAssumeVerified()) |
| 2057 | return printerFlags; |
| 2058 | |
| 2059 | // Ignore errors emitted by the verifier. We check the thread id to avoid |
| 2060 | // consuming other threads' errors. |
| 2061 | auto parentThreadId = llvm::get_threadid(); |
| 2062 | ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) { |
| 2063 | if (parentThreadId == llvm::get_threadid()) { |
| 2064 | LLVM_DEBUG({ |
| 2065 | diag.print(llvm::dbgs()); |
| 2066 | llvm::dbgs() << "\n" ; |
| 2067 | }); |
| 2068 | return success(); |
| 2069 | } |
| 2070 | return failure(); |
| 2071 | }); |
| 2072 | if (failed(Result: verify(op))) { |
| 2073 | LLVM_DEBUG(llvm::dbgs() |
| 2074 | << DEBUG_TYPE << ": '" << op->getName() |
| 2075 | << "' failed to verify and will be printed in generic form\n" ); |
| 2076 | printerFlags.printGenericOpForm(); |
| 2077 | } |
| 2078 | |
| 2079 | return printerFlags; |
| 2080 | } |
| 2081 | |
| 2082 | AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags, |
| 2083 | LocationMap *locationMap, FallbackAsmResourceMap *map) |
| 2084 | : impl(std::make_unique<AsmStateImpl>( |
| 2085 | args&: op, args: verifyOpAndAdjustFlags(op, printerFlags), args&: locationMap)) { |
| 2086 | if (map) |
| 2087 | attachFallbackResourcePrinter(map&: *map); |
| 2088 | } |
| 2089 | AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags, |
| 2090 | LocationMap *locationMap, FallbackAsmResourceMap *map) |
| 2091 | : impl(std::make_unique<AsmStateImpl>(args&: ctx, args: printerFlags, args&: locationMap)) { |
| 2092 | if (map) |
| 2093 | attachFallbackResourcePrinter(map&: *map); |
| 2094 | } |
| 2095 | AsmState::~AsmState() = default; |
| 2096 | |
| 2097 | const OpPrintingFlags &AsmState::getPrinterFlags() const { |
| 2098 | return impl->getPrinterFlags(); |
| 2099 | } |
| 2100 | |
| 2101 | void AsmState::attachResourcePrinter( |
| 2102 | std::unique_ptr<AsmResourcePrinter> printer) { |
| 2103 | impl->externalResourcePrinters.emplace_back(Args: std::move(printer)); |
| 2104 | } |
| 2105 | |
| 2106 | DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> & |
| 2107 | AsmState::getDialectResources() const { |
| 2108 | return impl->getDialectResources(); |
| 2109 | } |
| 2110 | |
| 2111 | //===----------------------------------------------------------------------===// |
| 2112 | // AsmPrinter::Impl |
| 2113 | //===----------------------------------------------------------------------===// |
| 2114 | |
| 2115 | AsmPrinter::Impl::Impl(raw_ostream &os, AsmStateImpl &state) |
| 2116 | : os(os), state(state), printerFlags(state.getPrinterFlags()) {} |
| 2117 | |
| 2118 | void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) { |
| 2119 | // Check to see if we are printing debug information. |
| 2120 | if (!printerFlags.shouldPrintDebugInfo()) |
| 2121 | return; |
| 2122 | |
| 2123 | os << " " ; |
| 2124 | printLocation(loc, /*allowAlias=*/allowAlias); |
| 2125 | } |
| 2126 | |
| 2127 | void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty, |
| 2128 | bool isTopLevel) { |
| 2129 | // If this isn't a top-level location, check for an alias. |
| 2130 | if (!isTopLevel && succeeded(Result: state.getAliasState().getAlias(attr: loc, os))) |
| 2131 | return; |
| 2132 | |
| 2133 | TypeSwitch<LocationAttr>(loc) |
| 2134 | .Case<OpaqueLoc>(caseFn: [&](OpaqueLoc loc) { |
| 2135 | printLocationInternal(loc: loc.getFallbackLocation(), pretty); |
| 2136 | }) |
| 2137 | .Case<UnknownLoc>(caseFn: [&](UnknownLoc loc) { |
| 2138 | if (pretty) |
| 2139 | os << "[unknown]" ; |
| 2140 | else |
| 2141 | os << "unknown" ; |
| 2142 | }) |
| 2143 | .Case<FileLineColRange>(caseFn: [&](FileLineColRange loc) { |
| 2144 | if (pretty) |
| 2145 | os << loc.getFilename().getValue(); |
| 2146 | else |
| 2147 | printEscapedString(str: loc.getFilename()); |
| 2148 | if (loc.getEndColumn() == loc.getStartColumn() && |
| 2149 | loc.getStartLine() == loc.getEndLine()) { |
| 2150 | os << ':' << loc.getStartLine() << ':' << loc.getStartColumn(); |
| 2151 | return; |
| 2152 | } |
| 2153 | if (loc.getStartLine() == loc.getEndLine()) { |
| 2154 | os << ':' << loc.getStartLine() << ':' << loc.getStartColumn() |
| 2155 | << " to :" << loc.getEndColumn(); |
| 2156 | return; |
| 2157 | } |
| 2158 | os << ':' << loc.getStartLine() << ':' << loc.getStartColumn() << " to " |
| 2159 | << loc.getEndLine() << ':' << loc.getEndColumn(); |
| 2160 | }) |
| 2161 | .Case<NameLoc>(caseFn: [&](NameLoc loc) { |
| 2162 | printEscapedString(str: loc.getName()); |
| 2163 | |
| 2164 | // Print the child if it isn't unknown. |
| 2165 | auto childLoc = loc.getChildLoc(); |
| 2166 | if (!llvm::isa<UnknownLoc>(Val: childLoc)) { |
| 2167 | os << '('; |
| 2168 | printLocationInternal(loc: childLoc, pretty); |
| 2169 | os << ')'; |
| 2170 | } |
| 2171 | }) |
| 2172 | .Case<CallSiteLoc>(caseFn: [&](CallSiteLoc loc) { |
| 2173 | Location caller = loc.getCaller(); |
| 2174 | Location callee = loc.getCallee(); |
| 2175 | if (!pretty) |
| 2176 | os << "callsite(" ; |
| 2177 | printLocationInternal(loc: callee, pretty); |
| 2178 | if (pretty) { |
| 2179 | if (llvm::isa<NameLoc>(Val: callee)) { |
| 2180 | if (llvm::isa<FileLineColLoc>(Val: caller)) { |
| 2181 | os << " at " ; |
| 2182 | } else { |
| 2183 | os << newLine << " at " ; |
| 2184 | } |
| 2185 | } else { |
| 2186 | os << newLine << " at " ; |
| 2187 | } |
| 2188 | } else { |
| 2189 | os << " at " ; |
| 2190 | } |
| 2191 | printLocationInternal(loc: caller, pretty); |
| 2192 | if (!pretty) |
| 2193 | os << ")" ; |
| 2194 | }) |
| 2195 | .Case<FusedLoc>(caseFn: [&](FusedLoc loc) { |
| 2196 | if (!pretty) |
| 2197 | os << "fused" ; |
| 2198 | if (Attribute metadata = loc.getMetadata()) { |
| 2199 | os << '<'; |
| 2200 | printAttribute(attr: metadata); |
| 2201 | os << '>'; |
| 2202 | } |
| 2203 | os << '['; |
| 2204 | interleave( |
| 2205 | c: loc.getLocations(), |
| 2206 | each_fn: [&](Location loc) { printLocationInternal(loc, pretty); }, |
| 2207 | between_fn: [&]() { os << ", " ; }); |
| 2208 | os << ']'; |
| 2209 | }) |
| 2210 | .Default(defaultFn: [&](LocationAttr loc) { |
| 2211 | // Assumes that this is a dialect-specific attribute and prints it |
| 2212 | // directly. |
| 2213 | printAttribute(attr: loc); |
| 2214 | }); |
| 2215 | } |
| 2216 | |
| 2217 | /// Print a floating point value in a way that the parser will be able to |
| 2218 | /// round-trip losslessly. |
| 2219 | static void printFloatValue(const APFloat &apValue, raw_ostream &os, |
| 2220 | bool *printedHex = nullptr) { |
| 2221 | // We would like to output the FP constant value in exponential notation, |
| 2222 | // but we cannot do this if doing so will lose precision. Check here to |
| 2223 | // make sure that we only output it in exponential format if we can parse |
| 2224 | // the value back and get the same value. |
| 2225 | bool isInf = apValue.isInfinity(); |
| 2226 | bool isNaN = apValue.isNaN(); |
| 2227 | if (!isInf && !isNaN) { |
| 2228 | SmallString<128> strValue; |
| 2229 | apValue.toString(Str&: strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0, |
| 2230 | /*TruncateZero=*/false); |
| 2231 | |
| 2232 | // Check to make sure that the stringized number is not some string like |
| 2233 | // "Inf" or NaN, that atof will accept, but the lexer will not. Check |
| 2234 | // that the string matches the "[-+]?[0-9]" regex. |
| 2235 | assert(((strValue[0] >= '0' && strValue[0] <= '9') || |
| 2236 | ((strValue[0] == '-' || strValue[0] == '+') && |
| 2237 | (strValue[1] >= '0' && strValue[1] <= '9'))) && |
| 2238 | "[-+]?[0-9] regex does not match!" ); |
| 2239 | |
| 2240 | // Parse back the stringized version and check that the value is equal |
| 2241 | // (i.e., there is no precision loss). |
| 2242 | if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(RHS: apValue)) { |
| 2243 | os << strValue; |
| 2244 | return; |
| 2245 | } |
| 2246 | |
| 2247 | // If it is not, use the default format of APFloat instead of the |
| 2248 | // exponential notation. |
| 2249 | strValue.clear(); |
| 2250 | apValue.toString(Str&: strValue); |
| 2251 | |
| 2252 | // Make sure that we can parse the default form as a float. |
| 2253 | if (strValue.str().contains(C: '.')) { |
| 2254 | os << strValue; |
| 2255 | return; |
| 2256 | } |
| 2257 | } |
| 2258 | |
| 2259 | // Print special values in hexadecimal format. The sign bit should be included |
| 2260 | // in the literal. |
| 2261 | if (printedHex) |
| 2262 | *printedHex = true; |
| 2263 | SmallVector<char, 16> str; |
| 2264 | APInt apInt = apValue.bitcastToAPInt(); |
| 2265 | apInt.toString(Str&: str, /*Radix=*/16, /*Signed=*/false, |
| 2266 | /*formatAsCLiteral=*/true); |
| 2267 | os << str; |
| 2268 | } |
| 2269 | |
| 2270 | void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) { |
| 2271 | if (printerFlags.shouldPrintDebugInfoPrettyForm()) |
| 2272 | return printLocationInternal(loc, /*pretty=*/true, /*isTopLevel=*/true); |
| 2273 | |
| 2274 | os << "loc(" ; |
| 2275 | if (!allowAlias || failed(Result: printAlias(attr: loc))) |
| 2276 | printLocationInternal(loc, /*pretty=*/false, /*isTopLevel=*/true); |
| 2277 | os << ')'; |
| 2278 | } |
| 2279 | |
| 2280 | /// Returns true if the given dialect symbol data is simple enough to print in |
| 2281 | /// the pretty form. This is essentially when the symbol takes the form: |
| 2282 | /// identifier (`<` body `>`)? |
| 2283 | static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) { |
| 2284 | // The name must start with an identifier. |
| 2285 | if (symName.empty() || !isalpha(symName.front())) |
| 2286 | return false; |
| 2287 | |
| 2288 | // Ignore all the characters that are valid in an identifier in the symbol |
| 2289 | // name. |
| 2290 | symName = symName.drop_while( |
| 2291 | F: [](char c) { return llvm::isAlnum(C: c) || c == '.' || c == '_'; }); |
| 2292 | if (symName.empty()) |
| 2293 | return true; |
| 2294 | |
| 2295 | // If we got to an unexpected character, then it must be a <>. Check that the |
| 2296 | // rest of the symbol is wrapped within <>. |
| 2297 | return symName.front() == '<' && symName.back() == '>'; |
| 2298 | } |
| 2299 | |
| 2300 | /// Print the given dialect symbol to the stream. |
| 2301 | static void printDialectSymbol(raw_ostream &os, StringRef symPrefix, |
| 2302 | StringRef dialectName, StringRef symString) { |
| 2303 | os << symPrefix << dialectName; |
| 2304 | |
| 2305 | // If this symbol name is simple enough, print it directly in pretty form, |
| 2306 | // otherwise, we print it as an escaped string. |
| 2307 | if (isDialectSymbolSimpleEnoughForPrettyForm(symName: symString)) { |
| 2308 | os << '.' << symString; |
| 2309 | return; |
| 2310 | } |
| 2311 | |
| 2312 | os << '<' << symString << '>'; |
| 2313 | } |
| 2314 | |
| 2315 | /// Returns true if the given string can be represented as a bare identifier. |
| 2316 | static bool isBareIdentifier(StringRef name) { |
| 2317 | // By making this unsigned, the value passed in to isalnum will always be |
| 2318 | // in the range 0-255. This is important when building with MSVC because |
| 2319 | // its implementation will assert. This situation can arise when dealing |
| 2320 | // with UTF-8 multibyte characters. |
| 2321 | if (name.empty() || (!isalpha(name[0]) && name[0] != '_')) |
| 2322 | return false; |
| 2323 | return llvm::all_of(Range: name.drop_front(), P: [](unsigned char c) { |
| 2324 | return isalnum(c) || c == '_' || c == '$' || c == '.'; |
| 2325 | }); |
| 2326 | } |
| 2327 | |
| 2328 | /// Print the given string as a keyword, or a quoted and escaped string if it |
| 2329 | /// has any special or non-printable characters in it. |
| 2330 | static void printKeywordOrString(StringRef keyword, raw_ostream &os) { |
| 2331 | // If it can be represented as a bare identifier, write it directly. |
| 2332 | if (isBareIdentifier(name: keyword)) { |
| 2333 | os << keyword; |
| 2334 | return; |
| 2335 | } |
| 2336 | |
| 2337 | // Otherwise, output the keyword wrapped in quotes with proper escaping. |
| 2338 | os << "\"" ; |
| 2339 | printEscapedString(Name: keyword, Out&: os); |
| 2340 | os << '"'; |
| 2341 | } |
| 2342 | |
| 2343 | /// Print the given string as a symbol reference. A symbol reference is |
| 2344 | /// represented as a string prefixed with '@'. The reference is surrounded with |
| 2345 | /// ""'s and escaped if it has any special or non-printable characters in it. |
| 2346 | static void printSymbolReference(StringRef symbolRef, raw_ostream &os) { |
| 2347 | if (symbolRef.empty()) { |
| 2348 | os << "@<<INVALID EMPTY SYMBOL>>" ; |
| 2349 | return; |
| 2350 | } |
| 2351 | os << '@'; |
| 2352 | printKeywordOrString(keyword: symbolRef, os); |
| 2353 | } |
| 2354 | |
| 2355 | // Print out a valid ElementsAttr that is succinct and can represent any |
| 2356 | // potential shape/type, for use when eliding a large ElementsAttr. |
| 2357 | // |
| 2358 | // We choose to use a dense resource ElementsAttr literal with conspicuous |
| 2359 | // content to hopefully alert readers to the fact that this has been elided. |
| 2360 | static void printElidedElementsAttr(raw_ostream &os) { |
| 2361 | os << R"(dense_resource<__elided__>)" ; |
| 2362 | } |
| 2363 | |
| 2364 | void AsmPrinter::Impl::printResourceHandle( |
| 2365 | const AsmDialectResourceHandle &resource) { |
| 2366 | auto *interface = cast<OpAsmDialectInterface>(Val: resource.getDialect()); |
| 2367 | ::printKeywordOrString(keyword: interface->getResourceKey(handle: resource), os); |
| 2368 | state.getDialectResources()[resource.getDialect()].insert(X: resource); |
| 2369 | } |
| 2370 | |
| 2371 | LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) { |
| 2372 | return state.getAliasState().getAlias(attr, os); |
| 2373 | } |
| 2374 | |
| 2375 | LogicalResult AsmPrinter::Impl::printAlias(Type type) { |
| 2376 | return state.getAliasState().getAlias(ty: type, os); |
| 2377 | } |
| 2378 | |
| 2379 | void AsmPrinter::Impl::printAttribute(Attribute attr, |
| 2380 | AttrTypeElision typeElision) { |
| 2381 | if (!attr) { |
| 2382 | os << "<<NULL ATTRIBUTE>>" ; |
| 2383 | return; |
| 2384 | } |
| 2385 | |
| 2386 | // Try to print an alias for this attribute. |
| 2387 | if (succeeded(Result: printAlias(attr))) |
| 2388 | return; |
| 2389 | return printAttributeImpl(attr, typeElision); |
| 2390 | } |
| 2391 | void AsmPrinter::Impl::printAttributeImpl(Attribute attr, |
| 2392 | AttrTypeElision typeElision) { |
| 2393 | if (!isa<BuiltinDialect>(Val: attr.getDialect())) { |
| 2394 | printDialectAttribute(attr); |
| 2395 | } else if (auto opaqueAttr = llvm::dyn_cast<OpaqueAttr>(Val&: attr)) { |
| 2396 | printDialectSymbol(os, symPrefix: "#" , dialectName: opaqueAttr.getDialectNamespace(), |
| 2397 | symString: opaqueAttr.getAttrData()); |
| 2398 | } else if (llvm::isa<UnitAttr>(Val: attr)) { |
| 2399 | os << "unit" ; |
| 2400 | return; |
| 2401 | } else if (auto distinctAttr = llvm::dyn_cast<DistinctAttr>(Val&: attr)) { |
| 2402 | os << "distinct[" << state.getDistinctState().getId(distinctAttr) << "]<" ; |
| 2403 | if (!llvm::isa<UnitAttr>(Val: distinctAttr.getReferencedAttr())) { |
| 2404 | printAttribute(attr: distinctAttr.getReferencedAttr()); |
| 2405 | } |
| 2406 | os << '>'; |
| 2407 | return; |
| 2408 | } else if (auto dictAttr = llvm::dyn_cast<DictionaryAttr>(Val&: attr)) { |
| 2409 | os << '{'; |
| 2410 | interleaveComma(c: dictAttr.getValue(), |
| 2411 | eachFn: [&](NamedAttribute attr) { printNamedAttribute(attr); }); |
| 2412 | os << '}'; |
| 2413 | |
| 2414 | } else if (auto intAttr = llvm::dyn_cast<IntegerAttr>(Val&: attr)) { |
| 2415 | Type intType = intAttr.getType(); |
| 2416 | if (intType.isSignlessInteger(width: 1)) { |
| 2417 | os << (intAttr.getValue().getBoolValue() ? "true" : "false" ); |
| 2418 | |
| 2419 | // Boolean integer attributes always elides the type. |
| 2420 | return; |
| 2421 | } |
| 2422 | |
| 2423 | // Only print attributes as unsigned if they are explicitly unsigned or are |
| 2424 | // signless 1-bit values. Indexes, signed values, and multi-bit signless |
| 2425 | // values print as signed. |
| 2426 | bool isUnsigned = |
| 2427 | intType.isUnsignedInteger() || intType.isSignlessInteger(width: 1); |
| 2428 | intAttr.getValue().print(OS&: os, isSigned: !isUnsigned); |
| 2429 | |
| 2430 | // IntegerAttr elides the type if I64. |
| 2431 | if (typeElision == AttrTypeElision::May && intType.isSignlessInteger(width: 64)) |
| 2432 | return; |
| 2433 | |
| 2434 | } else if (auto floatAttr = llvm::dyn_cast<FloatAttr>(Val&: attr)) { |
| 2435 | bool printedHex = false; |
| 2436 | printFloatValue(apValue: floatAttr.getValue(), os, printedHex: &printedHex); |
| 2437 | |
| 2438 | // FloatAttr elides the type if F64. |
| 2439 | if (typeElision == AttrTypeElision::May && floatAttr.getType().isF64() && |
| 2440 | !printedHex) |
| 2441 | return; |
| 2442 | |
| 2443 | } else if (auto strAttr = llvm::dyn_cast<StringAttr>(Val&: attr)) { |
| 2444 | printEscapedString(str: strAttr.getValue()); |
| 2445 | |
| 2446 | } else if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(Val&: attr)) { |
| 2447 | os << '['; |
| 2448 | interleaveComma(c: arrayAttr.getValue(), eachFn: [&](Attribute attr) { |
| 2449 | printAttribute(attr, typeElision: AttrTypeElision::May); |
| 2450 | }); |
| 2451 | os << ']'; |
| 2452 | |
| 2453 | } else if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(Val&: attr)) { |
| 2454 | os << "affine_map<" ; |
| 2455 | affineMapAttr.getValue().print(os); |
| 2456 | os << '>'; |
| 2457 | |
| 2458 | // AffineMap always elides the type. |
| 2459 | return; |
| 2460 | |
| 2461 | } else if (auto integerSetAttr = llvm::dyn_cast<IntegerSetAttr>(Val&: attr)) { |
| 2462 | os << "affine_set<" ; |
| 2463 | integerSetAttr.getValue().print(os); |
| 2464 | os << '>'; |
| 2465 | |
| 2466 | // IntegerSet always elides the type. |
| 2467 | return; |
| 2468 | |
| 2469 | } else if (auto typeAttr = llvm::dyn_cast<TypeAttr>(Val&: attr)) { |
| 2470 | printType(type: typeAttr.getValue()); |
| 2471 | |
| 2472 | } else if (auto refAttr = llvm::dyn_cast<SymbolRefAttr>(Val&: attr)) { |
| 2473 | printSymbolReference(symbolRef: refAttr.getRootReference().getValue(), os); |
| 2474 | for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) { |
| 2475 | os << "::" ; |
| 2476 | printSymbolReference(symbolRef: nestedRef.getValue(), os); |
| 2477 | } |
| 2478 | |
| 2479 | } else if (auto intOrFpEltAttr = |
| 2480 | llvm::dyn_cast<DenseIntOrFPElementsAttr>(Val&: attr)) { |
| 2481 | if (printerFlags.shouldElideElementsAttr(attr: intOrFpEltAttr)) { |
| 2482 | printElidedElementsAttr(os); |
| 2483 | } else { |
| 2484 | os << "dense<" ; |
| 2485 | printDenseIntOrFPElementsAttr(attr: intOrFpEltAttr, /*allowHex=*/true); |
| 2486 | os << '>'; |
| 2487 | } |
| 2488 | |
| 2489 | } else if (auto strEltAttr = llvm::dyn_cast<DenseStringElementsAttr>(Val&: attr)) { |
| 2490 | if (printerFlags.shouldElideElementsAttr(attr: strEltAttr)) { |
| 2491 | printElidedElementsAttr(os); |
| 2492 | } else { |
| 2493 | os << "dense<" ; |
| 2494 | printDenseStringElementsAttr(attr: strEltAttr); |
| 2495 | os << '>'; |
| 2496 | } |
| 2497 | |
| 2498 | } else if (auto sparseEltAttr = llvm::dyn_cast<SparseElementsAttr>(Val&: attr)) { |
| 2499 | if (printerFlags.shouldElideElementsAttr(attr: sparseEltAttr.getIndices()) || |
| 2500 | printerFlags.shouldElideElementsAttr(attr: sparseEltAttr.getValues())) { |
| 2501 | printElidedElementsAttr(os); |
| 2502 | } else { |
| 2503 | os << "sparse<" ; |
| 2504 | DenseIntElementsAttr indices = sparseEltAttr.getIndices(); |
| 2505 | if (indices.getNumElements() != 0) { |
| 2506 | printDenseIntOrFPElementsAttr(attr: indices, /*allowHex=*/false); |
| 2507 | os << ", " ; |
| 2508 | printDenseElementsAttr(attr: sparseEltAttr.getValues(), /*allowHex=*/true); |
| 2509 | } |
| 2510 | os << '>'; |
| 2511 | } |
| 2512 | } else if (auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(Val&: attr)) { |
| 2513 | stridedLayoutAttr.print(os); |
| 2514 | } else if (auto denseArrayAttr = llvm::dyn_cast<DenseArrayAttr>(Val&: attr)) { |
| 2515 | os << "array<" ; |
| 2516 | printType(type: denseArrayAttr.getElementType()); |
| 2517 | if (!denseArrayAttr.empty()) { |
| 2518 | os << ": " ; |
| 2519 | printDenseArrayAttr(attr: denseArrayAttr); |
| 2520 | } |
| 2521 | os << ">" ; |
| 2522 | return; |
| 2523 | } else if (auto resourceAttr = |
| 2524 | llvm::dyn_cast<DenseResourceElementsAttr>(Val&: attr)) { |
| 2525 | os << "dense_resource<" ; |
| 2526 | printResourceHandle(resource: resourceAttr.getRawHandle()); |
| 2527 | os << ">" ; |
| 2528 | } else if (auto locAttr = llvm::dyn_cast<LocationAttr>(Val&: attr)) { |
| 2529 | printLocation(loc: locAttr); |
| 2530 | } else { |
| 2531 | llvm::report_fatal_error(reason: "Unknown builtin attribute" ); |
| 2532 | } |
| 2533 | // Don't print the type if we must elide it, or if it is a None type. |
| 2534 | if (typeElision != AttrTypeElision::Must) { |
| 2535 | if (auto typedAttr = llvm::dyn_cast<TypedAttr>(Val&: attr)) { |
| 2536 | Type attrType = typedAttr.getType(); |
| 2537 | if (!llvm::isa<NoneType>(Val: attrType)) { |
| 2538 | os << " : " ; |
| 2539 | printType(type: attrType); |
| 2540 | } |
| 2541 | } |
| 2542 | } |
| 2543 | } |
| 2544 | |
| 2545 | /// Print the integer element of a DenseElementsAttr. |
| 2546 | static void printDenseIntElement(const APInt &value, raw_ostream &os, |
| 2547 | Type type) { |
| 2548 | if (type.isInteger(width: 1)) |
| 2549 | os << (value.getBoolValue() ? "true" : "false" ); |
| 2550 | else |
| 2551 | value.print(OS&: os, isSigned: !type.isUnsignedInteger()); |
| 2552 | } |
| 2553 | |
| 2554 | static void |
| 2555 | printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os, |
| 2556 | function_ref<void(unsigned)> printEltFn) { |
| 2557 | // Special case for 0-d and splat tensors. |
| 2558 | if (isSplat) |
| 2559 | return printEltFn(0); |
| 2560 | |
| 2561 | // Special case for degenerate tensors. |
| 2562 | auto numElements = type.getNumElements(); |
| 2563 | if (numElements == 0) |
| 2564 | return; |
| 2565 | |
| 2566 | // We use a mixed-radix counter to iterate through the shape. When we bump a |
| 2567 | // non-least-significant digit, we emit a close bracket. When we next emit an |
| 2568 | // element we re-open all closed brackets. |
| 2569 | |
| 2570 | // The mixed-radix counter, with radices in 'shape'. |
| 2571 | int64_t rank = type.getRank(); |
| 2572 | SmallVector<unsigned, 4> counter(rank, 0); |
| 2573 | // The number of brackets that have been opened and not closed. |
| 2574 | unsigned openBrackets = 0; |
| 2575 | |
| 2576 | auto shape = type.getShape(); |
| 2577 | auto bumpCounter = [&] { |
| 2578 | // Bump the least significant digit. |
| 2579 | ++counter[rank - 1]; |
| 2580 | // Iterate backwards bubbling back the increment. |
| 2581 | for (unsigned i = rank - 1; i > 0; --i) |
| 2582 | if (counter[i] >= shape[i]) { |
| 2583 | // Index 'i' is rolled over. Bump (i-1) and close a bracket. |
| 2584 | counter[i] = 0; |
| 2585 | ++counter[i - 1]; |
| 2586 | --openBrackets; |
| 2587 | os << ']'; |
| 2588 | } |
| 2589 | }; |
| 2590 | |
| 2591 | for (unsigned idx = 0, e = numElements; idx != e; ++idx) { |
| 2592 | if (idx != 0) |
| 2593 | os << ", " ; |
| 2594 | while (openBrackets++ < rank) |
| 2595 | os << '['; |
| 2596 | openBrackets = rank; |
| 2597 | printEltFn(idx); |
| 2598 | bumpCounter(); |
| 2599 | } |
| 2600 | while (openBrackets-- > 0) |
| 2601 | os << ']'; |
| 2602 | } |
| 2603 | |
| 2604 | void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr, |
| 2605 | bool allowHex) { |
| 2606 | if (auto stringAttr = llvm::dyn_cast<DenseStringElementsAttr>(Val&: attr)) |
| 2607 | return printDenseStringElementsAttr(attr: stringAttr); |
| 2608 | |
| 2609 | printDenseIntOrFPElementsAttr(attr: llvm::cast<DenseIntOrFPElementsAttr>(Val&: attr), |
| 2610 | allowHex); |
| 2611 | } |
| 2612 | |
| 2613 | void AsmPrinter::Impl::printDenseIntOrFPElementsAttr( |
| 2614 | DenseIntOrFPElementsAttr attr, bool allowHex) { |
| 2615 | auto type = attr.getType(); |
| 2616 | auto elementType = type.getElementType(); |
| 2617 | |
| 2618 | // Check to see if we should format this attribute as a hex string. |
| 2619 | if (allowHex && printerFlags.shouldPrintElementsAttrWithHex(attr)) { |
| 2620 | ArrayRef<char> rawData = attr.getRawData(); |
| 2621 | if (llvm::endianness::native == llvm::endianness::big) { |
| 2622 | // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE |
| 2623 | // machines. It is converted here to print in LE format. |
| 2624 | SmallVector<char, 64> outDataVec(rawData.size()); |
| 2625 | MutableArrayRef<char> convRawData(outDataVec); |
| 2626 | DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( |
| 2627 | inRawData: rawData, outRawData: convRawData, type); |
| 2628 | printHexString(data: convRawData); |
| 2629 | } else { |
| 2630 | printHexString(data: rawData); |
| 2631 | } |
| 2632 | |
| 2633 | return; |
| 2634 | } |
| 2635 | |
| 2636 | if (ComplexType complexTy = llvm::dyn_cast<ComplexType>(Val&: elementType)) { |
| 2637 | Type complexElementType = complexTy.getElementType(); |
| 2638 | // Note: The if and else below had a common lambda function which invoked |
| 2639 | // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2 |
| 2640 | // and hence was replaced. |
| 2641 | if (llvm::isa<IntegerType>(Val: complexElementType)) { |
| 2642 | auto valueIt = attr.value_begin<std::complex<APInt>>(); |
| 2643 | printDenseElementsAttrImpl(isSplat: attr.isSplat(), type, os, printEltFn: [&](unsigned index) { |
| 2644 | auto complexValue = *(valueIt + index); |
| 2645 | os << "(" ; |
| 2646 | printDenseIntElement(value: complexValue.real(), os, type: complexElementType); |
| 2647 | os << "," ; |
| 2648 | printDenseIntElement(value: complexValue.imag(), os, type: complexElementType); |
| 2649 | os << ")" ; |
| 2650 | }); |
| 2651 | } else { |
| 2652 | auto valueIt = attr.value_begin<std::complex<APFloat>>(); |
| 2653 | printDenseElementsAttrImpl(isSplat: attr.isSplat(), type, os, printEltFn: [&](unsigned index) { |
| 2654 | auto complexValue = *(valueIt + index); |
| 2655 | os << "(" ; |
| 2656 | printFloatValue(apValue: complexValue.real(), os); |
| 2657 | os << "," ; |
| 2658 | printFloatValue(apValue: complexValue.imag(), os); |
| 2659 | os << ")" ; |
| 2660 | }); |
| 2661 | } |
| 2662 | } else if (elementType.isIntOrIndex()) { |
| 2663 | auto valueIt = attr.value_begin<APInt>(); |
| 2664 | printDenseElementsAttrImpl(isSplat: attr.isSplat(), type, os, printEltFn: [&](unsigned index) { |
| 2665 | printDenseIntElement(value: *(valueIt + index), os, type: elementType); |
| 2666 | }); |
| 2667 | } else { |
| 2668 | assert(llvm::isa<FloatType>(elementType) && "unexpected element type" ); |
| 2669 | auto valueIt = attr.value_begin<APFloat>(); |
| 2670 | printDenseElementsAttrImpl(isSplat: attr.isSplat(), type, os, printEltFn: [&](unsigned index) { |
| 2671 | printFloatValue(apValue: *(valueIt + index), os); |
| 2672 | }); |
| 2673 | } |
| 2674 | } |
| 2675 | |
| 2676 | void AsmPrinter::Impl::printDenseStringElementsAttr( |
| 2677 | DenseStringElementsAttr attr) { |
| 2678 | ArrayRef<StringRef> data = attr.getRawStringData(); |
| 2679 | auto printFn = [&](unsigned index) { printEscapedString(str: data[index]); }; |
| 2680 | printDenseElementsAttrImpl(isSplat: attr.isSplat(), type: attr.getType(), os, printEltFn: printFn); |
| 2681 | } |
| 2682 | |
| 2683 | void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) { |
| 2684 | Type type = attr.getElementType(); |
| 2685 | unsigned bitwidth = type.isInteger(width: 1) ? 8 : type.getIntOrFloatBitWidth(); |
| 2686 | unsigned byteSize = bitwidth / 8; |
| 2687 | ArrayRef<char> data = attr.getRawData(); |
| 2688 | |
| 2689 | auto printElementAt = [&](unsigned i) { |
| 2690 | APInt value(bitwidth, 0); |
| 2691 | if (bitwidth) { |
| 2692 | llvm::LoadIntFromMemory( |
| 2693 | IntVal&: value, Src: reinterpret_cast<const uint8_t *>(data.begin() + byteSize * i), |
| 2694 | LoadBytes: byteSize); |
| 2695 | } |
| 2696 | // Print the data as-is or as a float. |
| 2697 | if (type.isIntOrIndex()) { |
| 2698 | printDenseIntElement(value, os&: getStream(), type); |
| 2699 | } else { |
| 2700 | APFloat fltVal(llvm::cast<FloatType>(Val&: type).getFloatSemantics(), value); |
| 2701 | printFloatValue(apValue: fltVal, os&: getStream()); |
| 2702 | } |
| 2703 | }; |
| 2704 | llvm::interleaveComma(c: llvm::seq<unsigned>(Begin: 0, End: attr.size()), os&: getStream(), |
| 2705 | each_fn: printElementAt); |
| 2706 | } |
| 2707 | |
| 2708 | void AsmPrinter::Impl::printType(Type type) { |
| 2709 | if (!type) { |
| 2710 | os << "<<NULL TYPE>>" ; |
| 2711 | return; |
| 2712 | } |
| 2713 | |
| 2714 | // Try to print an alias for this type. |
| 2715 | if (succeeded(Result: printAlias(type))) |
| 2716 | return; |
| 2717 | return printTypeImpl(type); |
| 2718 | } |
| 2719 | |
| 2720 | void AsmPrinter::Impl::printTypeImpl(Type type) { |
| 2721 | TypeSwitch<Type>(type) |
| 2722 | .Case<OpaqueType>(caseFn: [&](OpaqueType opaqueTy) { |
| 2723 | printDialectSymbol(os, symPrefix: "!" , dialectName: opaqueTy.getDialectNamespace(), |
| 2724 | symString: opaqueTy.getTypeData()); |
| 2725 | }) |
| 2726 | .Case<IndexType>(caseFn: [&](Type) { os << "index" ; }) |
| 2727 | .Case<Float4E2M1FNType>(caseFn: [&](Type) { os << "f4E2M1FN" ; }) |
| 2728 | .Case<Float6E2M3FNType>(caseFn: [&](Type) { os << "f6E2M3FN" ; }) |
| 2729 | .Case<Float6E3M2FNType>(caseFn: [&](Type) { os << "f6E3M2FN" ; }) |
| 2730 | .Case<Float8E5M2Type>(caseFn: [&](Type) { os << "f8E5M2" ; }) |
| 2731 | .Case<Float8E4M3Type>(caseFn: [&](Type) { os << "f8E4M3" ; }) |
| 2732 | .Case<Float8E4M3FNType>(caseFn: [&](Type) { os << "f8E4M3FN" ; }) |
| 2733 | .Case<Float8E5M2FNUZType>(caseFn: [&](Type) { os << "f8E5M2FNUZ" ; }) |
| 2734 | .Case<Float8E4M3FNUZType>(caseFn: [&](Type) { os << "f8E4M3FNUZ" ; }) |
| 2735 | .Case<Float8E4M3B11FNUZType>(caseFn: [&](Type) { os << "f8E4M3B11FNUZ" ; }) |
| 2736 | .Case<Float8E3M4Type>(caseFn: [&](Type) { os << "f8E3M4" ; }) |
| 2737 | .Case<Float8E8M0FNUType>(caseFn: [&](Type) { os << "f8E8M0FNU" ; }) |
| 2738 | .Case<BFloat16Type>(caseFn: [&](Type) { os << "bf16" ; }) |
| 2739 | .Case<Float16Type>(caseFn: [&](Type) { os << "f16" ; }) |
| 2740 | .Case<FloatTF32Type>(caseFn: [&](Type) { os << "tf32" ; }) |
| 2741 | .Case<Float32Type>(caseFn: [&](Type) { os << "f32" ; }) |
| 2742 | .Case<Float64Type>(caseFn: [&](Type) { os << "f64" ; }) |
| 2743 | .Case<Float80Type>(caseFn: [&](Type) { os << "f80" ; }) |
| 2744 | .Case<Float128Type>(caseFn: [&](Type) { os << "f128" ; }) |
| 2745 | .Case<IntegerType>(caseFn: [&](IntegerType integerTy) { |
| 2746 | if (integerTy.isSigned()) |
| 2747 | os << 's'; |
| 2748 | else if (integerTy.isUnsigned()) |
| 2749 | os << 'u'; |
| 2750 | os << 'i' << integerTy.getWidth(); |
| 2751 | }) |
| 2752 | .Case<FunctionType>(caseFn: [&](FunctionType funcTy) { |
| 2753 | os << '('; |
| 2754 | interleaveComma(c: funcTy.getInputs(), eachFn: [&](Type ty) { printType(type: ty); }); |
| 2755 | os << ") -> " ; |
| 2756 | ArrayRef<Type> results = funcTy.getResults(); |
| 2757 | if (results.size() == 1 && !llvm::isa<FunctionType>(Val: results[0])) { |
| 2758 | printType(type: results[0]); |
| 2759 | } else { |
| 2760 | os << '('; |
| 2761 | interleaveComma(c: results, eachFn: [&](Type ty) { printType(type: ty); }); |
| 2762 | os << ')'; |
| 2763 | } |
| 2764 | }) |
| 2765 | .Case<VectorType>(caseFn: [&](VectorType vectorTy) { |
| 2766 | auto scalableDims = vectorTy.getScalableDims(); |
| 2767 | os << "vector<" ; |
| 2768 | auto vShape = vectorTy.getShape(); |
| 2769 | unsigned lastDim = vShape.size(); |
| 2770 | unsigned dimIdx = 0; |
| 2771 | for (dimIdx = 0; dimIdx < lastDim; dimIdx++) { |
| 2772 | if (!scalableDims.empty() && scalableDims[dimIdx]) |
| 2773 | os << '['; |
| 2774 | os << vShape[dimIdx]; |
| 2775 | if (!scalableDims.empty() && scalableDims[dimIdx]) |
| 2776 | os << ']'; |
| 2777 | os << 'x'; |
| 2778 | } |
| 2779 | printType(type: vectorTy.getElementType()); |
| 2780 | os << '>'; |
| 2781 | }) |
| 2782 | .Case<RankedTensorType>(caseFn: [&](RankedTensorType tensorTy) { |
| 2783 | os << "tensor<" ; |
| 2784 | printDimensionList(shape: tensorTy.getShape()); |
| 2785 | if (!tensorTy.getShape().empty()) |
| 2786 | os << 'x'; |
| 2787 | printType(type: tensorTy.getElementType()); |
| 2788 | // Only print the encoding attribute value if set. |
| 2789 | if (tensorTy.getEncoding()) { |
| 2790 | os << ", " ; |
| 2791 | printAttribute(attr: tensorTy.getEncoding()); |
| 2792 | } |
| 2793 | os << '>'; |
| 2794 | }) |
| 2795 | .Case<UnrankedTensorType>(caseFn: [&](UnrankedTensorType tensorTy) { |
| 2796 | os << "tensor<*x" ; |
| 2797 | printType(type: tensorTy.getElementType()); |
| 2798 | os << '>'; |
| 2799 | }) |
| 2800 | .Case<MemRefType>(caseFn: [&](MemRefType memrefTy) { |
| 2801 | os << "memref<" ; |
| 2802 | printDimensionList(shape: memrefTy.getShape()); |
| 2803 | if (!memrefTy.getShape().empty()) |
| 2804 | os << 'x'; |
| 2805 | printType(type: memrefTy.getElementType()); |
| 2806 | MemRefLayoutAttrInterface layout = memrefTy.getLayout(); |
| 2807 | if (!llvm::isa<AffineMapAttr>(Val: layout) || !layout.isIdentity()) { |
| 2808 | os << ", " ; |
| 2809 | printAttribute(attr: memrefTy.getLayout(), typeElision: AttrTypeElision::May); |
| 2810 | } |
| 2811 | // Only print the memory space if it is the non-default one. |
| 2812 | if (memrefTy.getMemorySpace()) { |
| 2813 | os << ", " ; |
| 2814 | printAttribute(attr: memrefTy.getMemorySpace(), typeElision: AttrTypeElision::May); |
| 2815 | } |
| 2816 | os << '>'; |
| 2817 | }) |
| 2818 | .Case<UnrankedMemRefType>(caseFn: [&](UnrankedMemRefType memrefTy) { |
| 2819 | os << "memref<*x" ; |
| 2820 | printType(type: memrefTy.getElementType()); |
| 2821 | // Only print the memory space if it is the non-default one. |
| 2822 | if (memrefTy.getMemorySpace()) { |
| 2823 | os << ", " ; |
| 2824 | printAttribute(attr: memrefTy.getMemorySpace(), typeElision: AttrTypeElision::May); |
| 2825 | } |
| 2826 | os << '>'; |
| 2827 | }) |
| 2828 | .Case<ComplexType>(caseFn: [&](ComplexType complexTy) { |
| 2829 | os << "complex<" ; |
| 2830 | printType(type: complexTy.getElementType()); |
| 2831 | os << '>'; |
| 2832 | }) |
| 2833 | .Case<TupleType>(caseFn: [&](TupleType tupleTy) { |
| 2834 | os << "tuple<" ; |
| 2835 | interleaveComma(c: tupleTy.getTypes(), |
| 2836 | eachFn: [&](Type type) { printType(type); }); |
| 2837 | os << '>'; |
| 2838 | }) |
| 2839 | .Case<NoneType>(caseFn: [&](Type) { os << "none" ; }) |
| 2840 | .Default(defaultFn: [&](Type type) { return printDialectType(type); }); |
| 2841 | } |
| 2842 | |
| 2843 | void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
| 2844 | ArrayRef<StringRef> elidedAttrs, |
| 2845 | bool withKeyword) { |
| 2846 | // If there are no attributes, then there is nothing to be done. |
| 2847 | if (attrs.empty()) |
| 2848 | return; |
| 2849 | |
| 2850 | // Functor used to print a filtered attribute list. |
| 2851 | auto printFilteredAttributesFn = [&](auto filteredAttrs) { |
| 2852 | // Print the 'attributes' keyword if necessary. |
| 2853 | if (withKeyword) |
| 2854 | os << " attributes" ; |
| 2855 | |
| 2856 | // Otherwise, print them all out in braces. |
| 2857 | os << " {" ; |
| 2858 | interleaveComma(filteredAttrs, |
| 2859 | [&](NamedAttribute attr) { printNamedAttribute(attr); }); |
| 2860 | os << '}'; |
| 2861 | }; |
| 2862 | |
| 2863 | // If no attributes are elided, we can directly print with no filtering. |
| 2864 | if (elidedAttrs.empty()) |
| 2865 | return printFilteredAttributesFn(attrs); |
| 2866 | |
| 2867 | // Otherwise, filter out any attributes that shouldn't be included. |
| 2868 | llvm::SmallDenseSet<StringRef> (elidedAttrs.begin(), |
| 2869 | elidedAttrs.end()); |
| 2870 | auto filteredAttrs = llvm::make_filter_range(Range&: attrs, Pred: [&](NamedAttribute attr) { |
| 2871 | return !elidedAttrsSet.contains(V: attr.getName().strref()); |
| 2872 | }); |
| 2873 | if (!filteredAttrs.empty()) |
| 2874 | printFilteredAttributesFn(filteredAttrs); |
| 2875 | } |
| 2876 | void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) { |
| 2877 | // Print the name without quotes if possible. |
| 2878 | ::printKeywordOrString(keyword: attr.getName().strref(), os); |
| 2879 | |
| 2880 | // Pretty printing elides the attribute value for unit attributes. |
| 2881 | if (llvm::isa<UnitAttr>(Val: attr.getValue())) |
| 2882 | return; |
| 2883 | |
| 2884 | os << " = " ; |
| 2885 | printAttribute(attr: attr.getValue()); |
| 2886 | } |
| 2887 | |
| 2888 | void AsmPrinter::Impl::printDialectAttribute(Attribute attr) { |
| 2889 | auto &dialect = attr.getDialect(); |
| 2890 | |
| 2891 | // Ask the dialect to serialize the attribute to a string. |
| 2892 | std::string attrName; |
| 2893 | { |
| 2894 | llvm::raw_string_ostream attrNameStr(attrName); |
| 2895 | Impl subPrinter(attrNameStr, state); |
| 2896 | DialectAsmPrinter printer(subPrinter); |
| 2897 | dialect.printAttribute(attr, printer); |
| 2898 | } |
| 2899 | printDialectSymbol(os, symPrefix: "#" , dialectName: dialect.getNamespace(), symString: attrName); |
| 2900 | } |
| 2901 | |
| 2902 | void AsmPrinter::Impl::printDialectType(Type type) { |
| 2903 | auto &dialect = type.getDialect(); |
| 2904 | |
| 2905 | // Ask the dialect to serialize the type to a string. |
| 2906 | std::string typeName; |
| 2907 | { |
| 2908 | llvm::raw_string_ostream typeNameStr(typeName); |
| 2909 | Impl subPrinter(typeNameStr, state); |
| 2910 | DialectAsmPrinter printer(subPrinter); |
| 2911 | dialect.printType(type, printer); |
| 2912 | } |
| 2913 | printDialectSymbol(os, symPrefix: "!" , dialectName: dialect.getNamespace(), symString: typeName); |
| 2914 | } |
| 2915 | |
| 2916 | void AsmPrinter::Impl::printEscapedString(StringRef str) { |
| 2917 | os << "\"" ; |
| 2918 | llvm::printEscapedString(Name: str, Out&: os); |
| 2919 | os << "\"" ; |
| 2920 | } |
| 2921 | |
| 2922 | void AsmPrinter::Impl::printHexString(StringRef str) { |
| 2923 | os << "\"0x" << llvm::toHex(Input: str) << "\"" ; |
| 2924 | } |
| 2925 | void AsmPrinter::Impl::printHexString(ArrayRef<char> data) { |
| 2926 | printHexString(str: StringRef(data.data(), data.size())); |
| 2927 | } |
| 2928 | |
| 2929 | LogicalResult AsmPrinter::Impl::pushCyclicPrinting(const void *opaquePointer) { |
| 2930 | return state.pushCyclicPrinting(opaquePointer); |
| 2931 | } |
| 2932 | |
| 2933 | void AsmPrinter::Impl::popCyclicPrinting() { state.popCyclicPrinting(); } |
| 2934 | |
| 2935 | void AsmPrinter::Impl::printDimensionList(ArrayRef<int64_t> shape) { |
| 2936 | detail::printDimensionList(stream&: os, shape); |
| 2937 | } |
| 2938 | |
| 2939 | //===--------------------------------------------------------------------===// |
| 2940 | // AsmPrinter |
| 2941 | //===--------------------------------------------------------------------===// |
| 2942 | |
| 2943 | AsmPrinter::~AsmPrinter() = default; |
| 2944 | |
| 2945 | raw_ostream &AsmPrinter::getStream() const { |
| 2946 | assert(impl && "expected AsmPrinter::getStream to be overriden" ); |
| 2947 | return impl->getStream(); |
| 2948 | } |
| 2949 | |
| 2950 | /// Print the given floating point value in a stablized form. |
| 2951 | void AsmPrinter::printFloat(const APFloat &value) { |
| 2952 | assert(impl && "expected AsmPrinter::printFloat to be overriden" ); |
| 2953 | printFloatValue(apValue: value, os&: impl->getStream()); |
| 2954 | } |
| 2955 | |
| 2956 | void AsmPrinter::printType(Type type) { |
| 2957 | assert(impl && "expected AsmPrinter::printType to be overriden" ); |
| 2958 | impl->printType(type); |
| 2959 | } |
| 2960 | |
| 2961 | void AsmPrinter::printAttribute(Attribute attr) { |
| 2962 | assert(impl && "expected AsmPrinter::printAttribute to be overriden" ); |
| 2963 | impl->printAttribute(attr); |
| 2964 | } |
| 2965 | |
| 2966 | LogicalResult AsmPrinter::printAlias(Attribute attr) { |
| 2967 | assert(impl && "expected AsmPrinter::printAlias to be overriden" ); |
| 2968 | return impl->printAlias(attr); |
| 2969 | } |
| 2970 | |
| 2971 | LogicalResult AsmPrinter::printAlias(Type type) { |
| 2972 | assert(impl && "expected AsmPrinter::printAlias to be overriden" ); |
| 2973 | return impl->printAlias(type); |
| 2974 | } |
| 2975 | |
| 2976 | void AsmPrinter::printAttributeWithoutType(Attribute attr) { |
| 2977 | assert(impl && |
| 2978 | "expected AsmPrinter::printAttributeWithoutType to be overriden" ); |
| 2979 | impl->printAttribute(attr, typeElision: Impl::AttrTypeElision::Must); |
| 2980 | } |
| 2981 | |
| 2982 | void AsmPrinter::printNamedAttribute(NamedAttribute attr) { |
| 2983 | assert(impl && "expected AsmPrinter::printNamedAttribute to be overriden" ); |
| 2984 | impl->printNamedAttribute(attr); |
| 2985 | } |
| 2986 | |
| 2987 | void AsmPrinter::printKeywordOrString(StringRef keyword) { |
| 2988 | assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden" ); |
| 2989 | ::printKeywordOrString(keyword, os&: impl->getStream()); |
| 2990 | } |
| 2991 | |
| 2992 | void AsmPrinter::printString(StringRef keyword) { |
| 2993 | assert(impl && "expected AsmPrinter::printString to be overriden" ); |
| 2994 | *this << '"'; |
| 2995 | printEscapedString(Name: keyword, Out&: getStream()); |
| 2996 | *this << '"'; |
| 2997 | } |
| 2998 | |
| 2999 | void AsmPrinter::printSymbolName(StringRef symbolRef) { |
| 3000 | assert(impl && "expected AsmPrinter::printSymbolName to be overriden" ); |
| 3001 | ::printSymbolReference(symbolRef, os&: impl->getStream()); |
| 3002 | } |
| 3003 | |
| 3004 | void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) { |
| 3005 | assert(impl && "expected AsmPrinter::printResourceHandle to be overriden" ); |
| 3006 | impl->printResourceHandle(resource); |
| 3007 | } |
| 3008 | |
| 3009 | void AsmPrinter::printDimensionList(ArrayRef<int64_t> shape) { |
| 3010 | detail::printDimensionList(stream&: getStream(), shape); |
| 3011 | } |
| 3012 | |
| 3013 | LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) { |
| 3014 | return impl->pushCyclicPrinting(opaquePointer); |
| 3015 | } |
| 3016 | |
| 3017 | void AsmPrinter::popCyclicPrinting() { impl->popCyclicPrinting(); } |
| 3018 | |
| 3019 | //===----------------------------------------------------------------------===// |
| 3020 | // Affine expressions and maps |
| 3021 | //===----------------------------------------------------------------------===// |
| 3022 | |
| 3023 | void AsmPrinter::Impl::printAffineExpr( |
| 3024 | AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) { |
| 3025 | printAffineExprInternal(expr, enclosingTightness: BindingStrength::Weak, printValueName); |
| 3026 | } |
| 3027 | |
| 3028 | void AsmPrinter::Impl::printAffineExprInternal( |
| 3029 | AffineExpr expr, BindingStrength enclosingTightness, |
| 3030 | function_ref<void(unsigned, bool)> printValueName) { |
| 3031 | const char *binopSpelling = nullptr; |
| 3032 | switch (expr.getKind()) { |
| 3033 | case AffineExprKind::SymbolId: { |
| 3034 | unsigned pos = cast<AffineSymbolExpr>(Val&: expr).getPosition(); |
| 3035 | if (printValueName) |
| 3036 | printValueName(pos, /*isSymbol=*/true); |
| 3037 | else |
| 3038 | os << 's' << pos; |
| 3039 | return; |
| 3040 | } |
| 3041 | case AffineExprKind::DimId: { |
| 3042 | unsigned pos = cast<AffineDimExpr>(Val&: expr).getPosition(); |
| 3043 | if (printValueName) |
| 3044 | printValueName(pos, /*isSymbol=*/false); |
| 3045 | else |
| 3046 | os << 'd' << pos; |
| 3047 | return; |
| 3048 | } |
| 3049 | case AffineExprKind::Constant: |
| 3050 | os << cast<AffineConstantExpr>(Val&: expr).getValue(); |
| 3051 | return; |
| 3052 | case AffineExprKind::Add: |
| 3053 | binopSpelling = " + " ; |
| 3054 | break; |
| 3055 | case AffineExprKind::Mul: |
| 3056 | binopSpelling = " * " ; |
| 3057 | break; |
| 3058 | case AffineExprKind::FloorDiv: |
| 3059 | binopSpelling = " floordiv " ; |
| 3060 | break; |
| 3061 | case AffineExprKind::CeilDiv: |
| 3062 | binopSpelling = " ceildiv " ; |
| 3063 | break; |
| 3064 | case AffineExprKind::Mod: |
| 3065 | binopSpelling = " mod " ; |
| 3066 | break; |
| 3067 | } |
| 3068 | |
| 3069 | auto binOp = cast<AffineBinaryOpExpr>(Val&: expr); |
| 3070 | AffineExpr lhsExpr = binOp.getLHS(); |
| 3071 | AffineExpr rhsExpr = binOp.getRHS(); |
| 3072 | |
| 3073 | // Handle tightly binding binary operators. |
| 3074 | if (binOp.getKind() != AffineExprKind::Add) { |
| 3075 | if (enclosingTightness == BindingStrength::Strong) |
| 3076 | os << '('; |
| 3077 | |
| 3078 | // Pretty print multiplication with -1. |
| 3079 | auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhsExpr); |
| 3080 | if (rhsConst && binOp.getKind() == AffineExprKind::Mul && |
| 3081 | rhsConst.getValue() == -1) { |
| 3082 | os << "-" ; |
| 3083 | printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Strong, printValueName); |
| 3084 | if (enclosingTightness == BindingStrength::Strong) |
| 3085 | os << ')'; |
| 3086 | return; |
| 3087 | } |
| 3088 | |
| 3089 | printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Strong, printValueName); |
| 3090 | |
| 3091 | os << binopSpelling; |
| 3092 | printAffineExprInternal(expr: rhsExpr, enclosingTightness: BindingStrength::Strong, printValueName); |
| 3093 | |
| 3094 | if (enclosingTightness == BindingStrength::Strong) |
| 3095 | os << ')'; |
| 3096 | return; |
| 3097 | } |
| 3098 | |
| 3099 | // Print out special "pretty" forms for add. |
| 3100 | if (enclosingTightness == BindingStrength::Strong) |
| 3101 | os << '('; |
| 3102 | |
| 3103 | // Pretty print addition to a product that has a negative operand as a |
| 3104 | // subtraction. |
| 3105 | if (auto rhs = dyn_cast<AffineBinaryOpExpr>(Val&: rhsExpr)) { |
| 3106 | if (rhs.getKind() == AffineExprKind::Mul) { |
| 3107 | AffineExpr rrhsExpr = rhs.getRHS(); |
| 3108 | if (auto rrhs = dyn_cast<AffineConstantExpr>(Val&: rrhsExpr)) { |
| 3109 | if (rrhs.getValue() == -1) { |
| 3110 | printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Weak, |
| 3111 | printValueName); |
| 3112 | os << " - " ; |
| 3113 | if (rhs.getLHS().getKind() == AffineExprKind::Add) { |
| 3114 | printAffineExprInternal(expr: rhs.getLHS(), enclosingTightness: BindingStrength::Strong, |
| 3115 | printValueName); |
| 3116 | } else { |
| 3117 | printAffineExprInternal(expr: rhs.getLHS(), enclosingTightness: BindingStrength::Weak, |
| 3118 | printValueName); |
| 3119 | } |
| 3120 | |
| 3121 | if (enclosingTightness == BindingStrength::Strong) |
| 3122 | os << ')'; |
| 3123 | return; |
| 3124 | } |
| 3125 | |
| 3126 | if (rrhs.getValue() < -1) { |
| 3127 | printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Weak, |
| 3128 | printValueName); |
| 3129 | os << " - " ; |
| 3130 | printAffineExprInternal(expr: rhs.getLHS(), enclosingTightness: BindingStrength::Strong, |
| 3131 | printValueName); |
| 3132 | os << " * " << -rrhs.getValue(); |
| 3133 | if (enclosingTightness == BindingStrength::Strong) |
| 3134 | os << ')'; |
| 3135 | return; |
| 3136 | } |
| 3137 | } |
| 3138 | } |
| 3139 | } |
| 3140 | |
| 3141 | // Pretty print addition to a negative number as a subtraction. |
| 3142 | if (auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhsExpr)) { |
| 3143 | if (rhsConst.getValue() < 0) { |
| 3144 | printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Weak, printValueName); |
| 3145 | os << " - " << -rhsConst.getValue(); |
| 3146 | if (enclosingTightness == BindingStrength::Strong) |
| 3147 | os << ')'; |
| 3148 | return; |
| 3149 | } |
| 3150 | } |
| 3151 | |
| 3152 | printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Weak, printValueName); |
| 3153 | |
| 3154 | os << " + " ; |
| 3155 | printAffineExprInternal(expr: rhsExpr, enclosingTightness: BindingStrength::Weak, printValueName); |
| 3156 | |
| 3157 | if (enclosingTightness == BindingStrength::Strong) |
| 3158 | os << ')'; |
| 3159 | } |
| 3160 | |
| 3161 | void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) { |
| 3162 | printAffineExprInternal(expr, enclosingTightness: BindingStrength::Weak); |
| 3163 | isEq ? os << " == 0" : os << " >= 0" ; |
| 3164 | } |
| 3165 | |
| 3166 | void AsmPrinter::Impl::printAffineMap(AffineMap map) { |
| 3167 | // Dimension identifiers. |
| 3168 | os << '('; |
| 3169 | for (int i = 0; i < (int)map.getNumDims() - 1; ++i) |
| 3170 | os << 'd' << i << ", " ; |
| 3171 | if (map.getNumDims() >= 1) |
| 3172 | os << 'd' << map.getNumDims() - 1; |
| 3173 | os << ')'; |
| 3174 | |
| 3175 | // Symbolic identifiers. |
| 3176 | if (map.getNumSymbols() != 0) { |
| 3177 | os << '['; |
| 3178 | for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i) |
| 3179 | os << 's' << i << ", " ; |
| 3180 | if (map.getNumSymbols() >= 1) |
| 3181 | os << 's' << map.getNumSymbols() - 1; |
| 3182 | os << ']'; |
| 3183 | } |
| 3184 | |
| 3185 | // Result affine expressions. |
| 3186 | os << " -> (" ; |
| 3187 | interleaveComma(c: map.getResults(), |
| 3188 | eachFn: [&](AffineExpr expr) { printAffineExpr(expr); }); |
| 3189 | os << ')'; |
| 3190 | } |
| 3191 | |
| 3192 | void AsmPrinter::Impl::printIntegerSet(IntegerSet set) { |
| 3193 | // Dimension identifiers. |
| 3194 | os << '('; |
| 3195 | for (unsigned i = 1; i < set.getNumDims(); ++i) |
| 3196 | os << 'd' << i - 1 << ", " ; |
| 3197 | if (set.getNumDims() >= 1) |
| 3198 | os << 'd' << set.getNumDims() - 1; |
| 3199 | os << ')'; |
| 3200 | |
| 3201 | // Symbolic identifiers. |
| 3202 | if (set.getNumSymbols() != 0) { |
| 3203 | os << '['; |
| 3204 | for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i) |
| 3205 | os << 's' << i << ", " ; |
| 3206 | if (set.getNumSymbols() >= 1) |
| 3207 | os << 's' << set.getNumSymbols() - 1; |
| 3208 | os << ']'; |
| 3209 | } |
| 3210 | |
| 3211 | // Print constraints. |
| 3212 | os << " : (" ; |
| 3213 | int numConstraints = set.getNumConstraints(); |
| 3214 | for (int i = 1; i < numConstraints; ++i) { |
| 3215 | printAffineConstraint(expr: set.getConstraint(idx: i - 1), isEq: set.isEq(idx: i - 1)); |
| 3216 | os << ", " ; |
| 3217 | } |
| 3218 | if (numConstraints >= 1) |
| 3219 | printAffineConstraint(expr: set.getConstraint(idx: numConstraints - 1), |
| 3220 | isEq: set.isEq(idx: numConstraints - 1)); |
| 3221 | os << ')'; |
| 3222 | } |
| 3223 | |
| 3224 | //===----------------------------------------------------------------------===// |
| 3225 | // OperationPrinter |
| 3226 | //===----------------------------------------------------------------------===// |
| 3227 | |
| 3228 | namespace { |
| 3229 | /// This class contains the logic for printing operations, regions, and blocks. |
| 3230 | class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter { |
| 3231 | public: |
| 3232 | using Impl = AsmPrinter::Impl; |
| 3233 | using Impl::printType; |
| 3234 | |
| 3235 | explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state) |
| 3236 | : Impl(os, state), OpAsmPrinter(static_cast<Impl &>(*this)) {} |
| 3237 | |
| 3238 | /// Print the given top-level operation. |
| 3239 | void printTopLevelOperation(Operation *op); |
| 3240 | |
| 3241 | /// Print the given operation, including its left-hand side and its right-hand |
| 3242 | /// side, with its indent and location. |
| 3243 | void printFullOpWithIndentAndLoc(Operation *op); |
| 3244 | /// Print the given operation, including its left-hand side and its right-hand |
| 3245 | /// side, but not including indentation and location. |
| 3246 | void printFullOp(Operation *op); |
| 3247 | /// Print the right-hand size of the given operation in the custom or generic |
| 3248 | /// form. |
| 3249 | void printCustomOrGenericOp(Operation *op) override; |
| 3250 | /// Print the right-hand side of the given operation in the generic form. |
| 3251 | void printGenericOp(Operation *op, bool printOpName) override; |
| 3252 | |
| 3253 | /// Print the name of the given block. |
| 3254 | void printBlockName(Block *block); |
| 3255 | |
| 3256 | /// Print the given block. If 'printBlockArgs' is false, the arguments of the |
| 3257 | /// block are not printed. If 'printBlockTerminator' is false, the terminator |
| 3258 | /// operation of the block is not printed. |
| 3259 | void print(Block *block, bool printBlockArgs = true, |
| 3260 | bool printBlockTerminator = true); |
| 3261 | |
| 3262 | /// Print the ID of the given value, optionally with its result number. |
| 3263 | void printValueID(Value value, bool printResultNo = true, |
| 3264 | raw_ostream *streamOverride = nullptr) const; |
| 3265 | |
| 3266 | /// Print the ID of the given operation. |
| 3267 | void printOperationID(Operation *op, |
| 3268 | raw_ostream *streamOverride = nullptr) const; |
| 3269 | |
| 3270 | //===--------------------------------------------------------------------===// |
| 3271 | // OpAsmPrinter methods |
| 3272 | //===--------------------------------------------------------------------===// |
| 3273 | |
| 3274 | /// Print a loc(...) specifier if printing debug info is enabled. Locations |
| 3275 | /// may be deferred with an alias. |
| 3276 | void printOptionalLocationSpecifier(Location loc) override { |
| 3277 | printTrailingLocation(loc); |
| 3278 | } |
| 3279 | |
| 3280 | /// Print a newline and indent the printer to the start of the current |
| 3281 | /// operation. |
| 3282 | void printNewline() override { |
| 3283 | os << newLine; |
| 3284 | os.indent(NumSpaces: currentIndent); |
| 3285 | } |
| 3286 | |
| 3287 | /// Increase indentation. |
| 3288 | void increaseIndent() override { currentIndent += indentWidth; } |
| 3289 | |
| 3290 | /// Decrease indentation. |
| 3291 | void decreaseIndent() override { currentIndent -= indentWidth; } |
| 3292 | |
| 3293 | /// Print a block argument in the usual format of: |
| 3294 | /// %ssaName : type {attr1=42} loc("here") |
| 3295 | /// where location printing is controlled by the standard internal option. |
| 3296 | /// You may pass omitType=true to not print a type, and pass an empty |
| 3297 | /// attribute list if you don't care for attributes. |
| 3298 | void printRegionArgument(BlockArgument arg, |
| 3299 | ArrayRef<NamedAttribute> argAttrs = {}, |
| 3300 | bool omitType = false) override; |
| 3301 | |
| 3302 | /// Print the ID for the given value. |
| 3303 | void printOperand(Value value) override { printValueID(value); } |
| 3304 | void printOperand(Value value, raw_ostream &os) override { |
| 3305 | printValueID(value, /*printResultNo=*/true, streamOverride: &os); |
| 3306 | } |
| 3307 | |
| 3308 | /// Print an optional attribute dictionary with a given set of elided values. |
| 3309 | void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
| 3310 | ArrayRef<StringRef> elidedAttrs = {}) override { |
| 3311 | Impl::printOptionalAttrDict(attrs, elidedAttrs); |
| 3312 | } |
| 3313 | void printOptionalAttrDictWithKeyword( |
| 3314 | ArrayRef<NamedAttribute> attrs, |
| 3315 | ArrayRef<StringRef> elidedAttrs = {}) override { |
| 3316 | Impl::printOptionalAttrDict(attrs, elidedAttrs, |
| 3317 | /*withKeyword=*/true); |
| 3318 | } |
| 3319 | |
| 3320 | /// Print the given successor. |
| 3321 | void printSuccessor(Block *successor) override; |
| 3322 | |
| 3323 | /// Print an operation successor with the operands used for the block |
| 3324 | /// arguments. |
| 3325 | void printSuccessorAndUseList(Block *successor, |
| 3326 | ValueRange succOperands) override; |
| 3327 | |
| 3328 | /// Print the given region. |
| 3329 | void printRegion(Region ®ion, bool printEntryBlockArgs, |
| 3330 | bool printBlockTerminators, bool printEmptyBlock) override; |
| 3331 | |
| 3332 | /// Renumber the arguments for the specified region to the same names as the |
| 3333 | /// SSA values in namesToUse. This may only be used for IsolatedFromAbove |
| 3334 | /// operations. If any entry in namesToUse is null, the corresponding |
| 3335 | /// argument name is left alone. |
| 3336 | void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override { |
| 3337 | state.getSSANameState().shadowRegionArgs(region, namesToUse); |
| 3338 | } |
| 3339 | |
| 3340 | /// Print the given affine map with the symbol and dimension operands printed |
| 3341 | /// inline with the map. |
| 3342 | void printAffineMapOfSSAIds(AffineMapAttr mapAttr, |
| 3343 | ValueRange operands) override; |
| 3344 | |
| 3345 | /// Print the given affine expression with the symbol and dimension operands |
| 3346 | /// printed inline with the expression. |
| 3347 | void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, |
| 3348 | ValueRange symOperands) override; |
| 3349 | |
| 3350 | /// Print users of this operation or id of this operation if it has no result. |
| 3351 | void printUsersComment(Operation *op); |
| 3352 | |
| 3353 | /// Print users of this block arg. |
| 3354 | void printUsersComment(BlockArgument arg); |
| 3355 | |
| 3356 | /// Print the users of a value. |
| 3357 | void printValueUsers(Value value); |
| 3358 | |
| 3359 | /// Print either the ids of the result values or the id of the operation if |
| 3360 | /// the operation has no results. |
| 3361 | void printUserIDs(Operation *user, bool prefixComma = false); |
| 3362 | |
| 3363 | private: |
| 3364 | /// This class represents a resource builder implementation for the MLIR |
| 3365 | /// textual assembly format. |
| 3366 | class ResourceBuilder : public AsmResourceBuilder { |
| 3367 | public: |
| 3368 | using ValueFn = function_ref<void(raw_ostream &)>; |
| 3369 | using PrintFn = function_ref<void(StringRef, ValueFn)>; |
| 3370 | |
| 3371 | ResourceBuilder(PrintFn printFn) : printFn(printFn) {} |
| 3372 | ~ResourceBuilder() override = default; |
| 3373 | |
| 3374 | void buildBool(StringRef key, bool data) final { |
| 3375 | printFn(key, [&](raw_ostream &os) { os << (data ? "true" : "false" ); }); |
| 3376 | } |
| 3377 | |
| 3378 | void buildString(StringRef key, StringRef data) final { |
| 3379 | printFn(key, [&](raw_ostream &os) { |
| 3380 | os << "\"" ; |
| 3381 | llvm::printEscapedString(Name: data, Out&: os); |
| 3382 | os << "\"" ; |
| 3383 | }); |
| 3384 | } |
| 3385 | |
| 3386 | void buildBlob(StringRef key, ArrayRef<char> data, |
| 3387 | uint32_t dataAlignment) final { |
| 3388 | printFn(key, [&](raw_ostream &os) { |
| 3389 | // Store the blob in a hex string containing the alignment and the data. |
| 3390 | llvm::support::ulittle32_t dataAlignmentLE(dataAlignment); |
| 3391 | os << "\"0x" |
| 3392 | << llvm::toHex(Input: StringRef(reinterpret_cast<char *>(&dataAlignmentLE), |
| 3393 | sizeof(dataAlignment))) |
| 3394 | << llvm::toHex(Input: StringRef(data.data(), data.size())) << "\"" ; |
| 3395 | }); |
| 3396 | } |
| 3397 | |
| 3398 | private: |
| 3399 | PrintFn printFn; |
| 3400 | }; |
| 3401 | |
| 3402 | /// Print the metadata dictionary for the file, eliding it if it is empty. |
| 3403 | void printFileMetadataDictionary(Operation *op); |
| 3404 | |
| 3405 | /// Print the resource sections for the file metadata dictionary. |
| 3406 | /// `checkAddMetadataDict` is used to indicate that metadata is going to be |
| 3407 | /// added, and the file metadata dictionary should be started if it hasn't |
| 3408 | /// yet. |
| 3409 | void printResourceFileMetadata(function_ref<void()> checkAddMetadataDict, |
| 3410 | Operation *op); |
| 3411 | |
| 3412 | // Contains the stack of default dialects to use when printing regions. |
| 3413 | // A new dialect is pushed to the stack before parsing regions nested under an |
| 3414 | // operation implementing `OpAsmOpInterface`, and popped when done. At the |
| 3415 | // top-level we start with "builtin" as the default, so that the top-level |
| 3416 | // `module` operation prints as-is. |
| 3417 | SmallVector<StringRef> defaultDialectStack{"builtin" }; |
| 3418 | |
| 3419 | /// The number of spaces used for indenting nested operations. |
| 3420 | const static unsigned indentWidth = 2; |
| 3421 | |
| 3422 | // This is the current indentation level for nested structures. |
| 3423 | unsigned currentIndent = 0; |
| 3424 | }; |
| 3425 | } // namespace |
| 3426 | |
| 3427 | void OperationPrinter::printTopLevelOperation(Operation *op) { |
| 3428 | // Output the aliases at the top level that can't be deferred. |
| 3429 | state.getAliasState().printNonDeferredAliases(p&: *this, newLine); |
| 3430 | |
| 3431 | // Print the module. |
| 3432 | printFullOpWithIndentAndLoc(op); |
| 3433 | os << newLine; |
| 3434 | |
| 3435 | // Output the aliases at the top level that can be deferred. |
| 3436 | state.getAliasState().printDeferredAliases(p&: *this, newLine); |
| 3437 | |
| 3438 | // Output any file level metadata. |
| 3439 | printFileMetadataDictionary(op); |
| 3440 | } |
| 3441 | |
| 3442 | void OperationPrinter::printFileMetadataDictionary(Operation *op) { |
| 3443 | bool sawMetadataEntry = false; |
| 3444 | auto checkAddMetadataDict = [&] { |
| 3445 | if (!std::exchange(obj&: sawMetadataEntry, new_val: true)) |
| 3446 | os << newLine << "{-#" << newLine; |
| 3447 | }; |
| 3448 | |
| 3449 | // Add the various types of metadata. |
| 3450 | printResourceFileMetadata(checkAddMetadataDict, op); |
| 3451 | |
| 3452 | // If the file dictionary exists, close it. |
| 3453 | if (sawMetadataEntry) |
| 3454 | os << newLine << "#-}" << newLine; |
| 3455 | } |
| 3456 | |
| 3457 | void OperationPrinter::printResourceFileMetadata( |
| 3458 | function_ref<void()> checkAddMetadataDict, Operation *op) { |
| 3459 | // Functor used to add data entries to the file metadata dictionary. |
| 3460 | bool hadResource = false; |
| 3461 | bool needResourceComma = false; |
| 3462 | bool needEntryComma = false; |
| 3463 | auto processProvider = [&](StringRef dictName, StringRef name, auto &provider, |
| 3464 | auto &&...providerArgs) { |
| 3465 | bool hadEntry = false; |
| 3466 | auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) { |
| 3467 | checkAddMetadataDict(); |
| 3468 | |
| 3469 | std::string resourceStr; |
| 3470 | auto printResourceStr = [&](raw_ostream &os) { os << resourceStr; }; |
| 3471 | std::optional<uint64_t> charLimit = |
| 3472 | printerFlags.getLargeResourceStringLimit(); |
| 3473 | if (charLimit.has_value()) { |
| 3474 | // Don't compute resourceStr when charLimit is 0. |
| 3475 | if (charLimit.value() == 0) |
| 3476 | return; |
| 3477 | |
| 3478 | llvm::raw_string_ostream ss(resourceStr); |
| 3479 | valueFn(ss); |
| 3480 | |
| 3481 | // Only print entry if its string is small enough. |
| 3482 | if (resourceStr.size() > charLimit.value()) |
| 3483 | return; |
| 3484 | |
| 3485 | // Don't recompute resourceStr when valueFn is called below. |
| 3486 | valueFn = printResourceStr; |
| 3487 | } |
| 3488 | |
| 3489 | // Emit the top-level resource entry if we haven't yet. |
| 3490 | if (!std::exchange(obj&: hadResource, new_val: true)) { |
| 3491 | if (needResourceComma) |
| 3492 | os << "," << newLine; |
| 3493 | os << " " << dictName << "_resources: {" << newLine; |
| 3494 | } |
| 3495 | // Emit the parent resource entry if we haven't yet. |
| 3496 | if (!std::exchange(obj&: hadEntry, new_val: true)) { |
| 3497 | if (needEntryComma) |
| 3498 | os << "," << newLine; |
| 3499 | os << " " << name << ": {" << newLine; |
| 3500 | } else { |
| 3501 | os << "," << newLine; |
| 3502 | } |
| 3503 | os << " " ; |
| 3504 | ::printKeywordOrString(keyword: key, os); |
| 3505 | os << ": " ; |
| 3506 | // Call printResourceStr or original valueFn, depending on charLimit. |
| 3507 | valueFn(os); |
| 3508 | }; |
| 3509 | ResourceBuilder entryBuilder(printFn); |
| 3510 | provider.buildResources(op, providerArgs..., entryBuilder); |
| 3511 | |
| 3512 | needEntryComma |= hadEntry; |
| 3513 | if (hadEntry) |
| 3514 | os << newLine << " }" ; |
| 3515 | }; |
| 3516 | |
| 3517 | // Print the `dialect_resources` section if we have any dialects with |
| 3518 | // resources. |
| 3519 | for (const OpAsmDialectInterface &interface : state.getDialectInterfaces()) { |
| 3520 | auto &dialectResources = state.getDialectResources(); |
| 3521 | StringRef name = interface.getDialect()->getNamespace(); |
| 3522 | auto it = dialectResources.find(Val: interface.getDialect()); |
| 3523 | if (it != dialectResources.end()) |
| 3524 | processProvider("dialect" , name, interface, it->second); |
| 3525 | else |
| 3526 | processProvider("dialect" , name, interface, |
| 3527 | SetVector<AsmDialectResourceHandle>()); |
| 3528 | } |
| 3529 | if (hadResource) |
| 3530 | os << newLine << " }" ; |
| 3531 | |
| 3532 | // Print the `external_resources` section if we have any external clients with |
| 3533 | // resources. |
| 3534 | needEntryComma = false; |
| 3535 | needResourceComma = hadResource; |
| 3536 | hadResource = false; |
| 3537 | for (const auto &printer : state.getResourcePrinters()) |
| 3538 | processProvider("external" , printer.getName(), printer); |
| 3539 | if (hadResource) |
| 3540 | os << newLine << " }" ; |
| 3541 | } |
| 3542 | |
| 3543 | /// Print a block argument in the usual format of: |
| 3544 | /// %ssaName : type {attr1=42} loc("here") |
| 3545 | /// where location printing is controlled by the standard internal option. |
| 3546 | /// You may pass omitType=true to not print a type, and pass an empty |
| 3547 | /// attribute list if you don't care for attributes. |
| 3548 | void OperationPrinter::printRegionArgument(BlockArgument arg, |
| 3549 | ArrayRef<NamedAttribute> argAttrs, |
| 3550 | bool omitType) { |
| 3551 | printOperand(value: arg); |
| 3552 | if (!omitType) { |
| 3553 | os << ": " ; |
| 3554 | printType(type: arg.getType()); |
| 3555 | } |
| 3556 | printOptionalAttrDict(attrs: argAttrs); |
| 3557 | // TODO: We should allow location aliases on block arguments. |
| 3558 | printTrailingLocation(loc: arg.getLoc(), /*allowAlias*/ false); |
| 3559 | } |
| 3560 | |
| 3561 | void OperationPrinter::printFullOpWithIndentAndLoc(Operation *op) { |
| 3562 | // Track the location of this operation. |
| 3563 | state.registerOperationLocation(op, line: newLine.curLine, col: currentIndent); |
| 3564 | |
| 3565 | os.indent(NumSpaces: currentIndent); |
| 3566 | printFullOp(op); |
| 3567 | printTrailingLocation(loc: op->getLoc()); |
| 3568 | if (printerFlags.shouldPrintValueUsers()) |
| 3569 | printUsersComment(op); |
| 3570 | } |
| 3571 | |
| 3572 | void OperationPrinter::printFullOp(Operation *op) { |
| 3573 | if (size_t numResults = op->getNumResults()) { |
| 3574 | auto printResultGroup = [&](size_t resultNo, size_t resultCount) { |
| 3575 | printValueID(value: op->getResult(idx: resultNo), /*printResultNo=*/false); |
| 3576 | if (resultCount > 1) |
| 3577 | os << ':' << resultCount; |
| 3578 | }; |
| 3579 | |
| 3580 | // Check to see if this operation has multiple result groups. |
| 3581 | ArrayRef<int> resultGroups = state.getSSANameState().getOpResultGroups(op); |
| 3582 | if (!resultGroups.empty()) { |
| 3583 | // Interleave the groups excluding the last one, this one will be handled |
| 3584 | // separately. |
| 3585 | interleaveComma(c: llvm::seq<int>(Begin: 0, End: resultGroups.size() - 1), eachFn: [&](int i) { |
| 3586 | printResultGroup(resultGroups[i], |
| 3587 | resultGroups[i + 1] - resultGroups[i]); |
| 3588 | }); |
| 3589 | os << ", " ; |
| 3590 | printResultGroup(resultGroups.back(), numResults - resultGroups.back()); |
| 3591 | |
| 3592 | } else { |
| 3593 | printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults); |
| 3594 | } |
| 3595 | |
| 3596 | os << " = " ; |
| 3597 | } |
| 3598 | |
| 3599 | printCustomOrGenericOp(op); |
| 3600 | } |
| 3601 | |
| 3602 | void OperationPrinter::(Operation *op) { |
| 3603 | unsigned numResults = op->getNumResults(); |
| 3604 | if (!numResults && op->getNumOperands()) { |
| 3605 | os << " // id: " ; |
| 3606 | printOperationID(op); |
| 3607 | } else if (numResults && op->use_empty()) { |
| 3608 | os << " // unused" ; |
| 3609 | } else if (numResults && !op->use_empty()) { |
| 3610 | // Print "user" if the operation has one result used to compute one other |
| 3611 | // result, or is used in one operation with no result. |
| 3612 | unsigned usedInNResults = 0; |
| 3613 | unsigned usedInNOperations = 0; |
| 3614 | SmallPtrSet<Operation *, 1> userSet; |
| 3615 | for (Operation *user : op->getUsers()) { |
| 3616 | if (userSet.insert(Ptr: user).second) { |
| 3617 | ++usedInNOperations; |
| 3618 | usedInNResults += user->getNumResults(); |
| 3619 | } |
| 3620 | } |
| 3621 | |
| 3622 | // We already know that users is not empty. |
| 3623 | bool exactlyOneUniqueUse = |
| 3624 | usedInNResults <= 1 && usedInNOperations <= 1 && numResults == 1; |
| 3625 | os << " // " << (exactlyOneUniqueUse ? "user" : "users" ) << ": " ; |
| 3626 | bool shouldPrintBrackets = numResults > 1; |
| 3627 | auto printOpResult = [&](OpResult opResult) { |
| 3628 | if (shouldPrintBrackets) |
| 3629 | os << "(" ; |
| 3630 | printValueUsers(value: opResult); |
| 3631 | if (shouldPrintBrackets) |
| 3632 | os << ")" ; |
| 3633 | }; |
| 3634 | |
| 3635 | interleaveComma(c: op->getResults(), eachFn: printOpResult); |
| 3636 | } |
| 3637 | } |
| 3638 | |
| 3639 | void OperationPrinter::(BlockArgument arg) { |
| 3640 | os << "// " ; |
| 3641 | printValueID(value: arg); |
| 3642 | if (arg.use_empty()) { |
| 3643 | os << " is unused" ; |
| 3644 | } else { |
| 3645 | os << " is used by " ; |
| 3646 | printValueUsers(value: arg); |
| 3647 | } |
| 3648 | os << newLine; |
| 3649 | } |
| 3650 | |
| 3651 | void OperationPrinter::printValueUsers(Value value) { |
| 3652 | if (value.use_empty()) |
| 3653 | os << "unused" ; |
| 3654 | |
| 3655 | // One value might be used as the operand of an operation more than once. |
| 3656 | // Only print the operations results once in that case. |
| 3657 | SmallPtrSet<Operation *, 1> userSet; |
| 3658 | for (auto [index, user] : enumerate(First: value.getUsers())) { |
| 3659 | if (userSet.insert(Ptr: user).second) |
| 3660 | printUserIDs(user, prefixComma: index); |
| 3661 | } |
| 3662 | } |
| 3663 | |
| 3664 | void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) { |
| 3665 | if (prefixComma) |
| 3666 | os << ", " ; |
| 3667 | |
| 3668 | if (!user->getNumResults()) { |
| 3669 | printOperationID(op: user); |
| 3670 | } else { |
| 3671 | interleaveComma(c: user->getResults(), |
| 3672 | eachFn: [this](Value result) { printValueID(value: result); }); |
| 3673 | } |
| 3674 | } |
| 3675 | |
| 3676 | void OperationPrinter::printCustomOrGenericOp(Operation *op) { |
| 3677 | // If requested, always print the generic form. |
| 3678 | if (!printerFlags.shouldPrintGenericOpForm()) { |
| 3679 | // Check to see if this is a known operation. If so, use the registered |
| 3680 | // custom printer hook. |
| 3681 | if (auto opInfo = op->getRegisteredInfo()) { |
| 3682 | opInfo->printAssembly(op, p&: *this, defaultDialect: defaultDialectStack.back()); |
| 3683 | return; |
| 3684 | } |
| 3685 | // Otherwise try to dispatch to the dialect, if available. |
| 3686 | if (Dialect *dialect = op->getDialect()) { |
| 3687 | if (auto opPrinter = dialect->getOperationPrinter(op)) { |
| 3688 | // Print the op name first. |
| 3689 | StringRef name = op->getName().getStringRef(); |
| 3690 | // Only drop the default dialect prefix when it cannot lead to |
| 3691 | // ambiguities. |
| 3692 | if (name.count(C: '.') == 1) |
| 3693 | name.consume_front(Prefix: (defaultDialectStack.back() + "." ).str()); |
| 3694 | os << name; |
| 3695 | |
| 3696 | // Print the rest of the op now. |
| 3697 | opPrinter(op, *this); |
| 3698 | return; |
| 3699 | } |
| 3700 | } |
| 3701 | } |
| 3702 | |
| 3703 | // Otherwise print with the generic assembly form. |
| 3704 | printGenericOp(op, /*printOpName=*/true); |
| 3705 | } |
| 3706 | |
| 3707 | void OperationPrinter::printGenericOp(Operation *op, bool printOpName) { |
| 3708 | if (printOpName) |
| 3709 | printEscapedString(str: op->getName().getStringRef()); |
| 3710 | os << '('; |
| 3711 | interleaveComma(c: op->getOperands(), eachFn: [&](Value value) { printValueID(value); }); |
| 3712 | os << ')'; |
| 3713 | |
| 3714 | // For terminators, print the list of successors and their operands. |
| 3715 | if (op->getNumSuccessors() != 0) { |
| 3716 | os << '['; |
| 3717 | interleaveComma(c: op->getSuccessors(), |
| 3718 | eachFn: [&](Block *successor) { printBlockName(block: successor); }); |
| 3719 | os << ']'; |
| 3720 | } |
| 3721 | |
| 3722 | // Print the properties. |
| 3723 | if (Attribute prop = op->getPropertiesAsAttribute()) { |
| 3724 | os << " <" ; |
| 3725 | Impl::printAttribute(attr: prop); |
| 3726 | os << '>'; |
| 3727 | } |
| 3728 | |
| 3729 | // Print regions. |
| 3730 | if (op->getNumRegions() != 0) { |
| 3731 | os << " (" ; |
| 3732 | interleaveComma(c: op->getRegions(), eachFn: [&](Region ®ion) { |
| 3733 | printRegion(region, /*printEntryBlockArgs=*/true, |
| 3734 | /*printBlockTerminators=*/true, /*printEmptyBlock=*/true); |
| 3735 | }); |
| 3736 | os << ')'; |
| 3737 | } |
| 3738 | |
| 3739 | printOptionalAttrDict(attrs: op->getPropertiesStorage() |
| 3740 | ? llvm::to_vector(Range: op->getDiscardableAttrs()) |
| 3741 | : op->getAttrs()); |
| 3742 | |
| 3743 | // Print the type signature of the operation. |
| 3744 | os << " : " ; |
| 3745 | printFunctionalType(op); |
| 3746 | } |
| 3747 | |
| 3748 | void OperationPrinter::printBlockName(Block *block) { |
| 3749 | os << state.getSSANameState().getBlockInfo(block).name; |
| 3750 | } |
| 3751 | |
| 3752 | void OperationPrinter::print(Block *block, bool printBlockArgs, |
| 3753 | bool printBlockTerminator) { |
| 3754 | // Print the block label and argument list if requested. |
| 3755 | if (printBlockArgs) { |
| 3756 | os.indent(NumSpaces: currentIndent); |
| 3757 | printBlockName(block); |
| 3758 | |
| 3759 | // Print the argument list if non-empty. |
| 3760 | if (!block->args_empty()) { |
| 3761 | os << '('; |
| 3762 | interleaveComma(c: block->getArguments(), eachFn: [&](BlockArgument arg) { |
| 3763 | printValueID(value: arg); |
| 3764 | os << ": " ; |
| 3765 | printType(type: arg.getType()); |
| 3766 | // TODO: We should allow location aliases on block arguments. |
| 3767 | printTrailingLocation(loc: arg.getLoc(), /*allowAlias*/ false); |
| 3768 | }); |
| 3769 | os << ')'; |
| 3770 | } |
| 3771 | os << ':'; |
| 3772 | |
| 3773 | // Print out some context information about the predecessors of this block. |
| 3774 | if (!block->getParent()) { |
| 3775 | os << " // block is not in a region!" ; |
| 3776 | } else if (block->hasNoPredecessors()) { |
| 3777 | if (!block->isEntryBlock()) |
| 3778 | os << " // no predecessors" ; |
| 3779 | } else if (auto *pred = block->getSinglePredecessor()) { |
| 3780 | os << " // pred: " ; |
| 3781 | printBlockName(block: pred); |
| 3782 | } else { |
| 3783 | // We want to print the predecessors in a stable order, not in |
| 3784 | // whatever order the use-list is in, so gather and sort them. |
| 3785 | SmallVector<BlockInfo, 4> predIDs; |
| 3786 | for (auto *pred : block->getPredecessors()) |
| 3787 | predIDs.push_back(Elt: state.getSSANameState().getBlockInfo(block: pred)); |
| 3788 | llvm::sort(C&: predIDs, Comp: [](BlockInfo lhs, BlockInfo rhs) { |
| 3789 | return lhs.ordering < rhs.ordering; |
| 3790 | }); |
| 3791 | |
| 3792 | os << " // " << predIDs.size() << " preds: " ; |
| 3793 | |
| 3794 | interleaveComma(c: predIDs, eachFn: [&](BlockInfo pred) { os << pred.name; }); |
| 3795 | } |
| 3796 | os << newLine; |
| 3797 | } |
| 3798 | |
| 3799 | currentIndent += indentWidth; |
| 3800 | |
| 3801 | if (printerFlags.shouldPrintValueUsers()) { |
| 3802 | for (BlockArgument arg : block->getArguments()) { |
| 3803 | os.indent(NumSpaces: currentIndent); |
| 3804 | printUsersComment(arg); |
| 3805 | } |
| 3806 | } |
| 3807 | |
| 3808 | bool hasTerminator = |
| 3809 | !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>(); |
| 3810 | auto range = llvm::make_range( |
| 3811 | x: block->begin(), |
| 3812 | y: std::prev(x: block->end(), |
| 3813 | n: (!hasTerminator || printBlockTerminator) ? 0 : 1)); |
| 3814 | for (auto &op : range) { |
| 3815 | printFullOpWithIndentAndLoc(op: &op); |
| 3816 | os << newLine; |
| 3817 | } |
| 3818 | currentIndent -= indentWidth; |
| 3819 | } |
| 3820 | |
| 3821 | void OperationPrinter::printValueID(Value value, bool printResultNo, |
| 3822 | raw_ostream *streamOverride) const { |
| 3823 | state.getSSANameState().printValueID(value, printResultNo, |
| 3824 | stream&: streamOverride ? *streamOverride : os); |
| 3825 | } |
| 3826 | |
| 3827 | void OperationPrinter::printOperationID(Operation *op, |
| 3828 | raw_ostream *streamOverride) const { |
| 3829 | state.getSSANameState().printOperationID(op, stream&: streamOverride ? *streamOverride |
| 3830 | : os); |
| 3831 | } |
| 3832 | |
| 3833 | void OperationPrinter::printSuccessor(Block *successor) { |
| 3834 | printBlockName(block: successor); |
| 3835 | } |
| 3836 | |
| 3837 | void OperationPrinter::printSuccessorAndUseList(Block *successor, |
| 3838 | ValueRange succOperands) { |
| 3839 | printBlockName(block: successor); |
| 3840 | if (succOperands.empty()) |
| 3841 | return; |
| 3842 | |
| 3843 | os << '('; |
| 3844 | interleaveComma(c: succOperands, |
| 3845 | eachFn: [this](Value operand) { printValueID(value: operand); }); |
| 3846 | os << " : " ; |
| 3847 | interleaveComma(c: succOperands, |
| 3848 | eachFn: [this](Value operand) { printType(type: operand.getType()); }); |
| 3849 | os << ')'; |
| 3850 | } |
| 3851 | |
| 3852 | void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs, |
| 3853 | bool printBlockTerminators, |
| 3854 | bool printEmptyBlock) { |
| 3855 | if (printerFlags.shouldSkipRegions()) { |
| 3856 | os << "{...}" ; |
| 3857 | return; |
| 3858 | } |
| 3859 | os << "{" << newLine; |
| 3860 | if (!region.empty()) { |
| 3861 | auto restoreDefaultDialect = |
| 3862 | llvm::make_scope_exit(F: [&]() { defaultDialectStack.pop_back(); }); |
| 3863 | if (auto iface = dyn_cast<OpAsmOpInterface>(Val: region.getParentOp())) |
| 3864 | defaultDialectStack.push_back(Elt: iface.getDefaultDialect()); |
| 3865 | else |
| 3866 | defaultDialectStack.push_back(Elt: "" ); |
| 3867 | |
| 3868 | auto *entryBlock = ®ion.front(); |
| 3869 | // Force printing the block header if printEmptyBlock is set and the block |
| 3870 | // is empty or if printEntryBlockArgs is set and there are arguments to |
| 3871 | // print. |
| 3872 | bool = |
| 3873 | (printEmptyBlock && entryBlock->empty()) || |
| 3874 | (printEntryBlockArgs && entryBlock->getNumArguments() != 0); |
| 3875 | print(block: entryBlock, printBlockArgs: shouldAlwaysPrintBlockHeader, printBlockTerminator: printBlockTerminators); |
| 3876 | for (auto &b : llvm::drop_begin(RangeOrContainer&: region.getBlocks(), N: 1)) |
| 3877 | print(block: &b); |
| 3878 | } |
| 3879 | os.indent(NumSpaces: currentIndent) << "}" ; |
| 3880 | } |
| 3881 | |
| 3882 | void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr, |
| 3883 | ValueRange operands) { |
| 3884 | if (!mapAttr) { |
| 3885 | os << "<<NULL AFFINE MAP>>" ; |
| 3886 | return; |
| 3887 | } |
| 3888 | AffineMap map = mapAttr.getValue(); |
| 3889 | unsigned numDims = map.getNumDims(); |
| 3890 | auto printValueName = [&](unsigned pos, bool isSymbol) { |
| 3891 | unsigned index = isSymbol ? numDims + pos : pos; |
| 3892 | assert(index < operands.size()); |
| 3893 | if (isSymbol) |
| 3894 | os << "symbol(" ; |
| 3895 | printValueID(value: operands[index]); |
| 3896 | if (isSymbol) |
| 3897 | os << ')'; |
| 3898 | }; |
| 3899 | |
| 3900 | interleaveComma(c: map.getResults(), eachFn: [&](AffineExpr expr) { |
| 3901 | printAffineExpr(expr, printValueName); |
| 3902 | }); |
| 3903 | } |
| 3904 | |
| 3905 | void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr, |
| 3906 | ValueRange dimOperands, |
| 3907 | ValueRange symOperands) { |
| 3908 | auto printValueName = [&](unsigned pos, bool isSymbol) { |
| 3909 | if (!isSymbol) |
| 3910 | return printValueID(value: dimOperands[pos]); |
| 3911 | os << "symbol(" ; |
| 3912 | printValueID(value: symOperands[pos]); |
| 3913 | os << ')'; |
| 3914 | }; |
| 3915 | printAffineExpr(expr, printValueName); |
| 3916 | } |
| 3917 | |
| 3918 | //===----------------------------------------------------------------------===// |
| 3919 | // print and dump methods |
| 3920 | //===----------------------------------------------------------------------===// |
| 3921 | |
| 3922 | void Attribute::print(raw_ostream &os, bool elideType) const { |
| 3923 | if (!*this) { |
| 3924 | os << "<<NULL ATTRIBUTE>>" ; |
| 3925 | return; |
| 3926 | } |
| 3927 | |
| 3928 | AsmState state(getContext()); |
| 3929 | print(os, state, elideType); |
| 3930 | } |
| 3931 | void Attribute::print(raw_ostream &os, AsmState &state, bool elideType) const { |
| 3932 | using AttrTypeElision = AsmPrinter::Impl::AttrTypeElision; |
| 3933 | AsmPrinter::Impl(os, state.getImpl()) |
| 3934 | .printAttribute(attr: *this, typeElision: elideType ? AttrTypeElision::Must |
| 3935 | : AttrTypeElision::Never); |
| 3936 | } |
| 3937 | |
| 3938 | void Attribute::dump() const { |
| 3939 | print(os&: llvm::errs()); |
| 3940 | llvm::errs() << "\n" ; |
| 3941 | } |
| 3942 | |
| 3943 | void Attribute::printStripped(raw_ostream &os, AsmState &state) const { |
| 3944 | if (!*this) { |
| 3945 | os << "<<NULL ATTRIBUTE>>" ; |
| 3946 | return; |
| 3947 | } |
| 3948 | |
| 3949 | AsmPrinter::Impl subPrinter(os, state.getImpl()); |
| 3950 | if (succeeded(Result: subPrinter.printAlias(attr: *this))) |
| 3951 | return; |
| 3952 | |
| 3953 | auto &dialect = this->getDialect(); |
| 3954 | uint64_t posPrior = os.tell(); |
| 3955 | DialectAsmPrinter printer(subPrinter); |
| 3956 | dialect.printAttribute(*this, printer); |
| 3957 | if (posPrior != os.tell()) |
| 3958 | return; |
| 3959 | |
| 3960 | // Fallback to printing with prefix if the above failed to write anything |
| 3961 | // to the output stream. |
| 3962 | print(os, state); |
| 3963 | } |
| 3964 | void Attribute::printStripped(raw_ostream &os) const { |
| 3965 | if (!*this) { |
| 3966 | os << "<<NULL ATTRIBUTE>>" ; |
| 3967 | return; |
| 3968 | } |
| 3969 | |
| 3970 | AsmState state(getContext()); |
| 3971 | printStripped(os, state); |
| 3972 | } |
| 3973 | |
| 3974 | void Type::print(raw_ostream &os) const { |
| 3975 | if (!*this) { |
| 3976 | os << "<<NULL TYPE>>" ; |
| 3977 | return; |
| 3978 | } |
| 3979 | |
| 3980 | AsmState state(getContext()); |
| 3981 | print(os, state); |
| 3982 | } |
| 3983 | void Type::print(raw_ostream &os, AsmState &state) const { |
| 3984 | AsmPrinter::Impl(os, state.getImpl()).printType(type: *this); |
| 3985 | } |
| 3986 | |
| 3987 | void Type::dump() const { |
| 3988 | print(os&: llvm::errs()); |
| 3989 | llvm::errs() << "\n" ; |
| 3990 | } |
| 3991 | |
| 3992 | void AffineMap::dump() const { |
| 3993 | print(os&: llvm::errs()); |
| 3994 | llvm::errs() << "\n" ; |
| 3995 | } |
| 3996 | |
| 3997 | void IntegerSet::dump() const { |
| 3998 | print(os&: llvm::errs()); |
| 3999 | llvm::errs() << "\n" ; |
| 4000 | } |
| 4001 | |
| 4002 | void AffineExpr::print(raw_ostream &os) const { |
| 4003 | if (!expr) { |
| 4004 | os << "<<NULL AFFINE EXPR>>" ; |
| 4005 | return; |
| 4006 | } |
| 4007 | AsmState state(getContext()); |
| 4008 | AsmPrinter::Impl(os, state.getImpl()).printAffineExpr(expr: *this); |
| 4009 | } |
| 4010 | |
| 4011 | void AffineExpr::dump() const { |
| 4012 | print(os&: llvm::errs()); |
| 4013 | llvm::errs() << "\n" ; |
| 4014 | } |
| 4015 | |
| 4016 | void AffineMap::print(raw_ostream &os) const { |
| 4017 | if (!map) { |
| 4018 | os << "<<NULL AFFINE MAP>>" ; |
| 4019 | return; |
| 4020 | } |
| 4021 | AsmState state(getContext()); |
| 4022 | AsmPrinter::Impl(os, state.getImpl()).printAffineMap(map: *this); |
| 4023 | } |
| 4024 | |
| 4025 | void IntegerSet::print(raw_ostream &os) const { |
| 4026 | AsmState state(getContext()); |
| 4027 | AsmPrinter::Impl(os, state.getImpl()).printIntegerSet(set: *this); |
| 4028 | } |
| 4029 | |
| 4030 | void Value::print(raw_ostream &os) const { print(os, flags: OpPrintingFlags()); } |
| 4031 | void Value::print(raw_ostream &os, const OpPrintingFlags &flags) const { |
| 4032 | if (!impl) { |
| 4033 | os << "<<NULL VALUE>>" ; |
| 4034 | return; |
| 4035 | } |
| 4036 | |
| 4037 | if (auto *op = getDefiningOp()) |
| 4038 | return op->print(os, flags); |
| 4039 | // TODO: Improve BlockArgument print'ing. |
| 4040 | BlockArgument arg = llvm::cast<BlockArgument>(Val: *this); |
| 4041 | os << "<block argument> of type '" << arg.getType() |
| 4042 | << "' at index: " << arg.getArgNumber(); |
| 4043 | } |
| 4044 | void Value::print(raw_ostream &os, AsmState &state) const { |
| 4045 | if (!impl) { |
| 4046 | os << "<<NULL VALUE>>" ; |
| 4047 | return; |
| 4048 | } |
| 4049 | |
| 4050 | if (auto *op = getDefiningOp()) |
| 4051 | return op->print(os, state); |
| 4052 | |
| 4053 | // TODO: Improve BlockArgument print'ing. |
| 4054 | BlockArgument arg = llvm::cast<BlockArgument>(Val: *this); |
| 4055 | os << "<block argument> of type '" << arg.getType() |
| 4056 | << "' at index: " << arg.getArgNumber(); |
| 4057 | } |
| 4058 | |
| 4059 | void Value::dump() const { |
| 4060 | print(os&: llvm::errs()); |
| 4061 | llvm::errs() << "\n" ; |
| 4062 | } |
| 4063 | |
| 4064 | void Value::printAsOperand(raw_ostream &os, AsmState &state) const { |
| 4065 | // TODO: This doesn't necessarily capture all potential cases. |
| 4066 | // Currently, region arguments can be shadowed when printing the main |
| 4067 | // operation. If the IR hasn't been printed, this will produce the old SSA |
| 4068 | // name and not the shadowed name. |
| 4069 | state.getImpl().getSSANameState().printValueID(value: *this, /*printResultNo=*/true, |
| 4070 | stream&: os); |
| 4071 | } |
| 4072 | |
| 4073 | static Operation *findParent(Operation *op, bool shouldUseLocalScope) { |
| 4074 | do { |
| 4075 | // If we are printing local scope, stop at the first operation that is |
| 4076 | // isolated from above. |
| 4077 | if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>()) |
| 4078 | break; |
| 4079 | |
| 4080 | // Otherwise, traverse up to the next parent. |
| 4081 | Operation *parentOp = op->getParentOp(); |
| 4082 | if (!parentOp) |
| 4083 | break; |
| 4084 | op = parentOp; |
| 4085 | } while (true); |
| 4086 | return op; |
| 4087 | } |
| 4088 | |
| 4089 | void Value::printAsOperand(raw_ostream &os, |
| 4090 | const OpPrintingFlags &flags) const { |
| 4091 | Operation *op; |
| 4092 | if (auto result = llvm::dyn_cast<OpResult>(Val: *this)) { |
| 4093 | op = result.getOwner(); |
| 4094 | } else { |
| 4095 | op = llvm::cast<BlockArgument>(Val: *this).getOwner()->getParentOp(); |
| 4096 | if (!op) { |
| 4097 | os << "<<UNKNOWN SSA VALUE>>" ; |
| 4098 | return; |
| 4099 | } |
| 4100 | } |
| 4101 | op = findParent(op, shouldUseLocalScope: flags.shouldUseLocalScope()); |
| 4102 | AsmState state(op, flags); |
| 4103 | printAsOperand(os, state); |
| 4104 | } |
| 4105 | |
| 4106 | void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) { |
| 4107 | // Find the operation to number from based upon the provided flags. |
| 4108 | Operation *op = findParent(op: this, shouldUseLocalScope: printerFlags.shouldUseLocalScope()); |
| 4109 | AsmState state(op, printerFlags); |
| 4110 | print(os, state); |
| 4111 | } |
| 4112 | void Operation::print(raw_ostream &os, AsmState &state) { |
| 4113 | OperationPrinter printer(os, state.getImpl()); |
| 4114 | if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) { |
| 4115 | state.getImpl().initializeAliases(op: this); |
| 4116 | printer.printTopLevelOperation(op: this); |
| 4117 | } else { |
| 4118 | printer.printFullOpWithIndentAndLoc(op: this); |
| 4119 | } |
| 4120 | } |
| 4121 | |
| 4122 | void Operation::dump() { |
| 4123 | print(os&: llvm::errs(), printerFlags: OpPrintingFlags().useLocalScope()); |
| 4124 | llvm::errs() << "\n" ; |
| 4125 | } |
| 4126 | |
| 4127 | void Operation::dumpPretty() { |
| 4128 | print(os&: llvm::errs(), printerFlags: OpPrintingFlags().useLocalScope().assumeVerified()); |
| 4129 | llvm::errs() << "\n" ; |
| 4130 | } |
| 4131 | |
| 4132 | void Block::print(raw_ostream &os) { |
| 4133 | Operation *parentOp = getParentOp(); |
| 4134 | if (!parentOp) { |
| 4135 | os << "<<UNLINKED BLOCK>>\n" ; |
| 4136 | return; |
| 4137 | } |
| 4138 | // Get the top-level op. |
| 4139 | while (auto *nextOp = parentOp->getParentOp()) |
| 4140 | parentOp = nextOp; |
| 4141 | |
| 4142 | AsmState state(parentOp); |
| 4143 | print(os, state); |
| 4144 | } |
| 4145 | void Block::print(raw_ostream &os, AsmState &state) { |
| 4146 | OperationPrinter(os, state.getImpl()).print(block: this); |
| 4147 | } |
| 4148 | |
| 4149 | void Block::dump() { print(os&: llvm::errs()); } |
| 4150 | |
| 4151 | /// Print out the name of the block without printing its body. |
| 4152 | void Block::printAsOperand(raw_ostream &os, bool printType) { |
| 4153 | Operation *parentOp = getParentOp(); |
| 4154 | if (!parentOp) { |
| 4155 | os << "<<UNLINKED BLOCK>>\n" ; |
| 4156 | return; |
| 4157 | } |
| 4158 | AsmState state(parentOp); |
| 4159 | printAsOperand(os, state); |
| 4160 | } |
| 4161 | void Block::printAsOperand(raw_ostream &os, AsmState &state) { |
| 4162 | OperationPrinter printer(os, state.getImpl()); |
| 4163 | printer.printBlockName(block: this); |
| 4164 | } |
| 4165 | |
| 4166 | raw_ostream &mlir::operator<<(raw_ostream &os, Block &block) { |
| 4167 | block.print(os); |
| 4168 | return os; |
| 4169 | } |
| 4170 | |
| 4171 | //===--------------------------------------------------------------------===// |
| 4172 | // Custom printers |
| 4173 | //===--------------------------------------------------------------------===// |
| 4174 | namespace mlir { |
| 4175 | |
| 4176 | void printDimensionList(OpAsmPrinter &printer, Operation *op, |
| 4177 | ArrayRef<int64_t> dimensions) { |
| 4178 | if (dimensions.empty()) |
| 4179 | printer << "[" ; |
| 4180 | printer.printDimensionList(shape: dimensions); |
| 4181 | if (dimensions.empty()) |
| 4182 | printer << "]" ; |
| 4183 | } |
| 4184 | |
| 4185 | ParseResult parseDimensionList(OpAsmParser &parser, |
| 4186 | DenseI64ArrayAttr &dimensions) { |
| 4187 | // Empty list case denoted by "[]". |
| 4188 | if (succeeded(Result: parser.parseOptionalLSquare())) { |
| 4189 | if (failed(Result: parser.parseRSquare())) { |
| 4190 | return parser.emitError(loc: parser.getCurrentLocation()) |
| 4191 | << "Failed parsing dimension list." ; |
| 4192 | } |
| 4193 | dimensions = |
| 4194 | DenseI64ArrayAttr::get(context: parser.getContext(), content: ArrayRef<int64_t>()); |
| 4195 | return success(); |
| 4196 | } |
| 4197 | |
| 4198 | // Non-empty list case. |
| 4199 | SmallVector<int64_t> shapeArr; |
| 4200 | if (failed(Result: parser.parseDimensionList(dimensions&: shapeArr, allowDynamic: true, withTrailingX: false))) { |
| 4201 | return parser.emitError(loc: parser.getCurrentLocation()) |
| 4202 | << "Failed parsing dimension list." ; |
| 4203 | } |
| 4204 | if (shapeArr.empty()) { |
| 4205 | return parser.emitError(loc: parser.getCurrentLocation()) |
| 4206 | << "Failed parsing dimension list. Did you mean an empty list? It " |
| 4207 | "must be denoted by \"[]\"." ; |
| 4208 | } |
| 4209 | dimensions = DenseI64ArrayAttr::get(context: parser.getContext(), content: shapeArr); |
| 4210 | return success(); |
| 4211 | } |
| 4212 | |
| 4213 | } // namespace mlir |
| 4214 | |