| 1 | //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements the MLIR AsmPrinter class, which is used to implement |
| 10 | // the various print() methods on the core IR objects. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/IR/AffineExpr.h" |
| 15 | #include "mlir/IR/AffineMap.h" |
| 16 | #include "mlir/IR/AsmState.h" |
| 17 | #include "mlir/IR/Attributes.h" |
| 18 | #include "mlir/IR/Builders.h" |
| 19 | #include "mlir/IR/BuiltinAttributes.h" |
| 20 | #include "mlir/IR/BuiltinDialect.h" |
| 21 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
| 22 | #include "mlir/IR/BuiltinTypes.h" |
| 23 | #include "mlir/IR/Dialect.h" |
| 24 | #include "mlir/IR/DialectImplementation.h" |
| 25 | #include "mlir/IR/DialectResourceBlobManager.h" |
| 26 | #include "mlir/IR/IntegerSet.h" |
| 27 | #include "mlir/IR/MLIRContext.h" |
| 28 | #include "mlir/IR/OpImplementation.h" |
| 29 | #include "mlir/IR/Operation.h" |
| 30 | #include "mlir/IR/Verifier.h" |
| 31 | #include "llvm/ADT/APFloat.h" |
| 32 | #include "llvm/ADT/ArrayRef.h" |
| 33 | #include "llvm/ADT/DenseMap.h" |
| 34 | #include "llvm/ADT/MapVector.h" |
| 35 | #include "llvm/ADT/STLExtras.h" |
| 36 | #include "llvm/ADT/ScopeExit.h" |
| 37 | #include "llvm/ADT/ScopedHashTable.h" |
| 38 | #include "llvm/ADT/SetVector.h" |
| 39 | #include "llvm/ADT/SmallString.h" |
| 40 | #include "llvm/ADT/StringExtras.h" |
| 41 | #include "llvm/ADT/StringSet.h" |
| 42 | #include "llvm/ADT/TypeSwitch.h" |
| 43 | #include "llvm/Support/CommandLine.h" |
| 44 | #include "llvm/Support/Debug.h" |
| 45 | #include "llvm/Support/Endian.h" |
| 46 | #include "llvm/Support/ManagedStatic.h" |
| 47 | #include "llvm/Support/Regex.h" |
| 48 | #include "llvm/Support/SaveAndRestore.h" |
| 49 | #include "llvm/Support/Threading.h" |
| 50 | #include "llvm/Support/raw_ostream.h" |
| 51 | #include <type_traits> |
| 52 | |
| 53 | #include <optional> |
| 54 | #include <tuple> |
| 55 | |
| 56 | using namespace mlir; |
| 57 | using namespace mlir::detail; |
| 58 | |
| 59 | #define DEBUG_TYPE "mlir-asm-printer" |
| 60 | |
| 61 | void OperationName::print(raw_ostream &os) const { os << getStringRef(); } |
| 62 | |
| 63 | void OperationName::dump() const { print(os&: llvm::errs()); } |
| 64 | |
| 65 | //===--------------------------------------------------------------------===// |
| 66 | // AsmParser |
| 67 | //===--------------------------------------------------------------------===// |
| 68 | |
| 69 | AsmParser::~AsmParser() = default; |
| 70 | DialectAsmParser::~DialectAsmParser() = default; |
| 71 | OpAsmParser::~OpAsmParser() = default; |
| 72 | |
| 73 | MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); } |
| 74 | |
| 75 | /// Parse a type list. |
| 76 | /// This is out-of-line to work-around |
| 77 | /// https://github.com/llvm/llvm-project/issues/62918 |
| 78 | ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) { |
| 79 | return parseCommaSeparatedList( |
| 80 | parseElementFn: [&]() { return parseType(result&: result.emplace_back()); }); |
| 81 | } |
| 82 | |
| 83 | //===----------------------------------------------------------------------===// |
| 84 | // DialectAsmPrinter |
| 85 | //===----------------------------------------------------------------------===// |
| 86 | |
| 87 | DialectAsmPrinter::~DialectAsmPrinter() = default; |
| 88 | |
| 89 | //===----------------------------------------------------------------------===// |
| 90 | // OpAsmPrinter |
| 91 | //===----------------------------------------------------------------------===// |
| 92 | |
| 93 | OpAsmPrinter::~OpAsmPrinter() = default; |
| 94 | |
| 95 | void OpAsmPrinter::printFunctionalType(Operation *op) { |
| 96 | auto &os = getStream(); |
| 97 | os << '('; |
| 98 | llvm::interleaveComma(c: op->getOperands(), os, each_fn: [&](Value operand) { |
| 99 | // Print the types of null values as <<NULL TYPE>>. |
| 100 | *this << (operand ? operand.getType() : Type()); |
| 101 | }); |
| 102 | os << ") -> " ; |
| 103 | |
| 104 | // Print the result list. We don't parenthesize single result types unless |
| 105 | // it is a function (avoiding a grammar ambiguity). |
| 106 | bool wrapped = op->getNumResults() != 1; |
| 107 | if (!wrapped && op->getResult(idx: 0).getType() && |
| 108 | llvm::isa<FunctionType>(Val: op->getResult(idx: 0).getType())) |
| 109 | wrapped = true; |
| 110 | |
| 111 | if (wrapped) |
| 112 | os << '('; |
| 113 | |
| 114 | llvm::interleaveComma(c: op->getResults(), os, each_fn: [&](const OpResult &result) { |
| 115 | // Print the types of null values as <<NULL TYPE>>. |
| 116 | *this << (result ? result.getType() : Type()); |
| 117 | }); |
| 118 | |
| 119 | if (wrapped) |
| 120 | os << ')'; |
| 121 | } |
| 122 | |
| 123 | //===----------------------------------------------------------------------===// |
| 124 | // Operation OpAsm interface. |
| 125 | //===----------------------------------------------------------------------===// |
| 126 | |
| 127 | /// The OpAsmOpInterface, see OpAsmInterface.td for more details. |
| 128 | #include "mlir/IR/OpAsmAttrInterface.cpp.inc" |
| 129 | #include "mlir/IR/OpAsmOpInterface.cpp.inc" |
| 130 | #include "mlir/IR/OpAsmTypeInterface.cpp.inc" |
| 131 | |
| 132 | LogicalResult |
| 133 | OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const { |
| 134 | return entry.emitError() << "unknown 'resource' key '" << entry.getKey() |
| 135 | << "' for dialect '" << getDialect()->getNamespace() |
| 136 | << "'" ; |
| 137 | } |
| 138 | |
| 139 | //===----------------------------------------------------------------------===// |
| 140 | // OpPrintingFlags |
| 141 | //===----------------------------------------------------------------------===// |
| 142 | |
| 143 | namespace { |
| 144 | /// This struct contains command line options that can be used to initialize |
| 145 | /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need |
| 146 | /// for global command line options. |
| 147 | struct AsmPrinterOptions { |
| 148 | llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{ |
| 149 | "mlir-print-elementsattrs-with-hex-if-larger" , |
| 150 | llvm::cl::desc( |
| 151 | "Print DenseElementsAttrs with a hex string that have " |
| 152 | "more elements than the given upper limit (use -1 to disable)" )}; |
| 153 | |
| 154 | llvm::cl::opt<unsigned> elideElementsAttrIfLarger{ |
| 155 | "mlir-elide-elementsattrs-if-larger" , |
| 156 | llvm::cl::desc("Elide ElementsAttrs with \"...\" that have " |
| 157 | "more elements than the given upper limit" )}; |
| 158 | |
| 159 | llvm::cl::opt<unsigned> elideResourceStringsIfLarger{ |
| 160 | "mlir-elide-resource-strings-if-larger" , |
| 161 | llvm::cl::desc( |
| 162 | "Elide printing value of resources if string is too long in chars." )}; |
| 163 | |
| 164 | llvm::cl::opt<bool> printDebugInfoOpt{ |
| 165 | "mlir-print-debuginfo" , llvm::cl::init(Val: false), |
| 166 | llvm::cl::desc("Print debug info in MLIR output" )}; |
| 167 | |
| 168 | llvm::cl::opt<bool> printPrettyDebugInfoOpt{ |
| 169 | "mlir-pretty-debuginfo" , llvm::cl::init(Val: false), |
| 170 | llvm::cl::desc("Print pretty debug info in MLIR output" )}; |
| 171 | |
| 172 | // Use the generic op output form in the operation printer even if the custom |
| 173 | // form is defined. |
| 174 | llvm::cl::opt<bool> printGenericOpFormOpt{ |
| 175 | "mlir-print-op-generic" , llvm::cl::init(Val: false), |
| 176 | llvm::cl::desc("Print the generic op form" ), llvm::cl::Hidden}; |
| 177 | |
| 178 | llvm::cl::opt<bool> assumeVerifiedOpt{ |
| 179 | "mlir-print-assume-verified" , llvm::cl::init(Val: false), |
| 180 | llvm::cl::desc("Skip op verification when using custom printers" ), |
| 181 | llvm::cl::Hidden}; |
| 182 | |
| 183 | llvm::cl::opt<bool> printLocalScopeOpt{ |
| 184 | "mlir-print-local-scope" , llvm::cl::init(Val: false), |
| 185 | llvm::cl::desc("Print with local scope and inline information (eliding " |
| 186 | "aliases for attributes, types, and locations)" )}; |
| 187 | |
| 188 | llvm::cl::opt<bool> skipRegionsOpt{ |
| 189 | "mlir-print-skip-regions" , llvm::cl::init(Val: false), |
| 190 | llvm::cl::desc("Skip regions when printing ops." )}; |
| 191 | |
| 192 | llvm::cl::opt<bool> printValueUsers{ |
| 193 | "mlir-print-value-users" , llvm::cl::init(Val: false), |
| 194 | llvm::cl::desc( |
| 195 | "Print users of operation results and block arguments as a comment" )}; |
| 196 | |
| 197 | llvm::cl::opt<bool> printUniqueSSAIDs{ |
| 198 | "mlir-print-unique-ssa-ids" , llvm::cl::init(Val: false), |
| 199 | llvm::cl::desc("Print unique SSA ID numbers for values, block arguments " |
| 200 | "and naming conflicts across all regions" )}; |
| 201 | |
| 202 | llvm::cl::opt<bool> useNameLocAsPrefix{ |
| 203 | "mlir-use-nameloc-as-prefix" , llvm::cl::init(Val: false), |
| 204 | llvm::cl::desc("Print SSA IDs using NameLocs as prefixes" )}; |
| 205 | }; |
| 206 | } // namespace |
| 207 | |
| 208 | static llvm::ManagedStatic<AsmPrinterOptions> clOptions; |
| 209 | |
| 210 | /// Register a set of useful command-line options that can be used to configure |
| 211 | /// various flags within the AsmPrinter. |
| 212 | void mlir::registerAsmPrinterCLOptions() { |
| 213 | // Make sure that the options struct has been initialized. |
| 214 | *clOptions; |
| 215 | } |
| 216 | |
| 217 | /// Initialize the printing flags with default supplied by the cl::opts above. |
| 218 | OpPrintingFlags::OpPrintingFlags() |
| 219 | : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false), |
| 220 | printGenericOpFormFlag(false), skipRegionsFlag(false), |
| 221 | assumeVerifiedFlag(false), printLocalScope(false), |
| 222 | printValueUsersFlag(false), printUniqueSSAIDsFlag(false), |
| 223 | useNameLocAsPrefix(false) { |
| 224 | // Initialize based upon command line options, if they are available. |
| 225 | if (!clOptions.isConstructed()) |
| 226 | return; |
| 227 | if (clOptions->elideElementsAttrIfLarger.getNumOccurrences()) |
| 228 | elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger; |
| 229 | if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) |
| 230 | elementsAttrHexElementLimit = |
| 231 | clOptions->printElementsAttrWithHexIfLarger.getValue(); |
| 232 | if (clOptions->elideResourceStringsIfLarger.getNumOccurrences()) |
| 233 | resourceStringCharLimit = clOptions->elideResourceStringsIfLarger; |
| 234 | printDebugInfoFlag = clOptions->printDebugInfoOpt; |
| 235 | printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt; |
| 236 | printGenericOpFormFlag = clOptions->printGenericOpFormOpt; |
| 237 | assumeVerifiedFlag = clOptions->assumeVerifiedOpt; |
| 238 | printLocalScope = clOptions->printLocalScopeOpt; |
| 239 | skipRegionsFlag = clOptions->skipRegionsOpt; |
| 240 | printValueUsersFlag = clOptions->printValueUsers; |
| 241 | printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs; |
| 242 | useNameLocAsPrefix = clOptions->useNameLocAsPrefix; |
| 243 | } |
| 244 | |
| 245 | /// Enable the elision of large elements attributes, by printing a '...' |
| 246 | /// instead of the element data, when the number of elements is greater than |
| 247 | /// `largeElementLimit`. Note: The IR generated with this option is not |
| 248 | /// parsable. |
| 249 | OpPrintingFlags & |
| 250 | OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) { |
| 251 | elementsAttrElementLimit = largeElementLimit; |
| 252 | return *this; |
| 253 | } |
| 254 | |
| 255 | OpPrintingFlags & |
| 256 | OpPrintingFlags::printLargeElementsAttrWithHex(int64_t largeElementLimit) { |
| 257 | elementsAttrHexElementLimit = largeElementLimit; |
| 258 | return *this; |
| 259 | } |
| 260 | |
| 261 | OpPrintingFlags & |
| 262 | OpPrintingFlags::elideLargeResourceString(int64_t largeResourceLimit) { |
| 263 | resourceStringCharLimit = largeResourceLimit; |
| 264 | return *this; |
| 265 | } |
| 266 | |
| 267 | /// Enable printing of debug information. If 'prettyForm' is set to true, |
| 268 | /// debug information is printed in a more readable 'pretty' form. |
| 269 | OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool enable, |
| 270 | bool prettyForm) { |
| 271 | printDebugInfoFlag = enable; |
| 272 | printDebugInfoPrettyFormFlag = prettyForm; |
| 273 | return *this; |
| 274 | } |
| 275 | |
| 276 | /// Always print operations in the generic form. |
| 277 | OpPrintingFlags &OpPrintingFlags::printGenericOpForm(bool enable) { |
| 278 | printGenericOpFormFlag = enable; |
| 279 | return *this; |
| 280 | } |
| 281 | |
| 282 | /// Always skip Regions. |
| 283 | OpPrintingFlags &OpPrintingFlags::skipRegions(bool skip) { |
| 284 | skipRegionsFlag = skip; |
| 285 | return *this; |
| 286 | } |
| 287 | |
| 288 | /// Do not verify the operation when using custom operation printers. |
| 289 | OpPrintingFlags &OpPrintingFlags::assumeVerified(bool enable) { |
| 290 | assumeVerifiedFlag = enable; |
| 291 | return *this; |
| 292 | } |
| 293 | |
| 294 | /// Use local scope when printing the operation. This allows for using the |
| 295 | /// printer in a more localized and thread-safe setting, but may not necessarily |
| 296 | /// be identical of what the IR will look like when dumping the full module. |
| 297 | OpPrintingFlags &OpPrintingFlags::useLocalScope(bool enable) { |
| 298 | printLocalScope = enable; |
| 299 | return *this; |
| 300 | } |
| 301 | |
| 302 | /// Print users of values as comments. |
| 303 | OpPrintingFlags &OpPrintingFlags::printValueUsers(bool enable) { |
| 304 | printValueUsersFlag = enable; |
| 305 | return *this; |
| 306 | } |
| 307 | |
| 308 | /// Print unique SSA ID numbers for values, block arguments and naming conflicts |
| 309 | /// across all regions |
| 310 | OpPrintingFlags &OpPrintingFlags::printUniqueSSAIDs(bool enable) { |
| 311 | printUniqueSSAIDsFlag = enable; |
| 312 | return *this; |
| 313 | } |
| 314 | |
| 315 | /// Return if the given ElementsAttr should be elided. |
| 316 | bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const { |
| 317 | return elementsAttrElementLimit && |
| 318 | *elementsAttrElementLimit < int64_t(attr.getNumElements()) && |
| 319 | !llvm::isa<SplatElementsAttr>(attr); |
| 320 | } |
| 321 | |
| 322 | /// Return if the given ElementsAttr should be printed as hex string. |
| 323 | bool OpPrintingFlags::shouldPrintElementsAttrWithHex(ElementsAttr attr) const { |
| 324 | // -1 is used to disable hex printing. |
| 325 | return (elementsAttrHexElementLimit != -1) && |
| 326 | (elementsAttrHexElementLimit < int64_t(attr.getNumElements())) && |
| 327 | !llvm::isa<SplatElementsAttr>(attr); |
| 328 | } |
| 329 | |
| 330 | OpPrintingFlags &OpPrintingFlags::printNameLocAsPrefix(bool enable) { |
| 331 | useNameLocAsPrefix = enable; |
| 332 | return *this; |
| 333 | } |
| 334 | |
| 335 | /// Return the size limit for printing large ElementsAttr. |
| 336 | std::optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const { |
| 337 | return elementsAttrElementLimit; |
| 338 | } |
| 339 | |
| 340 | /// Return the size limit for printing large ElementsAttr as hex string. |
| 341 | int64_t OpPrintingFlags::getLargeElementsAttrHexLimit() const { |
| 342 | return elementsAttrHexElementLimit; |
| 343 | } |
| 344 | |
| 345 | /// Return the size limit for printing large ElementsAttr. |
| 346 | std::optional<uint64_t> OpPrintingFlags::getLargeResourceStringLimit() const { |
| 347 | return resourceStringCharLimit; |
| 348 | } |
| 349 | |
| 350 | /// Return if debug information should be printed. |
| 351 | bool OpPrintingFlags::shouldPrintDebugInfo() const { |
| 352 | return printDebugInfoFlag; |
| 353 | } |
| 354 | |
| 355 | /// Return if debug information should be printed in the pretty form. |
| 356 | bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const { |
| 357 | return printDebugInfoPrettyFormFlag; |
| 358 | } |
| 359 | |
| 360 | /// Return if operations should be printed in the generic form. |
| 361 | bool OpPrintingFlags::shouldPrintGenericOpForm() const { |
| 362 | return printGenericOpFormFlag; |
| 363 | } |
| 364 | |
| 365 | /// Return if Region should be skipped. |
| 366 | bool OpPrintingFlags::shouldSkipRegions() const { return skipRegionsFlag; } |
| 367 | |
| 368 | /// Return if operation verification should be skipped. |
| 369 | bool OpPrintingFlags::shouldAssumeVerified() const { |
| 370 | return assumeVerifiedFlag; |
| 371 | } |
| 372 | |
| 373 | /// Return if the printer should use local scope when dumping the IR. |
| 374 | bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; } |
| 375 | |
| 376 | /// Return if the printer should print users of values. |
| 377 | bool OpPrintingFlags::shouldPrintValueUsers() const { |
| 378 | return printValueUsersFlag; |
| 379 | } |
| 380 | |
| 381 | /// Return if the printer should use unique IDs. |
| 382 | bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const { |
| 383 | return printUniqueSSAIDsFlag || shouldPrintGenericOpForm(); |
| 384 | } |
| 385 | |
| 386 | /// Return if the printer should use NameLocs as prefixes when printing SSA IDs. |
| 387 | bool OpPrintingFlags::shouldUseNameLocAsPrefix() const { |
| 388 | return useNameLocAsPrefix; |
| 389 | } |
| 390 | |
| 391 | //===----------------------------------------------------------------------===// |
| 392 | // NewLineCounter |
| 393 | //===----------------------------------------------------------------------===// |
| 394 | |
| 395 | namespace { |
| 396 | /// This class is a simple formatter that emits a new line when inputted into a |
| 397 | /// stream, that enables counting the number of newlines emitted. This class |
| 398 | /// should be used whenever emitting newlines in the printer. |
| 399 | struct NewLineCounter { |
| 400 | unsigned curLine = 1; |
| 401 | }; |
| 402 | |
| 403 | static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) { |
| 404 | ++newLine.curLine; |
| 405 | return os << '\n'; |
| 406 | } |
| 407 | } // namespace |
| 408 | |
| 409 | //===----------------------------------------------------------------------===// |
| 410 | // AsmPrinter::Impl |
| 411 | //===----------------------------------------------------------------------===// |
| 412 | |
| 413 | namespace mlir { |
| 414 | class AsmPrinter::Impl { |
| 415 | public: |
| 416 | Impl(raw_ostream &os, AsmStateImpl &state); |
| 417 | explicit Impl(Impl &other) : Impl(other.os, other.state) {} |
| 418 | |
| 419 | /// Returns the output stream of the printer. |
| 420 | raw_ostream &getStream() { return os; } |
| 421 | |
| 422 | template <typename Container, typename UnaryFunctor> |
| 423 | inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const { |
| 424 | llvm::interleaveComma(c, os, eachFn); |
| 425 | } |
| 426 | |
| 427 | /// This enum describes the different kinds of elision for the type of an |
| 428 | /// attribute when printing it. |
| 429 | enum class AttrTypeElision { |
| 430 | /// The type must not be elided, |
| 431 | Never, |
| 432 | /// The type may be elided when it matches the default used in the parser |
| 433 | /// (for example i64 is the default for integer attributes). |
| 434 | May, |
| 435 | /// The type must be elided. |
| 436 | Must |
| 437 | }; |
| 438 | |
| 439 | /// Print the given attribute or an alias. |
| 440 | void printAttribute(Attribute attr, |
| 441 | AttrTypeElision typeElision = AttrTypeElision::Never); |
| 442 | /// Print the given attribute without considering an alias. |
| 443 | void printAttributeImpl(Attribute attr, |
| 444 | AttrTypeElision typeElision = AttrTypeElision::Never); |
| 445 | |
| 446 | /// Print the alias for the given attribute, return failure if no alias could |
| 447 | /// be printed. |
| 448 | LogicalResult printAlias(Attribute attr); |
| 449 | |
| 450 | /// Print the given type or an alias. |
| 451 | void printType(Type type); |
| 452 | /// Print the given type. |
| 453 | void printTypeImpl(Type type); |
| 454 | |
| 455 | /// Print the alias for the given type, return failure if no alias could |
| 456 | /// be printed. |
| 457 | LogicalResult printAlias(Type type); |
| 458 | |
| 459 | /// Print the given location to the stream. If `allowAlias` is true, this |
| 460 | /// allows for the internal location to use an attribute alias. |
| 461 | void printLocation(LocationAttr loc, bool allowAlias = false); |
| 462 | |
| 463 | /// Print a reference to the given resource that is owned by the given |
| 464 | /// dialect. |
| 465 | void printResourceHandle(const AsmDialectResourceHandle &resource); |
| 466 | |
| 467 | void printAffineMap(AffineMap map); |
| 468 | void |
| 469 | printAffineExpr(AffineExpr expr, |
| 470 | function_ref<void(unsigned, bool)> printValueName = nullptr); |
| 471 | void printAffineConstraint(AffineExpr expr, bool isEq); |
| 472 | void printIntegerSet(IntegerSet set); |
| 473 | |
| 474 | LogicalResult pushCyclicPrinting(const void *opaquePointer); |
| 475 | |
| 476 | void popCyclicPrinting(); |
| 477 | |
| 478 | void printDimensionList(ArrayRef<int64_t> shape); |
| 479 | |
| 480 | protected: |
| 481 | void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
| 482 | ArrayRef<StringRef> elidedAttrs = {}, |
| 483 | bool withKeyword = false); |
| 484 | void printNamedAttribute(NamedAttribute attr); |
| 485 | void printTrailingLocation(Location loc, bool allowAlias = true); |
| 486 | void printLocationInternal(LocationAttr loc, bool pretty = false, |
| 487 | bool isTopLevel = false); |
| 488 | |
| 489 | /// Print a dense elements attribute. If 'allowHex' is true, a hex string is |
| 490 | /// used instead of individual elements when the elements attr is large. |
| 491 | void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex); |
| 492 | |
| 493 | /// Print a dense string elements attribute. |
| 494 | void printDenseStringElementsAttr(DenseStringElementsAttr attr); |
| 495 | |
| 496 | /// Print a dense elements attribute. If 'allowHex' is true, a hex string is |
| 497 | /// used instead of individual elements when the elements attr is large. |
| 498 | void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, |
| 499 | bool allowHex); |
| 500 | |
| 501 | /// Print a dense array attribute. |
| 502 | void printDenseArrayAttr(DenseArrayAttr attr); |
| 503 | |
| 504 | void printDialectAttribute(Attribute attr); |
| 505 | void printDialectType(Type type); |
| 506 | |
| 507 | /// Print an escaped string, wrapped with "". |
| 508 | void printEscapedString(StringRef str); |
| 509 | |
| 510 | /// Print a hex string, wrapped with "". |
| 511 | void printHexString(StringRef str); |
| 512 | void printHexString(ArrayRef<char> data); |
| 513 | |
| 514 | /// This enum is used to represent the binding strength of the enclosing |
| 515 | /// context that an AffineExprStorage is being printed in, so we can |
| 516 | /// intelligently produce parens. |
| 517 | enum class BindingStrength { |
| 518 | Weak, // + and - |
| 519 | Strong, // All other binary operators. |
| 520 | }; |
| 521 | void printAffineExprInternal( |
| 522 | AffineExpr expr, BindingStrength enclosingTightness, |
| 523 | function_ref<void(unsigned, bool)> printValueName = nullptr); |
| 524 | |
| 525 | /// The output stream for the printer. |
| 526 | raw_ostream &os; |
| 527 | |
| 528 | /// An underlying assembly printer state. |
| 529 | AsmStateImpl &state; |
| 530 | |
| 531 | /// A set of flags to control the printer's behavior. |
| 532 | OpPrintingFlags printerFlags; |
| 533 | |
| 534 | /// A tracker for the number of new lines emitted during printing. |
| 535 | NewLineCounter newLine; |
| 536 | }; |
| 537 | } // namespace mlir |
| 538 | |
| 539 | //===----------------------------------------------------------------------===// |
| 540 | // AliasInitializer |
| 541 | //===----------------------------------------------------------------------===// |
| 542 | |
| 543 | namespace { |
| 544 | /// This class represents a specific instance of a symbol Alias. |
| 545 | class SymbolAlias { |
| 546 | public: |
| 547 | SymbolAlias(StringRef name, uint32_t suffixIndex, bool isType, |
| 548 | bool isDeferrable) |
| 549 | : name(name), suffixIndex(suffixIndex), isType(isType), |
| 550 | isDeferrable(isDeferrable) {} |
| 551 | |
| 552 | /// Print this alias to the given stream. |
| 553 | void print(raw_ostream &os) const { |
| 554 | os << (isType ? "!" : "#" ) << name; |
| 555 | if (suffixIndex) { |
| 556 | if (isdigit(name.back())) |
| 557 | os << '_'; |
| 558 | os << suffixIndex; |
| 559 | } |
| 560 | } |
| 561 | |
| 562 | /// Returns true if this is a type alias. |
| 563 | bool isTypeAlias() const { return isType; } |
| 564 | |
| 565 | /// Returns true if this alias supports deferred resolution when parsing. |
| 566 | bool canBeDeferred() const { return isDeferrable; } |
| 567 | |
| 568 | private: |
| 569 | /// The main name of the alias. |
| 570 | StringRef name; |
| 571 | /// The suffix index of the alias. |
| 572 | uint32_t suffixIndex : 30; |
| 573 | /// A flag indicating whether this alias is for a type. |
| 574 | bool isType : 1; |
| 575 | /// A flag indicating whether this alias may be deferred or not. |
| 576 | bool isDeferrable : 1; |
| 577 | |
| 578 | public: |
| 579 | /// Used to avoid printing incomplete aliases for recursive types. |
| 580 | bool isPrinted = false; |
| 581 | }; |
| 582 | |
| 583 | /// This class represents a utility that initializes the set of attribute and |
| 584 | /// type aliases, without the need to store the extra information within the |
| 585 | /// main AliasState class or pass it around via function arguments. |
| 586 | class AliasInitializer { |
| 587 | public: |
| 588 | AliasInitializer( |
| 589 | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces, |
| 590 | llvm::BumpPtrAllocator &aliasAllocator) |
| 591 | : interfaces(interfaces), aliasAllocator(aliasAllocator), |
| 592 | aliasOS(aliasBuffer) {} |
| 593 | |
| 594 | void initialize(Operation *op, const OpPrintingFlags &printerFlags, |
| 595 | llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias); |
| 596 | |
| 597 | /// Visit the given attribute to see if it has an alias. `canBeDeferred` is |
| 598 | /// set to true if the originator of this attribute can resolve the alias |
| 599 | /// after parsing has completed (e.g. in the case of operation locations). |
| 600 | /// `elideType` indicates if the type of the attribute should be skipped when |
| 601 | /// looking for nested aliases. Returns the maximum alias depth of the |
| 602 | /// attribute, and the alias index of this attribute. |
| 603 | std::pair<size_t, size_t> visit(Attribute attr, bool canBeDeferred = false, |
| 604 | bool elideType = false) { |
| 605 | return visitImpl(value: attr, aliases, canBeDeferred, printArgs&: elideType); |
| 606 | } |
| 607 | |
| 608 | /// Visit the given type to see if it has an alias. `canBeDeferred` is |
| 609 | /// set to true if the originator of this attribute can resolve the alias |
| 610 | /// after parsing has completed. Returns the maximum alias depth of the type, |
| 611 | /// and the alias index of this type. |
| 612 | std::pair<size_t, size_t> visit(Type type, bool canBeDeferred = false) { |
| 613 | return visitImpl(value: type, aliases, canBeDeferred); |
| 614 | } |
| 615 | |
| 616 | private: |
| 617 | struct InProgressAliasInfo { |
| 618 | InProgressAliasInfo() |
| 619 | : aliasDepth(0), isType(false), canBeDeferred(false) {} |
| 620 | InProgressAliasInfo(StringRef alias) |
| 621 | : alias(alias), aliasDepth(1), isType(false), canBeDeferred(false) {} |
| 622 | |
| 623 | bool operator<(const InProgressAliasInfo &rhs) const { |
| 624 | // Order first by depth, then by attr/type kind, and then by name. |
| 625 | if (aliasDepth != rhs.aliasDepth) |
| 626 | return aliasDepth < rhs.aliasDepth; |
| 627 | if (isType != rhs.isType) |
| 628 | return isType; |
| 629 | return alias < rhs.alias; |
| 630 | } |
| 631 | |
| 632 | /// The alias for the attribute or type, or std::nullopt if the value has no |
| 633 | /// alias. |
| 634 | std::optional<StringRef> alias; |
| 635 | /// The alias depth of this attribute or type, i.e. an indication of the |
| 636 | /// relative ordering of when to print this alias. |
| 637 | unsigned aliasDepth : 30; |
| 638 | /// If this alias represents a type or an attribute. |
| 639 | bool isType : 1; |
| 640 | /// If this alias can be deferred or not. |
| 641 | bool canBeDeferred : 1; |
| 642 | /// Indices for child aliases. |
| 643 | SmallVector<size_t> childIndices; |
| 644 | }; |
| 645 | |
| 646 | /// Visit the given attribute or type to see if it has an alias. |
| 647 | /// `canBeDeferred` is set to true if the originator of this value can resolve |
| 648 | /// the alias after parsing has completed (e.g. in the case of operation |
| 649 | /// locations). Returns the maximum alias depth of the value, and its alias |
| 650 | /// index. |
| 651 | template <typename T, typename... PrintArgs> |
| 652 | std::pair<size_t, size_t> |
| 653 | visitImpl(T value, |
| 654 | llvm::MapVector<const void *, InProgressAliasInfo> &aliases, |
| 655 | bool canBeDeferred, PrintArgs &&...printArgs); |
| 656 | |
| 657 | /// Mark the given alias as non-deferrable. |
| 658 | void markAliasNonDeferrable(size_t aliasIndex); |
| 659 | |
| 660 | /// Try to generate an alias for the provided symbol. If an alias is |
| 661 | /// generated, the provided alias mapping and reverse mapping are updated. |
| 662 | template <typename T> |
| 663 | void generateAlias(T symbol, InProgressAliasInfo &alias, bool canBeDeferred); |
| 664 | |
| 665 | /// Uniques the given alias name within the printer by generating name index |
| 666 | /// used as alias name suffix. |
| 667 | static unsigned |
| 668 | uniqueAliasNameIndex(StringRef alias, llvm::StringMap<unsigned> &nameCounts, |
| 669 | llvm::StringSet<llvm::BumpPtrAllocator &> &usedAliases); |
| 670 | |
| 671 | /// Given a collection of aliases and symbols, initialize a mapping from a |
| 672 | /// symbol to a given alias. |
| 673 | static void initializeAliases( |
| 674 | llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols, |
| 675 | llvm::MapVector<const void *, SymbolAlias> &symbolToAlias); |
| 676 | |
| 677 | /// The set of asm interfaces within the context. |
| 678 | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces; |
| 679 | |
| 680 | /// An allocator used for alias names. |
| 681 | llvm::BumpPtrAllocator &aliasAllocator; |
| 682 | |
| 683 | /// The set of built aliases. |
| 684 | llvm::MapVector<const void *, InProgressAliasInfo> aliases; |
| 685 | |
| 686 | /// Storage and stream used when generating an alias. |
| 687 | SmallString<32> aliasBuffer; |
| 688 | llvm::raw_svector_ostream aliasOS; |
| 689 | }; |
| 690 | |
| 691 | /// This class implements a dummy OpAsmPrinter that doesn't print any output, |
| 692 | /// and merely collects the attributes and types that *would* be printed in a |
| 693 | /// normal print invocation so that we can generate proper aliases. This allows |
| 694 | /// for us to generate aliases only for the attributes and types that would be |
| 695 | /// in the output, and trims down unnecessary output. |
| 696 | class DummyAliasOperationPrinter : private OpAsmPrinter { |
| 697 | public: |
| 698 | explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags, |
| 699 | AliasInitializer &initializer) |
| 700 | : printerFlags(printerFlags), initializer(initializer) {} |
| 701 | |
| 702 | /// Prints the entire operation with the custom assembly form, if available, |
| 703 | /// or the generic assembly form, otherwise. |
| 704 | void printCustomOrGenericOp(Operation *op) override { |
| 705 | // Visit the operation location. |
| 706 | if (printerFlags.shouldPrintDebugInfo()) |
| 707 | initializer.visit(attr: op->getLoc(), /*canBeDeferred=*/true); |
| 708 | |
| 709 | // If requested, always print the generic form. |
| 710 | if (!printerFlags.shouldPrintGenericOpForm()) { |
| 711 | op->getName().printAssembly(op, p&: *this, /*defaultDialect=*/"" ); |
| 712 | return; |
| 713 | } |
| 714 | |
| 715 | // Otherwise print with the generic assembly form. |
| 716 | printGenericOp(op); |
| 717 | } |
| 718 | |
| 719 | private: |
| 720 | /// Print the given operation in the generic form. |
| 721 | void printGenericOp(Operation *op, bool printOpName = true) override { |
| 722 | // Consider nested operations for aliases. |
| 723 | if (!printerFlags.shouldSkipRegions()) { |
| 724 | for (Region ®ion : op->getRegions()) |
| 725 | printRegion(region, /*printEntryBlockArgs=*/true, |
| 726 | /*printBlockTerminators=*/true); |
| 727 | } |
| 728 | |
| 729 | // Visit all the types used in the operation. |
| 730 | for (Type type : op->getOperandTypes()) |
| 731 | printType(type); |
| 732 | for (Type type : op->getResultTypes()) |
| 733 | printType(type); |
| 734 | |
| 735 | // Consider the attributes of the operation for aliases. |
| 736 | for (const NamedAttribute &attr : op->getAttrs()) |
| 737 | printAttribute(attr: attr.getValue()); |
| 738 | } |
| 739 | |
| 740 | /// Print the given block. If 'printBlockArgs' is false, the arguments of the |
| 741 | /// block are not printed. If 'printBlockTerminator' is false, the terminator |
| 742 | /// operation of the block is not printed. |
| 743 | void print(Block *block, bool printBlockArgs = true, |
| 744 | bool printBlockTerminator = true) { |
| 745 | // Consider the types of the block arguments for aliases if 'printBlockArgs' |
| 746 | // is set to true. |
| 747 | if (printBlockArgs) { |
| 748 | for (BlockArgument arg : block->getArguments()) { |
| 749 | printType(type: arg.getType()); |
| 750 | |
| 751 | // Visit the argument location. |
| 752 | if (printerFlags.shouldPrintDebugInfo()) |
| 753 | // TODO: Allow deferring argument locations. |
| 754 | initializer.visit(attr: arg.getLoc(), /*canBeDeferred=*/false); |
| 755 | } |
| 756 | } |
| 757 | |
| 758 | // Consider the operations within this block, ignoring the terminator if |
| 759 | // requested. |
| 760 | bool hasTerminator = |
| 761 | !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>(); |
| 762 | auto range = llvm::make_range( |
| 763 | x: block->begin(), |
| 764 | y: std::prev(x: block->end(), |
| 765 | n: (!hasTerminator || printBlockTerminator) ? 0 : 1)); |
| 766 | for (Operation &op : range) |
| 767 | printCustomOrGenericOp(op: &op); |
| 768 | } |
| 769 | |
| 770 | /// Print the given region. |
| 771 | void printRegion(Region ®ion, bool printEntryBlockArgs, |
| 772 | bool printBlockTerminators, |
| 773 | bool printEmptyBlock = false) override { |
| 774 | if (region.empty()) |
| 775 | return; |
| 776 | if (printerFlags.shouldSkipRegions()) { |
| 777 | os << "{...}" ; |
| 778 | return; |
| 779 | } |
| 780 | |
| 781 | auto *entryBlock = ®ion.front(); |
| 782 | print(block: entryBlock, printBlockArgs: printEntryBlockArgs, printBlockTerminator: printBlockTerminators); |
| 783 | for (Block &b : llvm::drop_begin(RangeOrContainer&: region, N: 1)) |
| 784 | print(block: &b); |
| 785 | } |
| 786 | |
| 787 | void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, |
| 788 | bool omitType) override { |
| 789 | printType(type: arg.getType()); |
| 790 | // Visit the argument location. |
| 791 | if (printerFlags.shouldPrintDebugInfo()) |
| 792 | // TODO: Allow deferring argument locations. |
| 793 | initializer.visit(attr: arg.getLoc(), /*canBeDeferred=*/false); |
| 794 | } |
| 795 | |
| 796 | /// Consider the given type to be printed for an alias. |
| 797 | void printType(Type type) override { initializer.visit(type); } |
| 798 | |
| 799 | /// Consider the given attribute to be printed for an alias. |
| 800 | void printAttribute(Attribute attr) override { initializer.visit(attr); } |
| 801 | void printAttributeWithoutType(Attribute attr) override { |
| 802 | printAttribute(attr); |
| 803 | } |
| 804 | LogicalResult printAlias(Attribute attr) override { |
| 805 | initializer.visit(attr); |
| 806 | return success(); |
| 807 | } |
| 808 | LogicalResult printAlias(Type type) override { |
| 809 | initializer.visit(type); |
| 810 | return success(); |
| 811 | } |
| 812 | |
| 813 | /// Consider the given location to be printed for an alias. |
| 814 | void printOptionalLocationSpecifier(Location loc) override { |
| 815 | printAttribute(attr: loc); |
| 816 | } |
| 817 | |
| 818 | /// Print the given set of attributes with names not included within |
| 819 | /// 'elidedAttrs'. |
| 820 | void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
| 821 | ArrayRef<StringRef> elidedAttrs = {}) override { |
| 822 | if (attrs.empty()) |
| 823 | return; |
| 824 | if (elidedAttrs.empty()) { |
| 825 | for (const NamedAttribute &attr : attrs) |
| 826 | printAttribute(attr: attr.getValue()); |
| 827 | return; |
| 828 | } |
| 829 | llvm::SmallDenseSet<StringRef> (elidedAttrs.begin(), |
| 830 | elidedAttrs.end()); |
| 831 | for (const NamedAttribute &attr : attrs) |
| 832 | if (!elidedAttrsSet.contains(V: attr.getName().strref())) |
| 833 | printAttribute(attr: attr.getValue()); |
| 834 | } |
| 835 | void printOptionalAttrDictWithKeyword( |
| 836 | ArrayRef<NamedAttribute> attrs, |
| 837 | ArrayRef<StringRef> elidedAttrs = {}) override { |
| 838 | printOptionalAttrDict(attrs, elidedAttrs); |
| 839 | } |
| 840 | |
| 841 | /// Return a null stream as the output stream, this will ignore any data fed |
| 842 | /// to it. |
| 843 | raw_ostream &getStream() const override { return os; } |
| 844 | |
| 845 | /// The following are hooks of `OpAsmPrinter` that are not necessary for |
| 846 | /// determining potential aliases. |
| 847 | void printFloat(const APFloat &) override {} |
| 848 | void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {} |
| 849 | void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {} |
| 850 | void printNewline() override {} |
| 851 | void increaseIndent() override {} |
| 852 | void decreaseIndent() override {} |
| 853 | void printOperand(Value) override {} |
| 854 | void printOperand(Value, raw_ostream &os) override { |
| 855 | // Users expect the output string to have at least the prefixed % to signal |
| 856 | // a value name. To maintain this invariant, emit a name even if it is |
| 857 | // guaranteed to go unused. |
| 858 | os << "%" ; |
| 859 | } |
| 860 | void printKeywordOrString(StringRef) override {} |
| 861 | void printString(StringRef) override {} |
| 862 | void printResourceHandle(const AsmDialectResourceHandle &) override {} |
| 863 | void printSymbolName(StringRef) override {} |
| 864 | void printSuccessor(Block *) override {} |
| 865 | void printSuccessorAndUseList(Block *, ValueRange) override {} |
| 866 | void shadowRegionArgs(Region &, ValueRange) override {} |
| 867 | |
| 868 | /// The printer flags to use when determining potential aliases. |
| 869 | const OpPrintingFlags &printerFlags; |
| 870 | |
| 871 | /// The initializer to use when identifying aliases. |
| 872 | AliasInitializer &initializer; |
| 873 | |
| 874 | /// A dummy output stream. |
| 875 | mutable llvm::raw_null_ostream os; |
| 876 | }; |
| 877 | |
| 878 | class DummyAliasDialectAsmPrinter : public DialectAsmPrinter { |
| 879 | public: |
| 880 | explicit DummyAliasDialectAsmPrinter(AliasInitializer &initializer, |
| 881 | bool canBeDeferred, |
| 882 | SmallVectorImpl<size_t> &childIndices) |
| 883 | : initializer(initializer), canBeDeferred(canBeDeferred), |
| 884 | childIndices(childIndices) {} |
| 885 | |
| 886 | /// Print the given attribute/type, visiting any nested aliases that would be |
| 887 | /// generated as part of printing. Returns the maximum alias depth found while |
| 888 | /// printing the given value. |
| 889 | template <typename T, typename... PrintArgs> |
| 890 | size_t printAndVisitNestedAliases(T value, PrintArgs &&...printArgs) { |
| 891 | printAndVisitNestedAliasesImpl(value, printArgs...); |
| 892 | return maxAliasDepth; |
| 893 | } |
| 894 | |
| 895 | private: |
| 896 | /// Print the given attribute/type, visiting any nested aliases that would be |
| 897 | /// generated as part of printing. |
| 898 | void printAndVisitNestedAliasesImpl(Attribute attr, bool elideType) { |
| 899 | if (!isa<BuiltinDialect>(Val: attr.getDialect())) { |
| 900 | attr.getDialect().printAttribute(attr, *this); |
| 901 | |
| 902 | // Process the builtin attributes. |
| 903 | } else if (llvm::isa<AffineMapAttr, DenseArrayAttr, FloatAttr, IntegerAttr, |
| 904 | IntegerSetAttr, UnitAttr>(Val: attr)) { |
| 905 | return; |
| 906 | } else if (auto distinctAttr = dyn_cast<DistinctAttr>(attr)) { |
| 907 | printAttribute(attr: distinctAttr.getReferencedAttr()); |
| 908 | } else if (auto dictAttr = dyn_cast<DictionaryAttr>(attr)) { |
| 909 | for (const NamedAttribute &nestedAttr : dictAttr.getValue()) { |
| 910 | printAttribute(nestedAttr.getName()); |
| 911 | printAttribute(nestedAttr.getValue()); |
| 912 | } |
| 913 | } else if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) { |
| 914 | for (Attribute nestedAttr : arrayAttr.getValue()) |
| 915 | printAttribute(nestedAttr); |
| 916 | } else if (auto typeAttr = dyn_cast<TypeAttr>(attr)) { |
| 917 | printType(type: typeAttr.getValue()); |
| 918 | } else if (auto locAttr = dyn_cast<OpaqueLoc>(attr)) { |
| 919 | printAttribute(attr: locAttr.getFallbackLocation()); |
| 920 | } else if (auto locAttr = dyn_cast<NameLoc>(attr)) { |
| 921 | if (!isa<UnknownLoc>(locAttr.getChildLoc())) |
| 922 | printAttribute(attr: locAttr.getChildLoc()); |
| 923 | } else if (auto locAttr = dyn_cast<CallSiteLoc>(attr)) { |
| 924 | printAttribute(attr: locAttr.getCallee()); |
| 925 | printAttribute(attr: locAttr.getCaller()); |
| 926 | } else if (auto locAttr = dyn_cast<FusedLoc>(attr)) { |
| 927 | if (Attribute metadata = locAttr.getMetadata()) |
| 928 | printAttribute(attr: metadata); |
| 929 | for (Location nestedLoc : locAttr.getLocations()) |
| 930 | printAttribute(nestedLoc); |
| 931 | } |
| 932 | |
| 933 | // Don't print the type if we must elide it, or if it is a None type. |
| 934 | if (!elideType) { |
| 935 | if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) { |
| 936 | Type attrType = typedAttr.getType(); |
| 937 | if (!llvm::isa<NoneType>(Val: attrType)) |
| 938 | printType(type: attrType); |
| 939 | } |
| 940 | } |
| 941 | } |
| 942 | void printAndVisitNestedAliasesImpl(Type type) { |
| 943 | if (!isa<BuiltinDialect>(Val: type.getDialect())) |
| 944 | return type.getDialect().printType(type, *this); |
| 945 | |
| 946 | // Only visit the layout of memref if it isn't the identity. |
| 947 | if (auto memrefTy = llvm::dyn_cast<MemRefType>(type)) { |
| 948 | printType(type: memrefTy.getElementType()); |
| 949 | MemRefLayoutAttrInterface layout = memrefTy.getLayout(); |
| 950 | if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) |
| 951 | printAttribute(attr: memrefTy.getLayout()); |
| 952 | if (memrefTy.getMemorySpace()) |
| 953 | printAttribute(attr: memrefTy.getMemorySpace()); |
| 954 | return; |
| 955 | } |
| 956 | |
| 957 | // For most builtin types, we can simply walk the sub elements. |
| 958 | auto visitFn = [&](auto element) { |
| 959 | if (element) |
| 960 | (void)printAlias(element); |
| 961 | }; |
| 962 | type.walkImmediateSubElements(walkAttrsFn: visitFn, walkTypesFn: visitFn); |
| 963 | } |
| 964 | |
| 965 | /// Consider the given type to be printed for an alias. |
| 966 | void printType(Type type) override { |
| 967 | recordAliasResult(aliasDepthAndIndex: initializer.visit(type, canBeDeferred)); |
| 968 | } |
| 969 | |
| 970 | /// Consider the given attribute to be printed for an alias. |
| 971 | void printAttribute(Attribute attr) override { |
| 972 | recordAliasResult(aliasDepthAndIndex: initializer.visit(attr, canBeDeferred)); |
| 973 | } |
| 974 | void printAttributeWithoutType(Attribute attr) override { |
| 975 | recordAliasResult( |
| 976 | aliasDepthAndIndex: initializer.visit(attr, canBeDeferred, /*elideType=*/true)); |
| 977 | } |
| 978 | LogicalResult printAlias(Attribute attr) override { |
| 979 | printAttribute(attr); |
| 980 | return success(); |
| 981 | } |
| 982 | LogicalResult printAlias(Type type) override { |
| 983 | printType(type); |
| 984 | return success(); |
| 985 | } |
| 986 | |
| 987 | /// Record the alias result of a child element. |
| 988 | void recordAliasResult(std::pair<size_t, size_t> aliasDepthAndIndex) { |
| 989 | childIndices.push_back(Elt: aliasDepthAndIndex.second); |
| 990 | if (aliasDepthAndIndex.first > maxAliasDepth) |
| 991 | maxAliasDepth = aliasDepthAndIndex.first; |
| 992 | } |
| 993 | |
| 994 | /// Return a null stream as the output stream, this will ignore any data fed |
| 995 | /// to it. |
| 996 | raw_ostream &getStream() const override { return os; } |
| 997 | |
| 998 | /// The following are hooks of `DialectAsmPrinter` that are not necessary for |
| 999 | /// determining potential aliases. |
| 1000 | void printFloat(const APFloat &) override {} |
| 1001 | void printKeywordOrString(StringRef) override {} |
| 1002 | void printString(StringRef) override {} |
| 1003 | void printSymbolName(StringRef) override {} |
| 1004 | void printResourceHandle(const AsmDialectResourceHandle &) override {} |
| 1005 | |
| 1006 | LogicalResult pushCyclicPrinting(const void *opaquePointer) override { |
| 1007 | return success(IsSuccess: cyclicPrintingStack.insert(X: opaquePointer)); |
| 1008 | } |
| 1009 | |
| 1010 | void popCyclicPrinting() override { cyclicPrintingStack.pop_back(); } |
| 1011 | |
| 1012 | /// Stack of potentially cyclic mutable attributes or type currently being |
| 1013 | /// printed. |
| 1014 | SetVector<const void *> cyclicPrintingStack; |
| 1015 | |
| 1016 | /// The initializer to use when identifying aliases. |
| 1017 | AliasInitializer &initializer; |
| 1018 | |
| 1019 | /// If the aliases visited by this printer can be deferred. |
| 1020 | bool canBeDeferred; |
| 1021 | |
| 1022 | /// The indices of child aliases. |
| 1023 | SmallVectorImpl<size_t> &childIndices; |
| 1024 | |
| 1025 | /// The maximum alias depth found by the printer. |
| 1026 | size_t maxAliasDepth = 0; |
| 1027 | |
| 1028 | /// A dummy output stream. |
| 1029 | mutable llvm::raw_null_ostream os; |
| 1030 | }; |
| 1031 | } // namespace |
| 1032 | |
| 1033 | /// Sanitize the given name such that it can be used as a valid identifier. If |
| 1034 | /// the string needs to be modified in any way, the provided buffer is used to |
| 1035 | /// store the new copy, |
| 1036 | static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, |
| 1037 | StringRef allowedPunctChars = "$._-" ) { |
| 1038 | assert(!name.empty() && "Shouldn't have an empty name here" ); |
| 1039 | |
| 1040 | auto validChar = [&](char ch) { |
| 1041 | return llvm::isAlnum(C: ch) || allowedPunctChars.contains(C: ch); |
| 1042 | }; |
| 1043 | |
| 1044 | auto copyNameToBuffer = [&] { |
| 1045 | for (char ch : name) { |
| 1046 | if (validChar(ch)) |
| 1047 | buffer.push_back(Elt: ch); |
| 1048 | else if (ch == ' ') |
| 1049 | buffer.push_back(Elt: '_'); |
| 1050 | else |
| 1051 | buffer.append(RHS: llvm::utohexstr(X: (unsigned char)ch)); |
| 1052 | } |
| 1053 | }; |
| 1054 | |
| 1055 | // Check to see if this name is valid. If it starts with a digit, then it |
| 1056 | // could conflict with the autogenerated numeric ID's, so add an underscore |
| 1057 | // prefix to avoid problems. |
| 1058 | if (isdigit(name[0]) || (!validChar(name[0]) && name[0] != ' ')) { |
| 1059 | buffer.push_back(Elt: '_'); |
| 1060 | copyNameToBuffer(); |
| 1061 | return buffer; |
| 1062 | } |
| 1063 | |
| 1064 | // Check to see that the name consists of only valid identifier characters. |
| 1065 | for (char ch : name) { |
| 1066 | if (!validChar(ch)) { |
| 1067 | copyNameToBuffer(); |
| 1068 | return buffer; |
| 1069 | } |
| 1070 | } |
| 1071 | |
| 1072 | // If there are no invalid characters, return the original name. |
| 1073 | return name; |
| 1074 | } |
| 1075 | |
| 1076 | unsigned AliasInitializer::uniqueAliasNameIndex( |
| 1077 | StringRef alias, llvm::StringMap<unsigned> &nameCounts, |
| 1078 | llvm::StringSet<llvm::BumpPtrAllocator &> &usedAliases) { |
| 1079 | if (!usedAliases.count(Key: alias)) { |
| 1080 | usedAliases.insert(key: alias); |
| 1081 | // 0 is not printed in SymbolAlias. |
| 1082 | return 0; |
| 1083 | } |
| 1084 | // Otherwise, we had a conflict - probe until we find a unique name. |
| 1085 | SmallString<64> probeAlias(alias); |
| 1086 | // alias with trailing digit will be printed as _N |
| 1087 | if (isdigit(alias.back())) |
| 1088 | probeAlias.push_back(Elt: '_'); |
| 1089 | // nameCounts start from 1 because 0 is not printed in SymbolAlias. |
| 1090 | if (nameCounts[probeAlias] == 0) |
| 1091 | nameCounts[probeAlias] = 1; |
| 1092 | // This is guaranteed to terminate (and usually in a single iteration) |
| 1093 | // because it generates new names by incrementing nameCounts. |
| 1094 | while (true) { |
| 1095 | unsigned nameIndex = nameCounts[probeAlias]++; |
| 1096 | probeAlias += llvm::utostr(X: nameIndex); |
| 1097 | if (!usedAliases.count(Key: probeAlias)) { |
| 1098 | usedAliases.insert(key: probeAlias); |
| 1099 | return nameIndex; |
| 1100 | } |
| 1101 | // Reset probeAlias to the original alias for the next iteration. |
| 1102 | probeAlias.resize(N: alias.size() + isdigit(alias.back()) ? 1 : 0); |
| 1103 | } |
| 1104 | } |
| 1105 | |
| 1106 | /// Given a collection of aliases and symbols, initialize a mapping from a |
| 1107 | /// symbol to a given alias. |
| 1108 | void AliasInitializer::initializeAliases( |
| 1109 | llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols, |
| 1110 | llvm::MapVector<const void *, SymbolAlias> &symbolToAlias) { |
| 1111 | SmallVector<std::pair<const void *, InProgressAliasInfo>, 0> |
| 1112 | unprocessedAliases = visitedSymbols.takeVector(); |
| 1113 | llvm::stable_sort(Range&: unprocessedAliases, C: llvm::less_second()); |
| 1114 | |
| 1115 | // This keeps track of all of the non-numeric names that are in flight, |
| 1116 | // allowing us to check for duplicates. |
| 1117 | llvm::BumpPtrAllocator usedAliasAllocator; |
| 1118 | llvm::StringSet<llvm::BumpPtrAllocator &> usedAliases(usedAliasAllocator); |
| 1119 | |
| 1120 | llvm::StringMap<unsigned> nameCounts; |
| 1121 | for (auto &[symbol, aliasInfo] : unprocessedAliases) { |
| 1122 | if (!aliasInfo.alias) |
| 1123 | continue; |
| 1124 | StringRef alias = *aliasInfo.alias; |
| 1125 | unsigned nameIndex = uniqueAliasNameIndex(alias, nameCounts, usedAliases); |
| 1126 | symbolToAlias.insert( |
| 1127 | KV: {symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType, |
| 1128 | aliasInfo.canBeDeferred)}); |
| 1129 | } |
| 1130 | } |
| 1131 | |
| 1132 | void AliasInitializer::initialize( |
| 1133 | Operation *op, const OpPrintingFlags &printerFlags, |
| 1134 | llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias) { |
| 1135 | // Use a dummy printer when walking the IR so that we can collect the |
| 1136 | // attributes/types that will actually be used during printing when |
| 1137 | // considering aliases. |
| 1138 | DummyAliasOperationPrinter aliasPrinter(printerFlags, *this); |
| 1139 | aliasPrinter.printCustomOrGenericOp(op); |
| 1140 | |
| 1141 | // Initialize the aliases. |
| 1142 | initializeAliases(visitedSymbols&: aliases, symbolToAlias&: attrTypeToAlias); |
| 1143 | } |
| 1144 | |
| 1145 | template <typename T, typename... PrintArgs> |
| 1146 | std::pair<size_t, size_t> AliasInitializer::visitImpl( |
| 1147 | T value, llvm::MapVector<const void *, InProgressAliasInfo> &aliases, |
| 1148 | bool canBeDeferred, PrintArgs &&...printArgs) { |
| 1149 | auto [it, inserted] = aliases.try_emplace(value.getAsOpaquePointer()); |
| 1150 | size_t aliasIndex = std::distance(aliases.begin(), it); |
| 1151 | if (!inserted) { |
| 1152 | // Make sure that the alias isn't deferred if we don't permit it. |
| 1153 | if (!canBeDeferred) |
| 1154 | markAliasNonDeferrable(aliasIndex); |
| 1155 | return {static_cast<size_t>(it->second.aliasDepth), aliasIndex}; |
| 1156 | } |
| 1157 | |
| 1158 | // Try to generate an alias for this value. |
| 1159 | generateAlias(value, it->second, canBeDeferred); |
| 1160 | it->second.isType = std::is_base_of_v<Type, T>; |
| 1161 | it->second.canBeDeferred = canBeDeferred; |
| 1162 | |
| 1163 | // Print the value, capturing any nested elements that require aliases. |
| 1164 | SmallVector<size_t> childAliases; |
| 1165 | DummyAliasDialectAsmPrinter printer(*this, canBeDeferred, childAliases); |
| 1166 | size_t maxAliasDepth = |
| 1167 | printer.printAndVisitNestedAliases(value, printArgs...); |
| 1168 | |
| 1169 | // Make sure to recompute `it` in case the map was reallocated. |
| 1170 | it = std::next(x: aliases.begin(), n: aliasIndex); |
| 1171 | |
| 1172 | // If we had sub elements, update to account for the depth. |
| 1173 | it->second.childIndices = std::move(childAliases); |
| 1174 | if (maxAliasDepth) |
| 1175 | it->second.aliasDepth = maxAliasDepth + 1; |
| 1176 | |
| 1177 | // Propagate the alias depth of the value. |
| 1178 | return {(size_t)it->second.aliasDepth, aliasIndex}; |
| 1179 | } |
| 1180 | |
| 1181 | void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) { |
| 1182 | auto *it = std::next(x: aliases.begin(), n: aliasIndex); |
| 1183 | |
| 1184 | // If already marked non-deferrable stop the recursion. |
| 1185 | // All children should already be marked non-deferrable as well. |
| 1186 | if (!it->second.canBeDeferred) |
| 1187 | return; |
| 1188 | |
| 1189 | it->second.canBeDeferred = false; |
| 1190 | |
| 1191 | // Propagate the non-deferrable flag to any child aliases. |
| 1192 | for (size_t childIndex : it->second.childIndices) |
| 1193 | markAliasNonDeferrable(aliasIndex: childIndex); |
| 1194 | } |
| 1195 | |
| 1196 | template <typename T> |
| 1197 | void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias, |
| 1198 | bool canBeDeferred) { |
| 1199 | SmallString<32> nameBuffer; |
| 1200 | |
| 1201 | OpAsmDialectInterface::AliasResult symbolInterfaceResult = |
| 1202 | OpAsmDialectInterface::AliasResult::NoAlias; |
| 1203 | using InterfaceT = std::conditional_t<std::is_base_of_v<Attribute, T>, |
| 1204 | OpAsmAttrInterface, OpAsmTypeInterface>; |
| 1205 | if (auto symbolInterface = dyn_cast<InterfaceT>(symbol)) { |
| 1206 | symbolInterfaceResult = symbolInterface.getAlias(aliasOS); |
| 1207 | if (symbolInterfaceResult != OpAsmDialectInterface::AliasResult::NoAlias) { |
| 1208 | nameBuffer = std::move(aliasBuffer); |
| 1209 | assert(!nameBuffer.empty() && "expected valid alias name" ); |
| 1210 | } |
| 1211 | } |
| 1212 | |
| 1213 | if (symbolInterfaceResult != OpAsmDialectInterface::AliasResult::FinalAlias) { |
| 1214 | for (const auto &interface : interfaces) { |
| 1215 | OpAsmDialectInterface::AliasResult result = |
| 1216 | interface.getAlias(symbol, aliasOS); |
| 1217 | if (result == OpAsmDialectInterface::AliasResult::NoAlias) |
| 1218 | continue; |
| 1219 | nameBuffer = std::move(aliasBuffer); |
| 1220 | assert(!nameBuffer.empty() && "expected valid alias name" ); |
| 1221 | if (result == OpAsmDialectInterface::AliasResult::FinalAlias) |
| 1222 | break; |
| 1223 | } |
| 1224 | } |
| 1225 | |
| 1226 | if (nameBuffer.empty()) |
| 1227 | return; |
| 1228 | |
| 1229 | SmallString<16> tempBuffer; |
| 1230 | StringRef name = |
| 1231 | sanitizeIdentifier(name: nameBuffer, buffer&: tempBuffer, /*allowedPunctChars=*/"$_-" ); |
| 1232 | name = name.copy(A&: aliasAllocator); |
| 1233 | alias = InProgressAliasInfo(name); |
| 1234 | } |
| 1235 | |
| 1236 | //===----------------------------------------------------------------------===// |
| 1237 | // AliasState |
| 1238 | //===----------------------------------------------------------------------===// |
| 1239 | |
| 1240 | namespace { |
| 1241 | /// This class manages the state for type and attribute aliases. |
| 1242 | class AliasState { |
| 1243 | public: |
| 1244 | // Initialize the internal aliases. |
| 1245 | void |
| 1246 | initialize(Operation *op, const OpPrintingFlags &printerFlags, |
| 1247 | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); |
| 1248 | |
| 1249 | /// Get an alias for the given attribute if it has one and print it in `os`. |
| 1250 | /// Returns success if an alias was printed, failure otherwise. |
| 1251 | LogicalResult getAlias(Attribute attr, raw_ostream &os) const; |
| 1252 | |
| 1253 | /// Get an alias for the given type if it has one and print it in `os`. |
| 1254 | /// Returns success if an alias was printed, failure otherwise. |
| 1255 | LogicalResult getAlias(Type ty, raw_ostream &os) const; |
| 1256 | |
| 1257 | /// Print all of the referenced aliases that can not be resolved in a deferred |
| 1258 | /// manner. |
| 1259 | void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) { |
| 1260 | printAliases(p, newLine, /*isDeferred=*/false); |
| 1261 | } |
| 1262 | |
| 1263 | /// Print all of the referenced aliases that support deferred resolution. |
| 1264 | void printDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) { |
| 1265 | printAliases(p, newLine, /*isDeferred=*/true); |
| 1266 | } |
| 1267 | |
| 1268 | private: |
| 1269 | /// Print all of the referenced aliases that support the provided resolution |
| 1270 | /// behavior. |
| 1271 | void printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, |
| 1272 | bool isDeferred); |
| 1273 | |
| 1274 | /// Mapping between attribute/type and alias. |
| 1275 | llvm::MapVector<const void *, SymbolAlias> attrTypeToAlias; |
| 1276 | |
| 1277 | /// An allocator used for alias names. |
| 1278 | llvm::BumpPtrAllocator aliasAllocator; |
| 1279 | }; |
| 1280 | } // namespace |
| 1281 | |
| 1282 | void AliasState::initialize( |
| 1283 | Operation *op, const OpPrintingFlags &printerFlags, |
| 1284 | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { |
| 1285 | AliasInitializer initializer(interfaces, aliasAllocator); |
| 1286 | initializer.initialize(op, printerFlags, attrTypeToAlias); |
| 1287 | } |
| 1288 | |
| 1289 | LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const { |
| 1290 | const auto *it = attrTypeToAlias.find(Key: attr.getAsOpaquePointer()); |
| 1291 | if (it == attrTypeToAlias.end()) |
| 1292 | return failure(); |
| 1293 | it->second.print(os); |
| 1294 | return success(); |
| 1295 | } |
| 1296 | |
| 1297 | LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const { |
| 1298 | const auto *it = attrTypeToAlias.find(Key: ty.getAsOpaquePointer()); |
| 1299 | if (it == attrTypeToAlias.end()) |
| 1300 | return failure(); |
| 1301 | if (!it->second.isPrinted) |
| 1302 | return failure(); |
| 1303 | |
| 1304 | it->second.print(os); |
| 1305 | return success(); |
| 1306 | } |
| 1307 | |
| 1308 | void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, |
| 1309 | bool isDeferred) { |
| 1310 | auto filterFn = [=](const auto &aliasIt) { |
| 1311 | return aliasIt.second.canBeDeferred() == isDeferred; |
| 1312 | }; |
| 1313 | for (auto &[opaqueSymbol, alias] : |
| 1314 | llvm::make_filter_range(Range&: attrTypeToAlias, Pred: filterFn)) { |
| 1315 | alias.print(os&: p.getStream()); |
| 1316 | p.getStream() << " = " ; |
| 1317 | |
| 1318 | if (alias.isTypeAlias()) { |
| 1319 | Type type = Type::getFromOpaquePointer(pointer: opaqueSymbol); |
| 1320 | p.printTypeImpl(type); |
| 1321 | alias.isPrinted = true; |
| 1322 | } else { |
| 1323 | // TODO: Support nested aliases in mutable attributes. |
| 1324 | Attribute attr = Attribute::getFromOpaquePointer(ptr: opaqueSymbol); |
| 1325 | if (attr.hasTrait<AttributeTrait::IsMutable>()) |
| 1326 | p.getStream() << attr; |
| 1327 | else |
| 1328 | p.printAttributeImpl(attr); |
| 1329 | } |
| 1330 | |
| 1331 | p.getStream() << newLine; |
| 1332 | } |
| 1333 | } |
| 1334 | |
| 1335 | //===----------------------------------------------------------------------===// |
| 1336 | // SSANameState |
| 1337 | //===----------------------------------------------------------------------===// |
| 1338 | |
| 1339 | namespace { |
| 1340 | /// Info about block printing: a number which is its position in the visitation |
| 1341 | /// order, and a name that is used to print reference to it, e.g. ^bb42. |
| 1342 | struct BlockInfo { |
| 1343 | int ordering; |
| 1344 | StringRef name; |
| 1345 | }; |
| 1346 | |
| 1347 | /// This class manages the state of SSA value names. |
| 1348 | class SSANameState { |
| 1349 | public: |
| 1350 | /// A sentinel value used for values with names set. |
| 1351 | enum : unsigned { NameSentinel = ~0U }; |
| 1352 | |
| 1353 | SSANameState(Operation *op, const OpPrintingFlags &printerFlags); |
| 1354 | SSANameState() = default; |
| 1355 | |
| 1356 | /// Print the SSA identifier for the given value to 'stream'. If |
| 1357 | /// 'printResultNo' is true, it also presents the result number ('#' number) |
| 1358 | /// of this value. |
| 1359 | void printValueID(Value value, bool printResultNo, raw_ostream &stream) const; |
| 1360 | |
| 1361 | /// Print the operation identifier. |
| 1362 | void printOperationID(Operation *op, raw_ostream &stream) const; |
| 1363 | |
| 1364 | /// Return the result indices for each of the result groups registered by this |
| 1365 | /// operation, or empty if none exist. |
| 1366 | ArrayRef<int> getOpResultGroups(Operation *op); |
| 1367 | |
| 1368 | /// Get the info for the given block. |
| 1369 | BlockInfo getBlockInfo(Block *block); |
| 1370 | |
| 1371 | /// Renumber the arguments for the specified region to the same names as the |
| 1372 | /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for |
| 1373 | /// details. |
| 1374 | void shadowRegionArgs(Region ®ion, ValueRange namesToUse); |
| 1375 | |
| 1376 | private: |
| 1377 | /// Number the SSA values within the given IR unit. |
| 1378 | void numberValuesInRegion(Region ®ion); |
| 1379 | void numberValuesInBlock(Block &block); |
| 1380 | void numberValuesInOp(Operation &op); |
| 1381 | |
| 1382 | /// Given a result of an operation 'result', find the result group head |
| 1383 | /// 'lookupValue' and the result of 'result' within that group in |
| 1384 | /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group |
| 1385 | /// has more than 1 result. |
| 1386 | void getResultIDAndNumber(OpResult result, Value &lookupValue, |
| 1387 | std::optional<int> &lookupResultNo) const; |
| 1388 | |
| 1389 | /// Set a special value name for the given value. |
| 1390 | void setValueName(Value value, StringRef name); |
| 1391 | |
| 1392 | /// Uniques the given value name within the printer. If the given name |
| 1393 | /// conflicts, it is automatically renamed. |
| 1394 | StringRef uniqueValueName(StringRef name); |
| 1395 | |
| 1396 | /// This is the value ID for each SSA value. If this returns NameSentinel, |
| 1397 | /// then the valueID has an entry in valueNames. |
| 1398 | DenseMap<Value, unsigned> valueIDs; |
| 1399 | DenseMap<Value, StringRef> valueNames; |
| 1400 | |
| 1401 | /// When printing users of values, an operation without a result might |
| 1402 | /// be the user. This map holds ids for such operations. |
| 1403 | DenseMap<Operation *, unsigned> operationIDs; |
| 1404 | |
| 1405 | /// This is a map of operations that contain multiple named result groups, |
| 1406 | /// i.e. there may be multiple names for the results of the operation. The |
| 1407 | /// value of this map are the result numbers that start a result group. |
| 1408 | DenseMap<Operation *, SmallVector<int, 1>> opResultGroups; |
| 1409 | |
| 1410 | /// This maps blocks to there visitation number in the current region as well |
| 1411 | /// as the string representing their name. |
| 1412 | DenseMap<Block *, BlockInfo> blockNames; |
| 1413 | |
| 1414 | /// This keeps track of all of the non-numeric names that are in flight, |
| 1415 | /// allowing us to check for duplicates. |
| 1416 | /// Note: the value of the map is unused. |
| 1417 | llvm::ScopedHashTable<StringRef, char> usedNames; |
| 1418 | llvm::BumpPtrAllocator usedNameAllocator; |
| 1419 | |
| 1420 | /// This is the next value ID to assign in numbering. |
| 1421 | unsigned nextValueID = 0; |
| 1422 | /// This is the next ID to assign to a region entry block argument. |
| 1423 | unsigned nextArgumentID = 0; |
| 1424 | /// This is the next ID to assign when a name conflict is detected. |
| 1425 | unsigned nextConflictID = 0; |
| 1426 | |
| 1427 | /// These are the printing flags. They control, eg., whether to print in |
| 1428 | /// generic form. |
| 1429 | OpPrintingFlags printerFlags; |
| 1430 | }; |
| 1431 | } // namespace |
| 1432 | |
| 1433 | SSANameState::SSANameState(Operation *op, const OpPrintingFlags &printerFlags) |
| 1434 | : printerFlags(printerFlags) { |
| 1435 | llvm::SaveAndRestore valueIDSaver(nextValueID); |
| 1436 | llvm::SaveAndRestore argumentIDSaver(nextArgumentID); |
| 1437 | llvm::SaveAndRestore conflictIDSaver(nextConflictID); |
| 1438 | |
| 1439 | // The naming context includes `nextValueID`, `nextArgumentID`, |
| 1440 | // `nextConflictID` and `usedNames` scoped HashTable. This information is |
| 1441 | // carried from the parent region. |
| 1442 | using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy; |
| 1443 | using NamingContext = |
| 1444 | std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>; |
| 1445 | |
| 1446 | // Allocator for UsedNamesScopeTy |
| 1447 | llvm::BumpPtrAllocator allocator; |
| 1448 | |
| 1449 | // Add a scope for the top level operation. |
| 1450 | auto *topLevelNamesScope = |
| 1451 | new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames); |
| 1452 | |
| 1453 | SmallVector<NamingContext, 8> nameContext; |
| 1454 | for (Region ®ion : op->getRegions()) |
| 1455 | nameContext.push_back(Elt: std::make_tuple(args: ®ion, args&: nextValueID, args&: nextArgumentID, |
| 1456 | args&: nextConflictID, args&: topLevelNamesScope)); |
| 1457 | |
| 1458 | numberValuesInOp(op&: *op); |
| 1459 | |
| 1460 | while (!nameContext.empty()) { |
| 1461 | Region *region; |
| 1462 | UsedNamesScopeTy *parentScope; |
| 1463 | |
| 1464 | if (printerFlags.shouldPrintUniqueSSAIDs()) |
| 1465 | // To print unique SSA IDs, ignore saved ID counts from parent regions |
| 1466 | std::tie(args&: region, args: std::ignore, args: std::ignore, args: std::ignore, args&: parentScope) = |
| 1467 | nameContext.pop_back_val(); |
| 1468 | else |
| 1469 | std::tie(args&: region, args&: nextValueID, args&: nextArgumentID, args&: nextConflictID, |
| 1470 | args&: parentScope) = nameContext.pop_back_val(); |
| 1471 | |
| 1472 | // When we switch from one subtree to another, pop the scopes(needless) |
| 1473 | // until the parent scope. |
| 1474 | while (usedNames.getCurScope() != parentScope) { |
| 1475 | usedNames.getCurScope()->~UsedNamesScopeTy(); |
| 1476 | assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) && |
| 1477 | "top level parentScope must be a nullptr" ); |
| 1478 | } |
| 1479 | |
| 1480 | // Add a scope for the current region. |
| 1481 | auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>()) |
| 1482 | UsedNamesScopeTy(usedNames); |
| 1483 | |
| 1484 | numberValuesInRegion(region&: *region); |
| 1485 | |
| 1486 | for (Operation &op : region->getOps()) |
| 1487 | for (Region ®ion : op.getRegions()) |
| 1488 | nameContext.push_back(Elt: std::make_tuple(args: ®ion, args&: nextValueID, |
| 1489 | args&: nextArgumentID, args&: nextConflictID, |
| 1490 | args&: curNamesScope)); |
| 1491 | } |
| 1492 | |
| 1493 | // Manually remove all the scopes. |
| 1494 | while (usedNames.getCurScope() != nullptr) |
| 1495 | usedNames.getCurScope()->~UsedNamesScopeTy(); |
| 1496 | } |
| 1497 | |
| 1498 | void SSANameState::printValueID(Value value, bool printResultNo, |
| 1499 | raw_ostream &stream) const { |
| 1500 | if (!value) { |
| 1501 | stream << "<<NULL VALUE>>" ; |
| 1502 | return; |
| 1503 | } |
| 1504 | |
| 1505 | std::optional<int> resultNo; |
| 1506 | auto lookupValue = value; |
| 1507 | |
| 1508 | // If this is an operation result, collect the head lookup value of the result |
| 1509 | // group and the result number of 'result' within that group. |
| 1510 | if (OpResult result = dyn_cast<OpResult>(Val&: value)) |
| 1511 | getResultIDAndNumber(result, lookupValue, lookupResultNo&: resultNo); |
| 1512 | |
| 1513 | auto it = valueIDs.find(Val: lookupValue); |
| 1514 | if (it == valueIDs.end()) { |
| 1515 | stream << "<<UNKNOWN SSA VALUE>>" ; |
| 1516 | return; |
| 1517 | } |
| 1518 | |
| 1519 | stream << '%'; |
| 1520 | if (it->second != NameSentinel) { |
| 1521 | stream << it->second; |
| 1522 | } else { |
| 1523 | auto nameIt = valueNames.find(Val: lookupValue); |
| 1524 | assert(nameIt != valueNames.end() && "Didn't have a name entry?" ); |
| 1525 | stream << nameIt->second; |
| 1526 | } |
| 1527 | |
| 1528 | if (resultNo && printResultNo) |
| 1529 | stream << '#' << *resultNo; |
| 1530 | } |
| 1531 | |
| 1532 | void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const { |
| 1533 | auto it = operationIDs.find(Val: op); |
| 1534 | if (it == operationIDs.end()) { |
| 1535 | stream << "<<UNKNOWN OPERATION>>" ; |
| 1536 | } else { |
| 1537 | stream << '%' << it->second; |
| 1538 | } |
| 1539 | } |
| 1540 | |
| 1541 | ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) { |
| 1542 | auto it = opResultGroups.find(Val: op); |
| 1543 | return it == opResultGroups.end() ? ArrayRef<int>() : it->second; |
| 1544 | } |
| 1545 | |
| 1546 | BlockInfo SSANameState::getBlockInfo(Block *block) { |
| 1547 | auto it = blockNames.find(Val: block); |
| 1548 | BlockInfo invalidBlock{.ordering: -1, .name: "INVALIDBLOCK" }; |
| 1549 | return it != blockNames.end() ? it->second : invalidBlock; |
| 1550 | } |
| 1551 | |
| 1552 | void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { |
| 1553 | assert(!region.empty() && "cannot shadow arguments of an empty region" ); |
| 1554 | assert(region.getNumArguments() == namesToUse.size() && |
| 1555 | "incorrect number of names passed in" ); |
| 1556 | assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() && |
| 1557 | "only KnownIsolatedFromAbove ops can shadow names" ); |
| 1558 | |
| 1559 | SmallVector<char, 16> nameStr; |
| 1560 | for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { |
| 1561 | auto nameToUse = namesToUse[i]; |
| 1562 | if (nameToUse == nullptr) |
| 1563 | continue; |
| 1564 | auto nameToReplace = region.getArgument(i); |
| 1565 | |
| 1566 | nameStr.clear(); |
| 1567 | llvm::raw_svector_ostream nameStream(nameStr); |
| 1568 | printValueID(value: nameToUse, /*printResultNo=*/true, stream&: nameStream); |
| 1569 | |
| 1570 | // Entry block arguments should already have a pretty "arg" name. |
| 1571 | assert(valueIDs[nameToReplace] == NameSentinel); |
| 1572 | |
| 1573 | // Use the name without the leading %. |
| 1574 | auto name = StringRef(nameStream.str()).drop_front(); |
| 1575 | |
| 1576 | // Overwrite the name. |
| 1577 | valueNames[nameToReplace] = name.copy(A&: usedNameAllocator); |
| 1578 | } |
| 1579 | } |
| 1580 | |
| 1581 | namespace { |
| 1582 | /// Try to get value name from value's location, fallback to `name`. |
| 1583 | StringRef maybeGetValueNameFromLoc(Value value, StringRef name) { |
| 1584 | if (auto maybeNameLoc = value.getLoc()->findInstanceOf<NameLoc>()) |
| 1585 | return maybeNameLoc.getName(); |
| 1586 | return name; |
| 1587 | } |
| 1588 | } // namespace |
| 1589 | |
| 1590 | void SSANameState::numberValuesInRegion(Region ®ion) { |
| 1591 | // Indicates whether OpAsmOpInterface set a name. |
| 1592 | bool opAsmOpInterfaceUsed = false; |
| 1593 | auto setBlockArgNameFn = [&](Value arg, StringRef name) { |
| 1594 | assert(!valueIDs.count(arg) && "arg numbered multiple times" ); |
| 1595 | assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == ®ion && |
| 1596 | "arg not defined in current region" ); |
| 1597 | opAsmOpInterfaceUsed = true; |
| 1598 | if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix())) |
| 1599 | name = maybeGetValueNameFromLoc(value: arg, name); |
| 1600 | setValueName(value: arg, name); |
| 1601 | }; |
| 1602 | |
| 1603 | if (!printerFlags.shouldPrintGenericOpForm()) { |
| 1604 | if (Operation *op = region.getParentOp()) { |
| 1605 | if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op)) |
| 1606 | asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn); |
| 1607 | // If the OpAsmOpInterface didn't set a name, get name from the type. |
| 1608 | if (!opAsmOpInterfaceUsed) { |
| 1609 | for (BlockArgument arg : region.getArguments()) { |
| 1610 | if (auto interface = dyn_cast<OpAsmTypeInterface>(arg.getType())) { |
| 1611 | interface.getAsmName( |
| 1612 | [&](StringRef name) { setBlockArgNameFn(arg, name); }); |
| 1613 | } |
| 1614 | } |
| 1615 | } |
| 1616 | } |
| 1617 | } |
| 1618 | |
| 1619 | // Number the values within this region in a breadth-first order. |
| 1620 | unsigned nextBlockID = 0; |
| 1621 | for (auto &block : region) { |
| 1622 | // Each block gets a unique ID, and all of the operations within it get |
| 1623 | // numbered as well. |
| 1624 | auto blockInfoIt = blockNames.insert(KV: {&block, {.ordering: -1, .name: "" }}); |
| 1625 | if (blockInfoIt.second) { |
| 1626 | // This block hasn't been named through `getAsmBlockArgumentNames`, use |
| 1627 | // default `^bbNNN` format. |
| 1628 | std::string name; |
| 1629 | llvm::raw_string_ostream(name) << "^bb" << nextBlockID; |
| 1630 | blockInfoIt.first->second.name = StringRef(name).copy(A&: usedNameAllocator); |
| 1631 | } |
| 1632 | blockInfoIt.first->second.ordering = nextBlockID++; |
| 1633 | |
| 1634 | numberValuesInBlock(block); |
| 1635 | } |
| 1636 | } |
| 1637 | |
| 1638 | void SSANameState::numberValuesInBlock(Block &block) { |
| 1639 | // Number the block arguments. We give entry block arguments a special name |
| 1640 | // 'arg'. |
| 1641 | bool isEntryBlock = block.isEntryBlock(); |
| 1642 | SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "" ); |
| 1643 | llvm::raw_svector_ostream specialName(specialNameBuffer); |
| 1644 | for (auto arg : block.getArguments()) { |
| 1645 | if (valueIDs.count(Val: arg)) |
| 1646 | continue; |
| 1647 | if (isEntryBlock) { |
| 1648 | specialNameBuffer.resize(N: strlen(s: "arg" )); |
| 1649 | specialName << nextArgumentID++; |
| 1650 | } |
| 1651 | StringRef specialNameStr = specialName.str(); |
| 1652 | if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix())) |
| 1653 | specialNameStr = maybeGetValueNameFromLoc(value: arg, name: specialNameStr); |
| 1654 | setValueName(value: arg, name: specialNameStr); |
| 1655 | } |
| 1656 | |
| 1657 | // Number the operations in this block. |
| 1658 | for (auto &op : block) |
| 1659 | numberValuesInOp(op); |
| 1660 | } |
| 1661 | |
| 1662 | void SSANameState::numberValuesInOp(Operation &op) { |
| 1663 | // Function used to set the special result names for the operation. |
| 1664 | SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0); |
| 1665 | // Indicates whether OpAsmOpInterface set a name. |
| 1666 | bool opAsmOpInterfaceUsed = false; |
| 1667 | auto setResultNameFn = [&](Value result, StringRef name) { |
| 1668 | assert(!valueIDs.count(result) && "result numbered multiple times" ); |
| 1669 | assert(result.getDefiningOp() == &op && "result not defined by 'op'" ); |
| 1670 | opAsmOpInterfaceUsed = true; |
| 1671 | if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix())) |
| 1672 | name = maybeGetValueNameFromLoc(value: result, name); |
| 1673 | setValueName(value: result, name); |
| 1674 | |
| 1675 | // Record the result number for groups not anchored at 0. |
| 1676 | if (int resultNo = llvm::cast<OpResult>(Val&: result).getResultNumber()) |
| 1677 | resultGroups.push_back(Elt: resultNo); |
| 1678 | }; |
| 1679 | // Operations can customize the printing of block names in OpAsmOpInterface. |
| 1680 | auto setBlockNameFn = [&](Block *block, StringRef name) { |
| 1681 | assert(block->getParentOp() == &op && |
| 1682 | "getAsmBlockArgumentNames callback invoked on a block not directly " |
| 1683 | "nested under the current operation" ); |
| 1684 | assert(!blockNames.count(block) && "block numbered multiple times" ); |
| 1685 | SmallString<16> tmpBuffer{"^" }; |
| 1686 | name = sanitizeIdentifier(name, buffer&: tmpBuffer); |
| 1687 | if (name.data() != tmpBuffer.data()) { |
| 1688 | tmpBuffer.append(RHS: name); |
| 1689 | name = tmpBuffer.str(); |
| 1690 | } |
| 1691 | name = name.copy(A&: usedNameAllocator); |
| 1692 | blockNames[block] = {.ordering: -1, .name: name}; |
| 1693 | }; |
| 1694 | |
| 1695 | if (!printerFlags.shouldPrintGenericOpForm()) { |
| 1696 | if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) { |
| 1697 | asmInterface.getAsmBlockNames(setBlockNameFn); |
| 1698 | asmInterface.getAsmResultNames(setResultNameFn); |
| 1699 | } |
| 1700 | if (!opAsmOpInterfaceUsed) { |
| 1701 | // If the OpAsmOpInterface didn't set a name, and all results have |
| 1702 | // OpAsmTypeInterface, get names from types. |
| 1703 | bool allHaveOpAsmTypeInterface = |
| 1704 | llvm::all_of(Range: op.getResultTypes(), P: [&](Type type) { |
| 1705 | return isa<OpAsmTypeInterface>(type); |
| 1706 | }); |
| 1707 | if (allHaveOpAsmTypeInterface) { |
| 1708 | for (OpResult result : op.getResults()) { |
| 1709 | auto interface = cast<OpAsmTypeInterface>(result.getType()); |
| 1710 | interface.getAsmName( |
| 1711 | [&](StringRef name) { setResultNameFn(result, name); }); |
| 1712 | } |
| 1713 | } |
| 1714 | } |
| 1715 | } |
| 1716 | |
| 1717 | unsigned numResults = op.getNumResults(); |
| 1718 | if (numResults == 0) { |
| 1719 | // If value users should be printed, operations with no result need an id. |
| 1720 | if (printerFlags.shouldPrintValueUsers()) { |
| 1721 | if (operationIDs.try_emplace(Key: &op, Args&: nextValueID).second) |
| 1722 | ++nextValueID; |
| 1723 | } |
| 1724 | return; |
| 1725 | } |
| 1726 | Value resultBegin = op.getResult(idx: 0); |
| 1727 | |
| 1728 | if (printerFlags.shouldUseNameLocAsPrefix() && !valueIDs.count(Val: resultBegin)) { |
| 1729 | if (auto nameLoc = resultBegin.getLoc()->findInstanceOf<NameLoc>()) { |
| 1730 | setValueName(value: resultBegin, name: nameLoc.getName()); |
| 1731 | } |
| 1732 | } |
| 1733 | |
| 1734 | // If the first result wasn't numbered, give it a default number. |
| 1735 | if (valueIDs.try_emplace(Key: resultBegin, Args&: nextValueID).second) |
| 1736 | ++nextValueID; |
| 1737 | |
| 1738 | // If this operation has multiple result groups, mark it. |
| 1739 | if (resultGroups.size() != 1) { |
| 1740 | llvm::array_pod_sort(Start: resultGroups.begin(), End: resultGroups.end()); |
| 1741 | opResultGroups.try_emplace(Key: &op, Args: std::move(resultGroups)); |
| 1742 | } |
| 1743 | } |
| 1744 | |
| 1745 | void SSANameState::getResultIDAndNumber( |
| 1746 | OpResult result, Value &lookupValue, |
| 1747 | std::optional<int> &lookupResultNo) const { |
| 1748 | Operation *owner = result.getOwner(); |
| 1749 | if (owner->getNumResults() == 1) |
| 1750 | return; |
| 1751 | int resultNo = result.getResultNumber(); |
| 1752 | |
| 1753 | // If this operation has multiple result groups, we will need to find the |
| 1754 | // one corresponding to this result. |
| 1755 | auto resultGroupIt = opResultGroups.find(Val: owner); |
| 1756 | if (resultGroupIt == opResultGroups.end()) { |
| 1757 | // If not, just use the first result. |
| 1758 | lookupResultNo = resultNo; |
| 1759 | lookupValue = owner->getResult(idx: 0); |
| 1760 | return; |
| 1761 | } |
| 1762 | |
| 1763 | // Find the correct index using a binary search, as the groups are ordered. |
| 1764 | ArrayRef<int> resultGroups = resultGroupIt->second; |
| 1765 | const auto *it = llvm::upper_bound(Range&: resultGroups, Value&: resultNo); |
| 1766 | int groupResultNo = 0, groupSize = 0; |
| 1767 | |
| 1768 | // If there are no smaller elements, the last result group is the lookup. |
| 1769 | if (it == resultGroups.end()) { |
| 1770 | groupResultNo = resultGroups.back(); |
| 1771 | groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back(); |
| 1772 | } else { |
| 1773 | // Otherwise, the previous element is the lookup. |
| 1774 | groupResultNo = *std::prev(x: it); |
| 1775 | groupSize = *it - groupResultNo; |
| 1776 | } |
| 1777 | |
| 1778 | // We only record the result number for a group of size greater than 1. |
| 1779 | if (groupSize != 1) |
| 1780 | lookupResultNo = resultNo - groupResultNo; |
| 1781 | lookupValue = owner->getResult(idx: groupResultNo); |
| 1782 | } |
| 1783 | |
| 1784 | void SSANameState::setValueName(Value value, StringRef name) { |
| 1785 | // If the name is empty, the value uses the default numbering. |
| 1786 | if (name.empty()) { |
| 1787 | valueIDs[value] = nextValueID++; |
| 1788 | return; |
| 1789 | } |
| 1790 | |
| 1791 | valueIDs[value] = NameSentinel; |
| 1792 | valueNames[value] = uniqueValueName(name); |
| 1793 | } |
| 1794 | |
| 1795 | StringRef SSANameState::uniqueValueName(StringRef name) { |
| 1796 | SmallString<16> tmpBuffer; |
| 1797 | name = sanitizeIdentifier(name, buffer&: tmpBuffer); |
| 1798 | |
| 1799 | // Check to see if this name is already unique. |
| 1800 | if (!usedNames.count(Key: name)) { |
| 1801 | name = name.copy(A&: usedNameAllocator); |
| 1802 | } else { |
| 1803 | // Otherwise, we had a conflict - probe until we find a unique name. This |
| 1804 | // is guaranteed to terminate (and usually in a single iteration) because it |
| 1805 | // generates new names by incrementing nextConflictID. |
| 1806 | SmallString<64> probeName(name); |
| 1807 | probeName.push_back(Elt: '_'); |
| 1808 | while (true) { |
| 1809 | probeName += llvm::utostr(X: nextConflictID++); |
| 1810 | if (!usedNames.count(Key: probeName)) { |
| 1811 | name = probeName.str().copy(A&: usedNameAllocator); |
| 1812 | break; |
| 1813 | } |
| 1814 | probeName.resize(N: name.size() + 1); |
| 1815 | } |
| 1816 | } |
| 1817 | |
| 1818 | usedNames.insert(Key: name, Val: char()); |
| 1819 | return name; |
| 1820 | } |
| 1821 | |
| 1822 | //===----------------------------------------------------------------------===// |
| 1823 | // DistinctState |
| 1824 | //===----------------------------------------------------------------------===// |
| 1825 | |
| 1826 | namespace { |
| 1827 | /// This class manages the state for distinct attributes. |
| 1828 | class DistinctState { |
| 1829 | public: |
| 1830 | /// Returns a unique identifier for the given distinct attribute. |
| 1831 | uint64_t getId(DistinctAttr distinctAttr); |
| 1832 | |
| 1833 | private: |
| 1834 | uint64_t distinctCounter = 0; |
| 1835 | DenseMap<DistinctAttr, uint64_t> distinctAttrMap; |
| 1836 | }; |
| 1837 | } // namespace |
| 1838 | |
| 1839 | uint64_t DistinctState::getId(DistinctAttr distinctAttr) { |
| 1840 | auto [it, inserted] = |
| 1841 | distinctAttrMap.try_emplace(Key: distinctAttr, Args&: distinctCounter); |
| 1842 | if (inserted) |
| 1843 | distinctCounter++; |
| 1844 | return it->getSecond(); |
| 1845 | } |
| 1846 | |
| 1847 | //===----------------------------------------------------------------------===// |
| 1848 | // Resources |
| 1849 | //===----------------------------------------------------------------------===// |
| 1850 | |
| 1851 | AsmParsedResourceEntry::~AsmParsedResourceEntry() = default; |
| 1852 | AsmResourceBuilder::~AsmResourceBuilder() = default; |
| 1853 | AsmResourceParser::~AsmResourceParser() = default; |
| 1854 | AsmResourcePrinter::~AsmResourcePrinter() = default; |
| 1855 | |
| 1856 | StringRef mlir::toString(AsmResourceEntryKind kind) { |
| 1857 | switch (kind) { |
| 1858 | case AsmResourceEntryKind::Blob: |
| 1859 | return "blob" ; |
| 1860 | case AsmResourceEntryKind::Bool: |
| 1861 | return "bool" ; |
| 1862 | case AsmResourceEntryKind::String: |
| 1863 | return "string" ; |
| 1864 | } |
| 1865 | llvm_unreachable("unknown AsmResourceEntryKind" ); |
| 1866 | } |
| 1867 | |
| 1868 | AsmResourceParser &FallbackAsmResourceMap::getParserFor(StringRef key) { |
| 1869 | std::unique_ptr<ResourceCollection> &collection = keyToResources[key.str()]; |
| 1870 | if (!collection) |
| 1871 | collection = std::make_unique<ResourceCollection>(args&: key); |
| 1872 | return *collection; |
| 1873 | } |
| 1874 | |
| 1875 | std::vector<std::unique_ptr<AsmResourcePrinter>> |
| 1876 | FallbackAsmResourceMap::getPrinters() { |
| 1877 | std::vector<std::unique_ptr<AsmResourcePrinter>> printers; |
| 1878 | for (auto &it : keyToResources) { |
| 1879 | ResourceCollection *collection = it.second.get(); |
| 1880 | auto buildValues = [=](Operation *op, AsmResourceBuilder &builder) { |
| 1881 | return collection->buildResources(op, builder); |
| 1882 | }; |
| 1883 | printers.emplace_back( |
| 1884 | args: AsmResourcePrinter::fromCallable(name: collection->getName(), printFn&: buildValues)); |
| 1885 | } |
| 1886 | return printers; |
| 1887 | } |
| 1888 | |
| 1889 | LogicalResult FallbackAsmResourceMap::ResourceCollection::parseResource( |
| 1890 | AsmParsedResourceEntry &entry) { |
| 1891 | switch (entry.getKind()) { |
| 1892 | case AsmResourceEntryKind::Blob: { |
| 1893 | FailureOr<AsmResourceBlob> blob = entry.parseAsBlob(); |
| 1894 | if (failed(Result: blob)) |
| 1895 | return failure(); |
| 1896 | resources.emplace_back(Args: entry.getKey(), Args: std::move(*blob)); |
| 1897 | return success(); |
| 1898 | } |
| 1899 | case AsmResourceEntryKind::Bool: { |
| 1900 | FailureOr<bool> value = entry.parseAsBool(); |
| 1901 | if (failed(Result: value)) |
| 1902 | return failure(); |
| 1903 | resources.emplace_back(Args: entry.getKey(), Args&: *value); |
| 1904 | break; |
| 1905 | } |
| 1906 | case AsmResourceEntryKind::String: { |
| 1907 | FailureOr<std::string> str = entry.parseAsString(); |
| 1908 | if (failed(Result: str)) |
| 1909 | return failure(); |
| 1910 | resources.emplace_back(Args: entry.getKey(), Args: std::move(*str)); |
| 1911 | break; |
| 1912 | } |
| 1913 | } |
| 1914 | return success(); |
| 1915 | } |
| 1916 | |
| 1917 | void FallbackAsmResourceMap::ResourceCollection::buildResources( |
| 1918 | Operation *op, AsmResourceBuilder &builder) const { |
| 1919 | for (const auto &entry : resources) { |
| 1920 | if (const auto *value = std::get_if<AsmResourceBlob>(ptr: &entry.value)) |
| 1921 | builder.buildBlob(key: entry.key, blob: *value); |
| 1922 | else if (const auto *value = std::get_if<bool>(ptr: &entry.value)) |
| 1923 | builder.buildBool(key: entry.key, data: *value); |
| 1924 | else if (const auto *value = std::get_if<std::string>(ptr: &entry.value)) |
| 1925 | builder.buildString(key: entry.key, data: *value); |
| 1926 | else |
| 1927 | llvm_unreachable("unknown AsmResourceEntryKind" ); |
| 1928 | } |
| 1929 | } |
| 1930 | |
| 1931 | //===----------------------------------------------------------------------===// |
| 1932 | // AsmState |
| 1933 | //===----------------------------------------------------------------------===// |
| 1934 | |
| 1935 | namespace mlir { |
| 1936 | namespace detail { |
| 1937 | class AsmStateImpl { |
| 1938 | public: |
| 1939 | explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags, |
| 1940 | AsmState::LocationMap *locationMap) |
| 1941 | : interfaces(op->getContext()), nameState(op, printerFlags), |
| 1942 | printerFlags(printerFlags), locationMap(locationMap) {} |
| 1943 | explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags, |
| 1944 | AsmState::LocationMap *locationMap) |
| 1945 | : interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {} |
| 1946 | |
| 1947 | /// Initialize the alias state to enable the printing of aliases. |
| 1948 | void initializeAliases(Operation *op) { |
| 1949 | aliasState.initialize(op, printerFlags, interfaces); |
| 1950 | } |
| 1951 | |
| 1952 | /// Get the state used for aliases. |
| 1953 | AliasState &getAliasState() { return aliasState; } |
| 1954 | |
| 1955 | /// Get the state used for SSA names. |
| 1956 | SSANameState &getSSANameState() { return nameState; } |
| 1957 | |
| 1958 | /// Get the state used for distinct attribute identifiers. |
| 1959 | DistinctState &getDistinctState() { return distinctState; } |
| 1960 | |
| 1961 | /// Return the dialects within the context that implement |
| 1962 | /// OpAsmDialectInterface. |
| 1963 | DialectInterfaceCollection<OpAsmDialectInterface> &getDialectInterfaces() { |
| 1964 | return interfaces; |
| 1965 | } |
| 1966 | |
| 1967 | /// Return the non-dialect resource printers. |
| 1968 | auto getResourcePrinters() { |
| 1969 | return llvm::make_pointee_range(Range&: externalResourcePrinters); |
| 1970 | } |
| 1971 | |
| 1972 | /// Get the printer flags. |
| 1973 | const OpPrintingFlags &getPrinterFlags() const { return printerFlags; } |
| 1974 | |
| 1975 | /// Register the location, line and column, within the buffer that the given |
| 1976 | /// operation was printed at. |
| 1977 | void registerOperationLocation(Operation *op, unsigned line, unsigned col) { |
| 1978 | if (locationMap) |
| 1979 | (*locationMap)[op] = std::make_pair(x&: line, y&: col); |
| 1980 | } |
| 1981 | |
| 1982 | /// Return the referenced dialect resources within the printer. |
| 1983 | DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> & |
| 1984 | getDialectResources() { |
| 1985 | return dialectResources; |
| 1986 | } |
| 1987 | |
| 1988 | LogicalResult pushCyclicPrinting(const void *opaquePointer) { |
| 1989 | return success(IsSuccess: cyclicPrintingStack.insert(X: opaquePointer)); |
| 1990 | } |
| 1991 | |
| 1992 | void popCyclicPrinting() { cyclicPrintingStack.pop_back(); } |
| 1993 | |
| 1994 | private: |
| 1995 | /// Collection of OpAsm interfaces implemented in the context. |
| 1996 | DialectInterfaceCollection<OpAsmDialectInterface> interfaces; |
| 1997 | |
| 1998 | /// A collection of non-dialect resource printers. |
| 1999 | SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters; |
| 2000 | |
| 2001 | /// A set of dialect resources that were referenced during printing. |
| 2002 | DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> dialectResources; |
| 2003 | |
| 2004 | /// The state used for attribute and type aliases. |
| 2005 | AliasState aliasState; |
| 2006 | |
| 2007 | /// The state used for SSA value names. |
| 2008 | SSANameState nameState; |
| 2009 | |
| 2010 | /// The state used for distinct attribute identifiers. |
| 2011 | DistinctState distinctState; |
| 2012 | |
| 2013 | /// Flags that control op output. |
| 2014 | OpPrintingFlags printerFlags; |
| 2015 | |
| 2016 | /// An optional location map to be populated. |
| 2017 | AsmState::LocationMap *locationMap; |
| 2018 | |
| 2019 | /// Stack of potentially cyclic mutable attributes or type currently being |
| 2020 | /// printed. |
| 2021 | SetVector<const void *> cyclicPrintingStack; |
| 2022 | |
| 2023 | // Allow direct access to the impl fields. |
| 2024 | friend AsmState; |
| 2025 | }; |
| 2026 | |
| 2027 | template <typename Range> |
| 2028 | void printDimensionList(raw_ostream &stream, Range &&shape) { |
| 2029 | llvm::interleave( |
| 2030 | shape, stream, |
| 2031 | [&stream](const auto &dimSize) { |
| 2032 | if (ShapedType::isDynamic(dimSize)) |
| 2033 | stream << "?" ; |
| 2034 | else |
| 2035 | stream << dimSize; |
| 2036 | }, |
| 2037 | "x" ); |
| 2038 | } |
| 2039 | |
| 2040 | } // namespace detail |
| 2041 | } // namespace mlir |
| 2042 | |
| 2043 | /// Verifies the operation and switches to generic op printing if verification |
| 2044 | /// fails. We need to do this because custom print functions may fail for |
| 2045 | /// invalid ops. |
| 2046 | static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op, |
| 2047 | OpPrintingFlags printerFlags) { |
| 2048 | if (printerFlags.shouldPrintGenericOpForm() || |
| 2049 | printerFlags.shouldAssumeVerified()) |
| 2050 | return printerFlags; |
| 2051 | |
| 2052 | // Ignore errors emitted by the verifier. We check the thread id to avoid |
| 2053 | // consuming other threads' errors. |
| 2054 | auto parentThreadId = llvm::get_threadid(); |
| 2055 | ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) { |
| 2056 | if (parentThreadId == llvm::get_threadid()) { |
| 2057 | LLVM_DEBUG({ |
| 2058 | diag.print(llvm::dbgs()); |
| 2059 | llvm::dbgs() << "\n" ; |
| 2060 | }); |
| 2061 | return success(); |
| 2062 | } |
| 2063 | return failure(); |
| 2064 | }); |
| 2065 | if (failed(Result: verify(op))) { |
| 2066 | LLVM_DEBUG(llvm::dbgs() |
| 2067 | << DEBUG_TYPE << ": '" << op->getName() |
| 2068 | << "' failed to verify and will be printed in generic form\n" ); |
| 2069 | printerFlags.printGenericOpForm(); |
| 2070 | } |
| 2071 | |
| 2072 | return printerFlags; |
| 2073 | } |
| 2074 | |
| 2075 | AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags, |
| 2076 | LocationMap *locationMap, FallbackAsmResourceMap *map) |
| 2077 | : impl(std::make_unique<AsmStateImpl>( |
| 2078 | args&: op, args: verifyOpAndAdjustFlags(op, printerFlags), args&: locationMap)) { |
| 2079 | if (map) |
| 2080 | attachFallbackResourcePrinter(map&: *map); |
| 2081 | } |
| 2082 | AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags, |
| 2083 | LocationMap *locationMap, FallbackAsmResourceMap *map) |
| 2084 | : impl(std::make_unique<AsmStateImpl>(args&: ctx, args: printerFlags, args&: locationMap)) { |
| 2085 | if (map) |
| 2086 | attachFallbackResourcePrinter(map&: *map); |
| 2087 | } |
| 2088 | AsmState::~AsmState() = default; |
| 2089 | |
| 2090 | const OpPrintingFlags &AsmState::getPrinterFlags() const { |
| 2091 | return impl->getPrinterFlags(); |
| 2092 | } |
| 2093 | |
| 2094 | void AsmState::attachResourcePrinter( |
| 2095 | std::unique_ptr<AsmResourcePrinter> printer) { |
| 2096 | impl->externalResourcePrinters.emplace_back(Args: std::move(printer)); |
| 2097 | } |
| 2098 | |
| 2099 | DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> & |
| 2100 | AsmState::getDialectResources() const { |
| 2101 | return impl->getDialectResources(); |
| 2102 | } |
| 2103 | |
| 2104 | //===----------------------------------------------------------------------===// |
| 2105 | // AsmPrinter::Impl |
| 2106 | //===----------------------------------------------------------------------===// |
| 2107 | |
| 2108 | AsmPrinter::Impl::Impl(raw_ostream &os, AsmStateImpl &state) |
| 2109 | : os(os), state(state), printerFlags(state.getPrinterFlags()) {} |
| 2110 | |
| 2111 | void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) { |
| 2112 | // Check to see if we are printing debug information. |
| 2113 | if (!printerFlags.shouldPrintDebugInfo()) |
| 2114 | return; |
| 2115 | |
| 2116 | os << " " ; |
| 2117 | printLocation(loc, /*allowAlias=*/allowAlias); |
| 2118 | } |
| 2119 | |
| 2120 | void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty, |
| 2121 | bool isTopLevel) { |
| 2122 | // If this isn't a top-level location, check for an alias. |
| 2123 | if (!isTopLevel && succeeded(Result: state.getAliasState().getAlias(attr: loc, os))) |
| 2124 | return; |
| 2125 | |
| 2126 | TypeSwitch<LocationAttr>(loc) |
| 2127 | .Case<OpaqueLoc>([&](OpaqueLoc loc) { |
| 2128 | printLocationInternal(loc.getFallbackLocation(), pretty); |
| 2129 | }) |
| 2130 | .Case<UnknownLoc>([&](UnknownLoc loc) { |
| 2131 | if (pretty) |
| 2132 | os << "[unknown]" ; |
| 2133 | else |
| 2134 | os << "unknown" ; |
| 2135 | }) |
| 2136 | .Case<FileLineColRange>([&](FileLineColRange loc) { |
| 2137 | if (pretty) |
| 2138 | os << loc.getFilename().getValue(); |
| 2139 | else |
| 2140 | printEscapedString(loc.getFilename()); |
| 2141 | if (loc.getEndColumn() == loc.getStartColumn() && |
| 2142 | loc.getStartLine() == loc.getEndLine()) { |
| 2143 | os << ':' << loc.getStartLine() << ':' << loc.getStartColumn(); |
| 2144 | return; |
| 2145 | } |
| 2146 | if (loc.getStartLine() == loc.getEndLine()) { |
| 2147 | os << ':' << loc.getStartLine() << ':' << loc.getStartColumn() |
| 2148 | << " to :" << loc.getEndColumn(); |
| 2149 | return; |
| 2150 | } |
| 2151 | os << ':' << loc.getStartLine() << ':' << loc.getStartColumn() << " to " |
| 2152 | << loc.getEndLine() << ':' << loc.getEndColumn(); |
| 2153 | }) |
| 2154 | .Case<NameLoc>([&](NameLoc loc) { |
| 2155 | printEscapedString(loc.getName()); |
| 2156 | |
| 2157 | // Print the child if it isn't unknown. |
| 2158 | auto childLoc = loc.getChildLoc(); |
| 2159 | if (!llvm::isa<UnknownLoc>(childLoc)) { |
| 2160 | os << '('; |
| 2161 | printLocationInternal(childLoc, pretty); |
| 2162 | os << ')'; |
| 2163 | } |
| 2164 | }) |
| 2165 | .Case<CallSiteLoc>([&](CallSiteLoc loc) { |
| 2166 | Location caller = loc.getCaller(); |
| 2167 | Location callee = loc.getCallee(); |
| 2168 | if (!pretty) |
| 2169 | os << "callsite(" ; |
| 2170 | printLocationInternal(callee, pretty); |
| 2171 | if (pretty) { |
| 2172 | if (llvm::isa<NameLoc>(callee)) { |
| 2173 | if (llvm::isa<FileLineColLoc>(caller)) { |
| 2174 | os << " at " ; |
| 2175 | } else { |
| 2176 | os << newLine << " at " ; |
| 2177 | } |
| 2178 | } else { |
| 2179 | os << newLine << " at " ; |
| 2180 | } |
| 2181 | } else { |
| 2182 | os << " at " ; |
| 2183 | } |
| 2184 | printLocationInternal(caller, pretty); |
| 2185 | if (!pretty) |
| 2186 | os << ")" ; |
| 2187 | }) |
| 2188 | .Case<FusedLoc>([&](FusedLoc loc) { |
| 2189 | if (!pretty) |
| 2190 | os << "fused" ; |
| 2191 | if (Attribute metadata = loc.getMetadata()) { |
| 2192 | os << '<'; |
| 2193 | printAttribute(metadata); |
| 2194 | os << '>'; |
| 2195 | } |
| 2196 | os << '['; |
| 2197 | interleave( |
| 2198 | loc.getLocations(), |
| 2199 | [&](Location loc) { printLocationInternal(loc, pretty); }, |
| 2200 | [&]() { os << ", " ; }); |
| 2201 | os << ']'; |
| 2202 | }) |
| 2203 | .Default([&](LocationAttr loc) { |
| 2204 | // Assumes that this is a dialect-specific attribute and prints it |
| 2205 | // directly. |
| 2206 | printAttribute(loc); |
| 2207 | }); |
| 2208 | } |
| 2209 | |
| 2210 | /// Print a floating point value in a way that the parser will be able to |
| 2211 | /// round-trip losslessly. |
| 2212 | static void printFloatValue(const APFloat &apValue, raw_ostream &os, |
| 2213 | bool *printedHex = nullptr) { |
| 2214 | // We would like to output the FP constant value in exponential notation, |
| 2215 | // but we cannot do this if doing so will lose precision. Check here to |
| 2216 | // make sure that we only output it in exponential format if we can parse |
| 2217 | // the value back and get the same value. |
| 2218 | bool isInf = apValue.isInfinity(); |
| 2219 | bool isNaN = apValue.isNaN(); |
| 2220 | if (!isInf && !isNaN) { |
| 2221 | SmallString<128> strValue; |
| 2222 | apValue.toString(Str&: strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0, |
| 2223 | /*TruncateZero=*/false); |
| 2224 | |
| 2225 | // Check to make sure that the stringized number is not some string like |
| 2226 | // "Inf" or NaN, that atof will accept, but the lexer will not. Check |
| 2227 | // that the string matches the "[-+]?[0-9]" regex. |
| 2228 | assert(((strValue[0] >= '0' && strValue[0] <= '9') || |
| 2229 | ((strValue[0] == '-' || strValue[0] == '+') && |
| 2230 | (strValue[1] >= '0' && strValue[1] <= '9'))) && |
| 2231 | "[-+]?[0-9] regex does not match!" ); |
| 2232 | |
| 2233 | // Parse back the stringized version and check that the value is equal |
| 2234 | // (i.e., there is no precision loss). |
| 2235 | if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(RHS: apValue)) { |
| 2236 | os << strValue; |
| 2237 | return; |
| 2238 | } |
| 2239 | |
| 2240 | // If it is not, use the default format of APFloat instead of the |
| 2241 | // exponential notation. |
| 2242 | strValue.clear(); |
| 2243 | apValue.toString(Str&: strValue); |
| 2244 | |
| 2245 | // Make sure that we can parse the default form as a float. |
| 2246 | if (strValue.str().contains(C: '.')) { |
| 2247 | os << strValue; |
| 2248 | return; |
| 2249 | } |
| 2250 | } |
| 2251 | |
| 2252 | // Print special values in hexadecimal format. The sign bit should be included |
| 2253 | // in the literal. |
| 2254 | if (printedHex) |
| 2255 | *printedHex = true; |
| 2256 | SmallVector<char, 16> str; |
| 2257 | APInt apInt = apValue.bitcastToAPInt(); |
| 2258 | apInt.toString(Str&: str, /*Radix=*/16, /*Signed=*/false, |
| 2259 | /*formatAsCLiteral=*/true); |
| 2260 | os << str; |
| 2261 | } |
| 2262 | |
| 2263 | void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) { |
| 2264 | if (printerFlags.shouldPrintDebugInfoPrettyForm()) |
| 2265 | return printLocationInternal(loc, /*pretty=*/true, /*isTopLevel=*/true); |
| 2266 | |
| 2267 | os << "loc(" ; |
| 2268 | if (!allowAlias || failed(Result: printAlias(attr: loc))) |
| 2269 | printLocationInternal(loc, /*pretty=*/false, /*isTopLevel=*/true); |
| 2270 | os << ')'; |
| 2271 | } |
| 2272 | |
| 2273 | /// Returns true if the given dialect symbol data is simple enough to print in |
| 2274 | /// the pretty form. This is essentially when the symbol takes the form: |
| 2275 | /// identifier (`<` body `>`)? |
| 2276 | static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) { |
| 2277 | // The name must start with an identifier. |
| 2278 | if (symName.empty() || !isalpha(symName.front())) |
| 2279 | return false; |
| 2280 | |
| 2281 | // Ignore all the characters that are valid in an identifier in the symbol |
| 2282 | // name. |
| 2283 | symName = symName.drop_while( |
| 2284 | F: [](char c) { return llvm::isAlnum(C: c) || c == '.' || c == '_'; }); |
| 2285 | if (symName.empty()) |
| 2286 | return true; |
| 2287 | |
| 2288 | // If we got to an unexpected character, then it must be a <>. Check that the |
| 2289 | // rest of the symbol is wrapped within <>. |
| 2290 | return symName.front() == '<' && symName.back() == '>'; |
| 2291 | } |
| 2292 | |
| 2293 | /// Print the given dialect symbol to the stream. |
| 2294 | static void printDialectSymbol(raw_ostream &os, StringRef symPrefix, |
| 2295 | StringRef dialectName, StringRef symString) { |
| 2296 | os << symPrefix << dialectName; |
| 2297 | |
| 2298 | // If this symbol name is simple enough, print it directly in pretty form, |
| 2299 | // otherwise, we print it as an escaped string. |
| 2300 | if (isDialectSymbolSimpleEnoughForPrettyForm(symName: symString)) { |
| 2301 | os << '.' << symString; |
| 2302 | return; |
| 2303 | } |
| 2304 | |
| 2305 | os << '<' << symString << '>'; |
| 2306 | } |
| 2307 | |
| 2308 | /// Returns true if the given string can be represented as a bare identifier. |
| 2309 | static bool isBareIdentifier(StringRef name) { |
| 2310 | // By making this unsigned, the value passed in to isalnum will always be |
| 2311 | // in the range 0-255. This is important when building with MSVC because |
| 2312 | // its implementation will assert. This situation can arise when dealing |
| 2313 | // with UTF-8 multibyte characters. |
| 2314 | if (name.empty() || (!isalpha(name[0]) && name[0] != '_')) |
| 2315 | return false; |
| 2316 | return llvm::all_of(Range: name.drop_front(), P: [](unsigned char c) { |
| 2317 | return isalnum(c) || c == '_' || c == '$' || c == '.'; |
| 2318 | }); |
| 2319 | } |
| 2320 | |
| 2321 | /// Print the given string as a keyword, or a quoted and escaped string if it |
| 2322 | /// has any special or non-printable characters in it. |
| 2323 | static void printKeywordOrString(StringRef keyword, raw_ostream &os) { |
| 2324 | // If it can be represented as a bare identifier, write it directly. |
| 2325 | if (isBareIdentifier(name: keyword)) { |
| 2326 | os << keyword; |
| 2327 | return; |
| 2328 | } |
| 2329 | |
| 2330 | // Otherwise, output the keyword wrapped in quotes with proper escaping. |
| 2331 | os << "\"" ; |
| 2332 | printEscapedString(Name: keyword, Out&: os); |
| 2333 | os << '"'; |
| 2334 | } |
| 2335 | |
| 2336 | /// Print the given string as a symbol reference. A symbol reference is |
| 2337 | /// represented as a string prefixed with '@'. The reference is surrounded with |
| 2338 | /// ""'s and escaped if it has any special or non-printable characters in it. |
| 2339 | static void printSymbolReference(StringRef symbolRef, raw_ostream &os) { |
| 2340 | if (symbolRef.empty()) { |
| 2341 | os << "@<<INVALID EMPTY SYMBOL>>" ; |
| 2342 | return; |
| 2343 | } |
| 2344 | os << '@'; |
| 2345 | printKeywordOrString(keyword: symbolRef, os); |
| 2346 | } |
| 2347 | |
| 2348 | // Print out a valid ElementsAttr that is succinct and can represent any |
| 2349 | // potential shape/type, for use when eliding a large ElementsAttr. |
| 2350 | // |
| 2351 | // We choose to use a dense resource ElementsAttr literal with conspicuous |
| 2352 | // content to hopefully alert readers to the fact that this has been elided. |
| 2353 | static void printElidedElementsAttr(raw_ostream &os) { |
| 2354 | os << R"(dense_resource<__elided__>)" ; |
| 2355 | } |
| 2356 | |
| 2357 | void AsmPrinter::Impl::printResourceHandle( |
| 2358 | const AsmDialectResourceHandle &resource) { |
| 2359 | auto *interface = cast<OpAsmDialectInterface>(Val: resource.getDialect()); |
| 2360 | ::printKeywordOrString(keyword: interface->getResourceKey(handle: resource), os); |
| 2361 | state.getDialectResources()[resource.getDialect()].insert(X: resource); |
| 2362 | } |
| 2363 | |
| 2364 | LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) { |
| 2365 | return state.getAliasState().getAlias(attr, os); |
| 2366 | } |
| 2367 | |
| 2368 | LogicalResult AsmPrinter::Impl::printAlias(Type type) { |
| 2369 | return state.getAliasState().getAlias(ty: type, os); |
| 2370 | } |
| 2371 | |
| 2372 | void AsmPrinter::Impl::printAttribute(Attribute attr, |
| 2373 | AttrTypeElision typeElision) { |
| 2374 | if (!attr) { |
| 2375 | os << "<<NULL ATTRIBUTE>>" ; |
| 2376 | return; |
| 2377 | } |
| 2378 | |
| 2379 | // Try to print an alias for this attribute. |
| 2380 | if (succeeded(Result: printAlias(attr))) |
| 2381 | return; |
| 2382 | return printAttributeImpl(attr, typeElision); |
| 2383 | } |
| 2384 | |
| 2385 | void AsmPrinter::Impl::printAttributeImpl(Attribute attr, |
| 2386 | AttrTypeElision typeElision) { |
| 2387 | if (!isa<BuiltinDialect>(Val: attr.getDialect())) { |
| 2388 | printDialectAttribute(attr); |
| 2389 | } else if (auto opaqueAttr = llvm::dyn_cast<OpaqueAttr>(attr)) { |
| 2390 | printDialectSymbol(os, "#" , opaqueAttr.getDialectNamespace(), |
| 2391 | opaqueAttr.getAttrData()); |
| 2392 | } else if (llvm::isa<UnitAttr>(Val: attr)) { |
| 2393 | os << "unit" ; |
| 2394 | return; |
| 2395 | } else if (auto distinctAttr = llvm::dyn_cast<DistinctAttr>(attr)) { |
| 2396 | os << "distinct[" << state.getDistinctState().getId(distinctAttr) << "]<" ; |
| 2397 | if (!llvm::isa<UnitAttr>(Val: distinctAttr.getReferencedAttr())) { |
| 2398 | printAttribute(attr: distinctAttr.getReferencedAttr()); |
| 2399 | } |
| 2400 | os << '>'; |
| 2401 | return; |
| 2402 | } else if (auto dictAttr = llvm::dyn_cast<DictionaryAttr>(attr)) { |
| 2403 | os << '{'; |
| 2404 | interleaveComma(dictAttr.getValue(), |
| 2405 | [&](NamedAttribute attr) { printNamedAttribute(attr); }); |
| 2406 | os << '}'; |
| 2407 | |
| 2408 | } else if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) { |
| 2409 | Type intType = intAttr.getType(); |
| 2410 | if (intType.isSignlessInteger(width: 1)) { |
| 2411 | os << (intAttr.getValue().getBoolValue() ? "true" : "false" ); |
| 2412 | |
| 2413 | // Boolean integer attributes always elides the type. |
| 2414 | return; |
| 2415 | } |
| 2416 | |
| 2417 | // Only print attributes as unsigned if they are explicitly unsigned or are |
| 2418 | // signless 1-bit values. Indexes, signed values, and multi-bit signless |
| 2419 | // values print as signed. |
| 2420 | bool isUnsigned = |
| 2421 | intType.isUnsignedInteger() || intType.isSignlessInteger(width: 1); |
| 2422 | intAttr.getValue().print(os, !isUnsigned); |
| 2423 | |
| 2424 | // IntegerAttr elides the type if I64. |
| 2425 | if (typeElision == AttrTypeElision::May && intType.isSignlessInteger(width: 64)) |
| 2426 | return; |
| 2427 | |
| 2428 | } else if (auto floatAttr = llvm::dyn_cast<FloatAttr>(attr)) { |
| 2429 | bool printedHex = false; |
| 2430 | printFloatValue(floatAttr.getValue(), os, &printedHex); |
| 2431 | |
| 2432 | // FloatAttr elides the type if F64. |
| 2433 | if (typeElision == AttrTypeElision::May && floatAttr.getType().isF64() && |
| 2434 | !printedHex) |
| 2435 | return; |
| 2436 | |
| 2437 | } else if (auto strAttr = llvm::dyn_cast<StringAttr>(attr)) { |
| 2438 | printEscapedString(str: strAttr.getValue()); |
| 2439 | |
| 2440 | } else if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr)) { |
| 2441 | os << '['; |
| 2442 | interleaveComma(arrayAttr.getValue(), [&](Attribute attr) { |
| 2443 | printAttribute(attr, typeElision: AttrTypeElision::May); |
| 2444 | }); |
| 2445 | os << ']'; |
| 2446 | |
| 2447 | } else if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(attr)) { |
| 2448 | os << "affine_map<" ; |
| 2449 | affineMapAttr.getValue().print(os); |
| 2450 | os << '>'; |
| 2451 | |
| 2452 | // AffineMap always elides the type. |
| 2453 | return; |
| 2454 | |
| 2455 | } else if (auto integerSetAttr = llvm::dyn_cast<IntegerSetAttr>(attr)) { |
| 2456 | os << "affine_set<" ; |
| 2457 | integerSetAttr.getValue().print(os); |
| 2458 | os << '>'; |
| 2459 | |
| 2460 | // IntegerSet always elides the type. |
| 2461 | return; |
| 2462 | |
| 2463 | } else if (auto typeAttr = llvm::dyn_cast<TypeAttr>(attr)) { |
| 2464 | printType(type: typeAttr.getValue()); |
| 2465 | |
| 2466 | } else if (auto refAttr = llvm::dyn_cast<SymbolRefAttr>(attr)) { |
| 2467 | printSymbolReference(refAttr.getRootReference().getValue(), os); |
| 2468 | for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) { |
| 2469 | os << "::" ; |
| 2470 | printSymbolReference(nestedRef.getValue(), os); |
| 2471 | } |
| 2472 | |
| 2473 | } else if (auto intOrFpEltAttr = |
| 2474 | llvm::dyn_cast<DenseIntOrFPElementsAttr>(attr)) { |
| 2475 | if (printerFlags.shouldElideElementsAttr(attr: intOrFpEltAttr)) { |
| 2476 | printElidedElementsAttr(os); |
| 2477 | } else { |
| 2478 | os << "dense<" ; |
| 2479 | printDenseIntOrFPElementsAttr(attr: intOrFpEltAttr, /*allowHex=*/true); |
| 2480 | os << '>'; |
| 2481 | } |
| 2482 | |
| 2483 | } else if (auto strEltAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr)) { |
| 2484 | if (printerFlags.shouldElideElementsAttr(attr: strEltAttr)) { |
| 2485 | printElidedElementsAttr(os); |
| 2486 | } else { |
| 2487 | os << "dense<" ; |
| 2488 | printDenseStringElementsAttr(attr: strEltAttr); |
| 2489 | os << '>'; |
| 2490 | } |
| 2491 | |
| 2492 | } else if (auto sparseEltAttr = llvm::dyn_cast<SparseElementsAttr>(attr)) { |
| 2493 | if (printerFlags.shouldElideElementsAttr(attr: sparseEltAttr.getIndices()) || |
| 2494 | printerFlags.shouldElideElementsAttr(attr: sparseEltAttr.getValues())) { |
| 2495 | printElidedElementsAttr(os); |
| 2496 | } else { |
| 2497 | os << "sparse<" ; |
| 2498 | DenseIntElementsAttr indices = sparseEltAttr.getIndices(); |
| 2499 | if (indices.getNumElements() != 0) { |
| 2500 | printDenseIntOrFPElementsAttr(attr: indices, /*allowHex=*/false); |
| 2501 | os << ", " ; |
| 2502 | printDenseElementsAttr(attr: sparseEltAttr.getValues(), /*allowHex=*/true); |
| 2503 | } |
| 2504 | os << '>'; |
| 2505 | } |
| 2506 | } else if (auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(attr)) { |
| 2507 | stridedLayoutAttr.print(os); |
| 2508 | } else if (auto denseArrayAttr = llvm::dyn_cast<DenseArrayAttr>(attr)) { |
| 2509 | os << "array<" ; |
| 2510 | printType(type: denseArrayAttr.getElementType()); |
| 2511 | if (!denseArrayAttr.empty()) { |
| 2512 | os << ": " ; |
| 2513 | printDenseArrayAttr(attr: denseArrayAttr); |
| 2514 | } |
| 2515 | os << ">" ; |
| 2516 | return; |
| 2517 | } else if (auto resourceAttr = |
| 2518 | llvm::dyn_cast<DenseResourceElementsAttr>(attr)) { |
| 2519 | os << "dense_resource<" ; |
| 2520 | printResourceHandle(resource: resourceAttr.getRawHandle()); |
| 2521 | os << ">" ; |
| 2522 | } else if (auto locAttr = llvm::dyn_cast<LocationAttr>(Val&: attr)) { |
| 2523 | printLocation(loc: locAttr); |
| 2524 | } else { |
| 2525 | llvm::report_fatal_error(reason: "Unknown builtin attribute" ); |
| 2526 | } |
| 2527 | // Don't print the type if we must elide it, or if it is a None type. |
| 2528 | if (typeElision != AttrTypeElision::Must) { |
| 2529 | if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) { |
| 2530 | Type attrType = typedAttr.getType(); |
| 2531 | if (!llvm::isa<NoneType>(Val: attrType)) { |
| 2532 | os << " : " ; |
| 2533 | printType(type: attrType); |
| 2534 | } |
| 2535 | } |
| 2536 | } |
| 2537 | } |
| 2538 | |
| 2539 | /// Print the integer element of a DenseElementsAttr. |
| 2540 | static void printDenseIntElement(const APInt &value, raw_ostream &os, |
| 2541 | Type type) { |
| 2542 | if (type.isInteger(width: 1)) |
| 2543 | os << (value.getBoolValue() ? "true" : "false" ); |
| 2544 | else |
| 2545 | value.print(OS&: os, isSigned: !type.isUnsignedInteger()); |
| 2546 | } |
| 2547 | |
| 2548 | static void |
| 2549 | printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os, |
| 2550 | function_ref<void(unsigned)> printEltFn) { |
| 2551 | // Special case for 0-d and splat tensors. |
| 2552 | if (isSplat) |
| 2553 | return printEltFn(0); |
| 2554 | |
| 2555 | // Special case for degenerate tensors. |
| 2556 | auto numElements = type.getNumElements(); |
| 2557 | if (numElements == 0) |
| 2558 | return; |
| 2559 | |
| 2560 | // We use a mixed-radix counter to iterate through the shape. When we bump a |
| 2561 | // non-least-significant digit, we emit a close bracket. When we next emit an |
| 2562 | // element we re-open all closed brackets. |
| 2563 | |
| 2564 | // The mixed-radix counter, with radices in 'shape'. |
| 2565 | int64_t rank = type.getRank(); |
| 2566 | SmallVector<unsigned, 4> counter(rank, 0); |
| 2567 | // The number of brackets that have been opened and not closed. |
| 2568 | unsigned openBrackets = 0; |
| 2569 | |
| 2570 | auto shape = type.getShape(); |
| 2571 | auto bumpCounter = [&] { |
| 2572 | // Bump the least significant digit. |
| 2573 | ++counter[rank - 1]; |
| 2574 | // Iterate backwards bubbling back the increment. |
| 2575 | for (unsigned i = rank - 1; i > 0; --i) |
| 2576 | if (counter[i] >= shape[i]) { |
| 2577 | // Index 'i' is rolled over. Bump (i-1) and close a bracket. |
| 2578 | counter[i] = 0; |
| 2579 | ++counter[i - 1]; |
| 2580 | --openBrackets; |
| 2581 | os << ']'; |
| 2582 | } |
| 2583 | }; |
| 2584 | |
| 2585 | for (unsigned idx = 0, e = numElements; idx != e; ++idx) { |
| 2586 | if (idx != 0) |
| 2587 | os << ", " ; |
| 2588 | while (openBrackets++ < rank) |
| 2589 | os << '['; |
| 2590 | openBrackets = rank; |
| 2591 | printEltFn(idx); |
| 2592 | bumpCounter(); |
| 2593 | } |
| 2594 | while (openBrackets-- > 0) |
| 2595 | os << ']'; |
| 2596 | } |
| 2597 | |
| 2598 | void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr, |
| 2599 | bool allowHex) { |
| 2600 | if (auto stringAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr)) |
| 2601 | return printDenseStringElementsAttr(attr: stringAttr); |
| 2602 | |
| 2603 | printDenseIntOrFPElementsAttr(llvm::cast<DenseIntOrFPElementsAttr>(attr), |
| 2604 | allowHex); |
| 2605 | } |
| 2606 | |
| 2607 | void AsmPrinter::Impl::printDenseIntOrFPElementsAttr( |
| 2608 | DenseIntOrFPElementsAttr attr, bool allowHex) { |
| 2609 | auto type = attr.getType(); |
| 2610 | auto elementType = type.getElementType(); |
| 2611 | |
| 2612 | // Check to see if we should format this attribute as a hex string. |
| 2613 | if (allowHex && printerFlags.shouldPrintElementsAttrWithHex(attr: attr)) { |
| 2614 | ArrayRef<char> rawData = attr.getRawData(); |
| 2615 | if (llvm::endianness::native == llvm::endianness::big) { |
| 2616 | // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE |
| 2617 | // machines. It is converted here to print in LE format. |
| 2618 | SmallVector<char, 64> outDataVec(rawData.size()); |
| 2619 | MutableArrayRef<char> convRawData(outDataVec); |
| 2620 | DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( |
| 2621 | rawData, convRawData, type); |
| 2622 | printHexString(data: convRawData); |
| 2623 | } else { |
| 2624 | printHexString(data: rawData); |
| 2625 | } |
| 2626 | |
| 2627 | return; |
| 2628 | } |
| 2629 | |
| 2630 | if (ComplexType complexTy = llvm::dyn_cast<ComplexType>(elementType)) { |
| 2631 | Type complexElementType = complexTy.getElementType(); |
| 2632 | // Note: The if and else below had a common lambda function which invoked |
| 2633 | // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2 |
| 2634 | // and hence was replaced. |
| 2635 | if (llvm::isa<IntegerType>(Val: complexElementType)) { |
| 2636 | auto valueIt = attr.value_begin<std::complex<APInt>>(); |
| 2637 | printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { |
| 2638 | auto complexValue = *(valueIt + index); |
| 2639 | os << "(" ; |
| 2640 | printDenseIntElement(complexValue.real(), os, complexElementType); |
| 2641 | os << "," ; |
| 2642 | printDenseIntElement(complexValue.imag(), os, complexElementType); |
| 2643 | os << ")" ; |
| 2644 | }); |
| 2645 | } else { |
| 2646 | auto valueIt = attr.value_begin<std::complex<APFloat>>(); |
| 2647 | printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { |
| 2648 | auto complexValue = *(valueIt + index); |
| 2649 | os << "(" ; |
| 2650 | printFloatValue(complexValue.real(), os); |
| 2651 | os << "," ; |
| 2652 | printFloatValue(complexValue.imag(), os); |
| 2653 | os << ")" ; |
| 2654 | }); |
| 2655 | } |
| 2656 | } else if (elementType.isIntOrIndex()) { |
| 2657 | auto valueIt = attr.value_begin<APInt>(); |
| 2658 | printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { |
| 2659 | printDenseIntElement(*(valueIt + index), os, elementType); |
| 2660 | }); |
| 2661 | } else { |
| 2662 | assert(llvm::isa<FloatType>(elementType) && "unexpected element type" ); |
| 2663 | auto valueIt = attr.value_begin<APFloat>(); |
| 2664 | printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { |
| 2665 | printFloatValue(*(valueIt + index), os); |
| 2666 | }); |
| 2667 | } |
| 2668 | } |
| 2669 | |
| 2670 | void AsmPrinter::Impl::printDenseStringElementsAttr( |
| 2671 | DenseStringElementsAttr attr) { |
| 2672 | ArrayRef<StringRef> data = attr.getRawStringData(); |
| 2673 | auto printFn = [&](unsigned index) { printEscapedString(str: data[index]); }; |
| 2674 | printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn); |
| 2675 | } |
| 2676 | |
| 2677 | void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) { |
| 2678 | Type type = attr.getElementType(); |
| 2679 | unsigned bitwidth = type.isInteger(width: 1) ? 8 : type.getIntOrFloatBitWidth(); |
| 2680 | unsigned byteSize = bitwidth / 8; |
| 2681 | ArrayRef<char> data = attr.getRawData(); |
| 2682 | |
| 2683 | auto printElementAt = [&](unsigned i) { |
| 2684 | APInt value(bitwidth, 0); |
| 2685 | if (bitwidth) { |
| 2686 | llvm::LoadIntFromMemory( |
| 2687 | IntVal&: value, Src: reinterpret_cast<const uint8_t *>(data.begin() + byteSize * i), |
| 2688 | LoadBytes: byteSize); |
| 2689 | } |
| 2690 | // Print the data as-is or as a float. |
| 2691 | if (type.isIntOrIndex()) { |
| 2692 | printDenseIntElement(value, os&: getStream(), type); |
| 2693 | } else { |
| 2694 | APFloat fltVal(llvm::cast<FloatType>(type).getFloatSemantics(), value); |
| 2695 | printFloatValue(apValue: fltVal, os&: getStream()); |
| 2696 | } |
| 2697 | }; |
| 2698 | llvm::interleaveComma(llvm::seq<unsigned>(0, attr.size()), getStream(), |
| 2699 | printElementAt); |
| 2700 | } |
| 2701 | |
| 2702 | void AsmPrinter::Impl::printType(Type type) { |
| 2703 | if (!type) { |
| 2704 | os << "<<NULL TYPE>>" ; |
| 2705 | return; |
| 2706 | } |
| 2707 | |
| 2708 | // Try to print an alias for this type. |
| 2709 | if (succeeded(Result: printAlias(type))) |
| 2710 | return; |
| 2711 | return printTypeImpl(type); |
| 2712 | } |
| 2713 | |
| 2714 | void AsmPrinter::Impl::printTypeImpl(Type type) { |
| 2715 | TypeSwitch<Type>(type) |
| 2716 | .Case<OpaqueType>([&](OpaqueType opaqueTy) { |
| 2717 | printDialectSymbol(os, "!" , opaqueTy.getDialectNamespace(), |
| 2718 | opaqueTy.getTypeData()); |
| 2719 | }) |
| 2720 | .Case<IndexType>([&](Type) { os << "index" ; }) |
| 2721 | .Case<Float4E2M1FNType>([&](Type) { os << "f4E2M1FN" ; }) |
| 2722 | .Case<Float6E2M3FNType>([&](Type) { os << "f6E2M3FN" ; }) |
| 2723 | .Case<Float6E3M2FNType>([&](Type) { os << "f6E3M2FN" ; }) |
| 2724 | .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2" ; }) |
| 2725 | .Case<Float8E4M3Type>([&](Type) { os << "f8E4M3" ; }) |
| 2726 | .Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN" ; }) |
| 2727 | .Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ" ; }) |
| 2728 | .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ" ; }) |
| 2729 | .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ" ; }) |
| 2730 | .Case<Float8E3M4Type>([&](Type) { os << "f8E3M4" ; }) |
| 2731 | .Case<Float8E8M0FNUType>([&](Type) { os << "f8E8M0FNU" ; }) |
| 2732 | .Case<BFloat16Type>([&](Type) { os << "bf16" ; }) |
| 2733 | .Case<Float16Type>([&](Type) { os << "f16" ; }) |
| 2734 | .Case<FloatTF32Type>([&](Type) { os << "tf32" ; }) |
| 2735 | .Case<Float32Type>([&](Type) { os << "f32" ; }) |
| 2736 | .Case<Float64Type>([&](Type) { os << "f64" ; }) |
| 2737 | .Case<Float80Type>([&](Type) { os << "f80" ; }) |
| 2738 | .Case<Float128Type>([&](Type) { os << "f128" ; }) |
| 2739 | .Case<IntegerType>([&](IntegerType integerTy) { |
| 2740 | if (integerTy.isSigned()) |
| 2741 | os << 's'; |
| 2742 | else if (integerTy.isUnsigned()) |
| 2743 | os << 'u'; |
| 2744 | os << 'i' << integerTy.getWidth(); |
| 2745 | }) |
| 2746 | .Case<FunctionType>([&](FunctionType funcTy) { |
| 2747 | os << '('; |
| 2748 | interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); }); |
| 2749 | os << ") -> " ; |
| 2750 | ArrayRef<Type> results = funcTy.getResults(); |
| 2751 | if (results.size() == 1 && !llvm::isa<FunctionType>(results[0])) { |
| 2752 | printType(results[0]); |
| 2753 | } else { |
| 2754 | os << '('; |
| 2755 | interleaveComma(results, [&](Type ty) { printType(ty); }); |
| 2756 | os << ')'; |
| 2757 | } |
| 2758 | }) |
| 2759 | .Case<VectorType>([&](VectorType vectorTy) { |
| 2760 | auto scalableDims = vectorTy.getScalableDims(); |
| 2761 | os << "vector<" ; |
| 2762 | auto vShape = vectorTy.getShape(); |
| 2763 | unsigned lastDim = vShape.size(); |
| 2764 | unsigned dimIdx = 0; |
| 2765 | for (dimIdx = 0; dimIdx < lastDim; dimIdx++) { |
| 2766 | if (!scalableDims.empty() && scalableDims[dimIdx]) |
| 2767 | os << '['; |
| 2768 | os << vShape[dimIdx]; |
| 2769 | if (!scalableDims.empty() && scalableDims[dimIdx]) |
| 2770 | os << ']'; |
| 2771 | os << 'x'; |
| 2772 | } |
| 2773 | printType(vectorTy.getElementType()); |
| 2774 | os << '>'; |
| 2775 | }) |
| 2776 | .Case<RankedTensorType>([&](RankedTensorType tensorTy) { |
| 2777 | os << "tensor<" ; |
| 2778 | printDimensionList(tensorTy.getShape()); |
| 2779 | if (!tensorTy.getShape().empty()) |
| 2780 | os << 'x'; |
| 2781 | printType(tensorTy.getElementType()); |
| 2782 | // Only print the encoding attribute value if set. |
| 2783 | if (tensorTy.getEncoding()) { |
| 2784 | os << ", " ; |
| 2785 | printAttribute(tensorTy.getEncoding()); |
| 2786 | } |
| 2787 | os << '>'; |
| 2788 | }) |
| 2789 | .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) { |
| 2790 | os << "tensor<*x" ; |
| 2791 | printType(tensorTy.getElementType()); |
| 2792 | os << '>'; |
| 2793 | }) |
| 2794 | .Case<MemRefType>([&](MemRefType memrefTy) { |
| 2795 | os << "memref<" ; |
| 2796 | printDimensionList(memrefTy.getShape()); |
| 2797 | if (!memrefTy.getShape().empty()) |
| 2798 | os << 'x'; |
| 2799 | printType(memrefTy.getElementType()); |
| 2800 | MemRefLayoutAttrInterface layout = memrefTy.getLayout(); |
| 2801 | if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) { |
| 2802 | os << ", " ; |
| 2803 | printAttribute(memrefTy.getLayout(), AttrTypeElision::May); |
| 2804 | } |
| 2805 | // Only print the memory space if it is the non-default one. |
| 2806 | if (memrefTy.getMemorySpace()) { |
| 2807 | os << ", " ; |
| 2808 | printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); |
| 2809 | } |
| 2810 | os << '>'; |
| 2811 | }) |
| 2812 | .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) { |
| 2813 | os << "memref<*x" ; |
| 2814 | printType(memrefTy.getElementType()); |
| 2815 | // Only print the memory space if it is the non-default one. |
| 2816 | if (memrefTy.getMemorySpace()) { |
| 2817 | os << ", " ; |
| 2818 | printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); |
| 2819 | } |
| 2820 | os << '>'; |
| 2821 | }) |
| 2822 | .Case<ComplexType>([&](ComplexType complexTy) { |
| 2823 | os << "complex<" ; |
| 2824 | printType(complexTy.getElementType()); |
| 2825 | os << '>'; |
| 2826 | }) |
| 2827 | .Case<TupleType>([&](TupleType tupleTy) { |
| 2828 | os << "tuple<" ; |
| 2829 | interleaveComma(tupleTy.getTypes(), |
| 2830 | [&](Type type) { printType(type); }); |
| 2831 | os << '>'; |
| 2832 | }) |
| 2833 | .Case<NoneType>([&](Type) { os << "none" ; }) |
| 2834 | .Default([&](Type type) { return printDialectType(type); }); |
| 2835 | } |
| 2836 | |
| 2837 | void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
| 2838 | ArrayRef<StringRef> elidedAttrs, |
| 2839 | bool withKeyword) { |
| 2840 | // If there are no attributes, then there is nothing to be done. |
| 2841 | if (attrs.empty()) |
| 2842 | return; |
| 2843 | |
| 2844 | // Functor used to print a filtered attribute list. |
| 2845 | auto printFilteredAttributesFn = [&](auto filteredAttrs) { |
| 2846 | // Print the 'attributes' keyword if necessary. |
| 2847 | if (withKeyword) |
| 2848 | os << " attributes" ; |
| 2849 | |
| 2850 | // Otherwise, print them all out in braces. |
| 2851 | os << " {" ; |
| 2852 | interleaveComma(filteredAttrs, |
| 2853 | [&](NamedAttribute attr) { printNamedAttribute(attr); }); |
| 2854 | os << '}'; |
| 2855 | }; |
| 2856 | |
| 2857 | // If no attributes are elided, we can directly print with no filtering. |
| 2858 | if (elidedAttrs.empty()) |
| 2859 | return printFilteredAttributesFn(attrs); |
| 2860 | |
| 2861 | // Otherwise, filter out any attributes that shouldn't be included. |
| 2862 | llvm::SmallDenseSet<StringRef> (elidedAttrs.begin(), |
| 2863 | elidedAttrs.end()); |
| 2864 | auto filteredAttrs = llvm::make_filter_range(Range&: attrs, Pred: [&](NamedAttribute attr) { |
| 2865 | return !elidedAttrsSet.contains(attr.getName().strref()); |
| 2866 | }); |
| 2867 | if (!filteredAttrs.empty()) |
| 2868 | printFilteredAttributesFn(filteredAttrs); |
| 2869 | } |
| 2870 | void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) { |
| 2871 | // Print the name without quotes if possible. |
| 2872 | ::printKeywordOrString(keyword: attr.getName().strref(), os); |
| 2873 | |
| 2874 | // Pretty printing elides the attribute value for unit attributes. |
| 2875 | if (llvm::isa<UnitAttr>(Val: attr.getValue())) |
| 2876 | return; |
| 2877 | |
| 2878 | os << " = " ; |
| 2879 | printAttribute(attr: attr.getValue()); |
| 2880 | } |
| 2881 | |
| 2882 | void AsmPrinter::Impl::printDialectAttribute(Attribute attr) { |
| 2883 | auto &dialect = attr.getDialect(); |
| 2884 | |
| 2885 | // Ask the dialect to serialize the attribute to a string. |
| 2886 | std::string attrName; |
| 2887 | { |
| 2888 | llvm::raw_string_ostream attrNameStr(attrName); |
| 2889 | Impl subPrinter(attrNameStr, state); |
| 2890 | DialectAsmPrinter printer(subPrinter); |
| 2891 | dialect.printAttribute(attr, printer); |
| 2892 | } |
| 2893 | printDialectSymbol(os, symPrefix: "#" , dialectName: dialect.getNamespace(), symString: attrName); |
| 2894 | } |
| 2895 | |
| 2896 | void AsmPrinter::Impl::printDialectType(Type type) { |
| 2897 | auto &dialect = type.getDialect(); |
| 2898 | |
| 2899 | // Ask the dialect to serialize the type to a string. |
| 2900 | std::string typeName; |
| 2901 | { |
| 2902 | llvm::raw_string_ostream typeNameStr(typeName); |
| 2903 | Impl subPrinter(typeNameStr, state); |
| 2904 | DialectAsmPrinter printer(subPrinter); |
| 2905 | dialect.printType(type, printer); |
| 2906 | } |
| 2907 | printDialectSymbol(os, symPrefix: "!" , dialectName: dialect.getNamespace(), symString: typeName); |
| 2908 | } |
| 2909 | |
| 2910 | void AsmPrinter::Impl::printEscapedString(StringRef str) { |
| 2911 | os << "\"" ; |
| 2912 | llvm::printEscapedString(Name: str, Out&: os); |
| 2913 | os << "\"" ; |
| 2914 | } |
| 2915 | |
| 2916 | void AsmPrinter::Impl::printHexString(StringRef str) { |
| 2917 | os << "\"0x" << llvm::toHex(Input: str) << "\"" ; |
| 2918 | } |
| 2919 | void AsmPrinter::Impl::printHexString(ArrayRef<char> data) { |
| 2920 | printHexString(str: StringRef(data.data(), data.size())); |
| 2921 | } |
| 2922 | |
| 2923 | LogicalResult AsmPrinter::Impl::pushCyclicPrinting(const void *opaquePointer) { |
| 2924 | return state.pushCyclicPrinting(opaquePointer); |
| 2925 | } |
| 2926 | |
| 2927 | void AsmPrinter::Impl::popCyclicPrinting() { state.popCyclicPrinting(); } |
| 2928 | |
| 2929 | void AsmPrinter::Impl::printDimensionList(ArrayRef<int64_t> shape) { |
| 2930 | detail::printDimensionList(stream&: os, shape); |
| 2931 | } |
| 2932 | |
| 2933 | //===--------------------------------------------------------------------===// |
| 2934 | // AsmPrinter |
| 2935 | //===--------------------------------------------------------------------===// |
| 2936 | |
| 2937 | AsmPrinter::~AsmPrinter() = default; |
| 2938 | |
| 2939 | raw_ostream &AsmPrinter::getStream() const { |
| 2940 | assert(impl && "expected AsmPrinter::getStream to be overriden" ); |
| 2941 | return impl->getStream(); |
| 2942 | } |
| 2943 | |
| 2944 | /// Print the given floating point value in a stablized form. |
| 2945 | void AsmPrinter::printFloat(const APFloat &value) { |
| 2946 | assert(impl && "expected AsmPrinter::printFloat to be overriden" ); |
| 2947 | printFloatValue(apValue: value, os&: impl->getStream()); |
| 2948 | } |
| 2949 | |
| 2950 | void AsmPrinter::printType(Type type) { |
| 2951 | assert(impl && "expected AsmPrinter::printType to be overriden" ); |
| 2952 | impl->printType(type); |
| 2953 | } |
| 2954 | |
| 2955 | void AsmPrinter::printAttribute(Attribute attr) { |
| 2956 | assert(impl && "expected AsmPrinter::printAttribute to be overriden" ); |
| 2957 | impl->printAttribute(attr); |
| 2958 | } |
| 2959 | |
| 2960 | LogicalResult AsmPrinter::printAlias(Attribute attr) { |
| 2961 | assert(impl && "expected AsmPrinter::printAlias to be overriden" ); |
| 2962 | return impl->printAlias(attr); |
| 2963 | } |
| 2964 | |
| 2965 | LogicalResult AsmPrinter::printAlias(Type type) { |
| 2966 | assert(impl && "expected AsmPrinter::printAlias to be overriden" ); |
| 2967 | return impl->printAlias(type); |
| 2968 | } |
| 2969 | |
| 2970 | void AsmPrinter::printAttributeWithoutType(Attribute attr) { |
| 2971 | assert(impl && |
| 2972 | "expected AsmPrinter::printAttributeWithoutType to be overriden" ); |
| 2973 | impl->printAttribute(attr, typeElision: Impl::AttrTypeElision::Must); |
| 2974 | } |
| 2975 | |
| 2976 | void AsmPrinter::printKeywordOrString(StringRef keyword) { |
| 2977 | assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden" ); |
| 2978 | ::printKeywordOrString(keyword, os&: impl->getStream()); |
| 2979 | } |
| 2980 | |
| 2981 | void AsmPrinter::printString(StringRef keyword) { |
| 2982 | assert(impl && "expected AsmPrinter::printString to be overriden" ); |
| 2983 | *this << '"'; |
| 2984 | printEscapedString(Name: keyword, Out&: getStream()); |
| 2985 | *this << '"'; |
| 2986 | } |
| 2987 | |
| 2988 | void AsmPrinter::printSymbolName(StringRef symbolRef) { |
| 2989 | assert(impl && "expected AsmPrinter::printSymbolName to be overriden" ); |
| 2990 | ::printSymbolReference(symbolRef, os&: impl->getStream()); |
| 2991 | } |
| 2992 | |
| 2993 | void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) { |
| 2994 | assert(impl && "expected AsmPrinter::printResourceHandle to be overriden" ); |
| 2995 | impl->printResourceHandle(resource); |
| 2996 | } |
| 2997 | |
| 2998 | void AsmPrinter::printDimensionList(ArrayRef<int64_t> shape) { |
| 2999 | detail::printDimensionList(stream&: getStream(), shape); |
| 3000 | } |
| 3001 | |
| 3002 | LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) { |
| 3003 | return impl->pushCyclicPrinting(opaquePointer); |
| 3004 | } |
| 3005 | |
| 3006 | void AsmPrinter::popCyclicPrinting() { impl->popCyclicPrinting(); } |
| 3007 | |
| 3008 | //===----------------------------------------------------------------------===// |
| 3009 | // Affine expressions and maps |
| 3010 | //===----------------------------------------------------------------------===// |
| 3011 | |
| 3012 | void AsmPrinter::Impl::printAffineExpr( |
| 3013 | AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) { |
| 3014 | printAffineExprInternal(expr, enclosingTightness: BindingStrength::Weak, printValueName); |
| 3015 | } |
| 3016 | |
| 3017 | void AsmPrinter::Impl::printAffineExprInternal( |
| 3018 | AffineExpr expr, BindingStrength enclosingTightness, |
| 3019 | function_ref<void(unsigned, bool)> printValueName) { |
| 3020 | const char *binopSpelling = nullptr; |
| 3021 | switch (expr.getKind()) { |
| 3022 | case AffineExprKind::SymbolId: { |
| 3023 | unsigned pos = cast<AffineSymbolExpr>(Val&: expr).getPosition(); |
| 3024 | if (printValueName) |
| 3025 | printValueName(pos, /*isSymbol=*/true); |
| 3026 | else |
| 3027 | os << 's' << pos; |
| 3028 | return; |
| 3029 | } |
| 3030 | case AffineExprKind::DimId: { |
| 3031 | unsigned pos = cast<AffineDimExpr>(Val&: expr).getPosition(); |
| 3032 | if (printValueName) |
| 3033 | printValueName(pos, /*isSymbol=*/false); |
| 3034 | else |
| 3035 | os << 'd' << pos; |
| 3036 | return; |
| 3037 | } |
| 3038 | case AffineExprKind::Constant: |
| 3039 | os << cast<AffineConstantExpr>(Val&: expr).getValue(); |
| 3040 | return; |
| 3041 | case AffineExprKind::Add: |
| 3042 | binopSpelling = " + " ; |
| 3043 | break; |
| 3044 | case AffineExprKind::Mul: |
| 3045 | binopSpelling = " * " ; |
| 3046 | break; |
| 3047 | case AffineExprKind::FloorDiv: |
| 3048 | binopSpelling = " floordiv " ; |
| 3049 | break; |
| 3050 | case AffineExprKind::CeilDiv: |
| 3051 | binopSpelling = " ceildiv " ; |
| 3052 | break; |
| 3053 | case AffineExprKind::Mod: |
| 3054 | binopSpelling = " mod " ; |
| 3055 | break; |
| 3056 | } |
| 3057 | |
| 3058 | auto binOp = cast<AffineBinaryOpExpr>(Val&: expr); |
| 3059 | AffineExpr lhsExpr = binOp.getLHS(); |
| 3060 | AffineExpr rhsExpr = binOp.getRHS(); |
| 3061 | |
| 3062 | // Handle tightly binding binary operators. |
| 3063 | if (binOp.getKind() != AffineExprKind::Add) { |
| 3064 | if (enclosingTightness == BindingStrength::Strong) |
| 3065 | os << '('; |
| 3066 | |
| 3067 | // Pretty print multiplication with -1. |
| 3068 | auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhsExpr); |
| 3069 | if (rhsConst && binOp.getKind() == AffineExprKind::Mul && |
| 3070 | rhsConst.getValue() == -1) { |
| 3071 | os << "-" ; |
| 3072 | printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Strong, printValueName); |
| 3073 | if (enclosingTightness == BindingStrength::Strong) |
| 3074 | os << ')'; |
| 3075 | return; |
| 3076 | } |
| 3077 | |
| 3078 | printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Strong, printValueName); |
| 3079 | |
| 3080 | os << binopSpelling; |
| 3081 | printAffineExprInternal(expr: rhsExpr, enclosingTightness: BindingStrength::Strong, printValueName); |
| 3082 | |
| 3083 | if (enclosingTightness == BindingStrength::Strong) |
| 3084 | os << ')'; |
| 3085 | return; |
| 3086 | } |
| 3087 | |
| 3088 | // Print out special "pretty" forms for add. |
| 3089 | if (enclosingTightness == BindingStrength::Strong) |
| 3090 | os << '('; |
| 3091 | |
| 3092 | // Pretty print addition to a product that has a negative operand as a |
| 3093 | // subtraction. |
| 3094 | if (auto rhs = dyn_cast<AffineBinaryOpExpr>(Val&: rhsExpr)) { |
| 3095 | if (rhs.getKind() == AffineExprKind::Mul) { |
| 3096 | AffineExpr rrhsExpr = rhs.getRHS(); |
| 3097 | if (auto rrhs = dyn_cast<AffineConstantExpr>(Val&: rrhsExpr)) { |
| 3098 | if (rrhs.getValue() == -1) { |
| 3099 | printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Weak, |
| 3100 | printValueName); |
| 3101 | os << " - " ; |
| 3102 | if (rhs.getLHS().getKind() == AffineExprKind::Add) { |
| 3103 | printAffineExprInternal(expr: rhs.getLHS(), enclosingTightness: BindingStrength::Strong, |
| 3104 | printValueName); |
| 3105 | } else { |
| 3106 | printAffineExprInternal(expr: rhs.getLHS(), enclosingTightness: BindingStrength::Weak, |
| 3107 | printValueName); |
| 3108 | } |
| 3109 | |
| 3110 | if (enclosingTightness == BindingStrength::Strong) |
| 3111 | os << ')'; |
| 3112 | return; |
| 3113 | } |
| 3114 | |
| 3115 | if (rrhs.getValue() < -1) { |
| 3116 | printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Weak, |
| 3117 | printValueName); |
| 3118 | os << " - " ; |
| 3119 | printAffineExprInternal(expr: rhs.getLHS(), enclosingTightness: BindingStrength::Strong, |
| 3120 | printValueName); |
| 3121 | os << " * " << -rrhs.getValue(); |
| 3122 | if (enclosingTightness == BindingStrength::Strong) |
| 3123 | os << ')'; |
| 3124 | return; |
| 3125 | } |
| 3126 | } |
| 3127 | } |
| 3128 | } |
| 3129 | |
| 3130 | // Pretty print addition to a negative number as a subtraction. |
| 3131 | if (auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhsExpr)) { |
| 3132 | if (rhsConst.getValue() < 0) { |
| 3133 | printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Weak, printValueName); |
| 3134 | os << " - " << -rhsConst.getValue(); |
| 3135 | if (enclosingTightness == BindingStrength::Strong) |
| 3136 | os << ')'; |
| 3137 | return; |
| 3138 | } |
| 3139 | } |
| 3140 | |
| 3141 | printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Weak, printValueName); |
| 3142 | |
| 3143 | os << " + " ; |
| 3144 | printAffineExprInternal(expr: rhsExpr, enclosingTightness: BindingStrength::Weak, printValueName); |
| 3145 | |
| 3146 | if (enclosingTightness == BindingStrength::Strong) |
| 3147 | os << ')'; |
| 3148 | } |
| 3149 | |
| 3150 | void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) { |
| 3151 | printAffineExprInternal(expr, enclosingTightness: BindingStrength::Weak); |
| 3152 | isEq ? os << " == 0" : os << " >= 0" ; |
| 3153 | } |
| 3154 | |
| 3155 | void AsmPrinter::Impl::printAffineMap(AffineMap map) { |
| 3156 | // Dimension identifiers. |
| 3157 | os << '('; |
| 3158 | for (int i = 0; i < (int)map.getNumDims() - 1; ++i) |
| 3159 | os << 'd' << i << ", " ; |
| 3160 | if (map.getNumDims() >= 1) |
| 3161 | os << 'd' << map.getNumDims() - 1; |
| 3162 | os << ')'; |
| 3163 | |
| 3164 | // Symbolic identifiers. |
| 3165 | if (map.getNumSymbols() != 0) { |
| 3166 | os << '['; |
| 3167 | for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i) |
| 3168 | os << 's' << i << ", " ; |
| 3169 | if (map.getNumSymbols() >= 1) |
| 3170 | os << 's' << map.getNumSymbols() - 1; |
| 3171 | os << ']'; |
| 3172 | } |
| 3173 | |
| 3174 | // Result affine expressions. |
| 3175 | os << " -> (" ; |
| 3176 | interleaveComma(c: map.getResults(), |
| 3177 | eachFn: [&](AffineExpr expr) { printAffineExpr(expr); }); |
| 3178 | os << ')'; |
| 3179 | } |
| 3180 | |
| 3181 | void AsmPrinter::Impl::printIntegerSet(IntegerSet set) { |
| 3182 | // Dimension identifiers. |
| 3183 | os << '('; |
| 3184 | for (unsigned i = 1; i < set.getNumDims(); ++i) |
| 3185 | os << 'd' << i - 1 << ", " ; |
| 3186 | if (set.getNumDims() >= 1) |
| 3187 | os << 'd' << set.getNumDims() - 1; |
| 3188 | os << ')'; |
| 3189 | |
| 3190 | // Symbolic identifiers. |
| 3191 | if (set.getNumSymbols() != 0) { |
| 3192 | os << '['; |
| 3193 | for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i) |
| 3194 | os << 's' << i << ", " ; |
| 3195 | if (set.getNumSymbols() >= 1) |
| 3196 | os << 's' << set.getNumSymbols() - 1; |
| 3197 | os << ']'; |
| 3198 | } |
| 3199 | |
| 3200 | // Print constraints. |
| 3201 | os << " : (" ; |
| 3202 | int numConstraints = set.getNumConstraints(); |
| 3203 | for (int i = 1; i < numConstraints; ++i) { |
| 3204 | printAffineConstraint(expr: set.getConstraint(idx: i - 1), isEq: set.isEq(idx: i - 1)); |
| 3205 | os << ", " ; |
| 3206 | } |
| 3207 | if (numConstraints >= 1) |
| 3208 | printAffineConstraint(expr: set.getConstraint(idx: numConstraints - 1), |
| 3209 | isEq: set.isEq(idx: numConstraints - 1)); |
| 3210 | os << ')'; |
| 3211 | } |
| 3212 | |
| 3213 | //===----------------------------------------------------------------------===// |
| 3214 | // OperationPrinter |
| 3215 | //===----------------------------------------------------------------------===// |
| 3216 | |
| 3217 | namespace { |
| 3218 | /// This class contains the logic for printing operations, regions, and blocks. |
| 3219 | class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter { |
| 3220 | public: |
| 3221 | using Impl = AsmPrinter::Impl; |
| 3222 | using Impl::printType; |
| 3223 | |
| 3224 | explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state) |
| 3225 | : Impl(os, state), OpAsmPrinter(static_cast<Impl &>(*this)) {} |
| 3226 | |
| 3227 | /// Print the given top-level operation. |
| 3228 | void printTopLevelOperation(Operation *op); |
| 3229 | |
| 3230 | /// Print the given operation, including its left-hand side and its right-hand |
| 3231 | /// side, with its indent and location. |
| 3232 | void printFullOpWithIndentAndLoc(Operation *op); |
| 3233 | /// Print the given operation, including its left-hand side and its right-hand |
| 3234 | /// side, but not including indentation and location. |
| 3235 | void printFullOp(Operation *op); |
| 3236 | /// Print the right-hand size of the given operation in the custom or generic |
| 3237 | /// form. |
| 3238 | void printCustomOrGenericOp(Operation *op) override; |
| 3239 | /// Print the right-hand side of the given operation in the generic form. |
| 3240 | void printGenericOp(Operation *op, bool printOpName) override; |
| 3241 | |
| 3242 | /// Print the name of the given block. |
| 3243 | void printBlockName(Block *block); |
| 3244 | |
| 3245 | /// Print the given block. If 'printBlockArgs' is false, the arguments of the |
| 3246 | /// block are not printed. If 'printBlockTerminator' is false, the terminator |
| 3247 | /// operation of the block is not printed. |
| 3248 | void print(Block *block, bool printBlockArgs = true, |
| 3249 | bool printBlockTerminator = true); |
| 3250 | |
| 3251 | /// Print the ID of the given value, optionally with its result number. |
| 3252 | void printValueID(Value value, bool printResultNo = true, |
| 3253 | raw_ostream *streamOverride = nullptr) const; |
| 3254 | |
| 3255 | /// Print the ID of the given operation. |
| 3256 | void printOperationID(Operation *op, |
| 3257 | raw_ostream *streamOverride = nullptr) const; |
| 3258 | |
| 3259 | //===--------------------------------------------------------------------===// |
| 3260 | // OpAsmPrinter methods |
| 3261 | //===--------------------------------------------------------------------===// |
| 3262 | |
| 3263 | /// Print a loc(...) specifier if printing debug info is enabled. Locations |
| 3264 | /// may be deferred with an alias. |
| 3265 | void printOptionalLocationSpecifier(Location loc) override { |
| 3266 | printTrailingLocation(loc); |
| 3267 | } |
| 3268 | |
| 3269 | /// Print a newline and indent the printer to the start of the current |
| 3270 | /// operation. |
| 3271 | void printNewline() override { |
| 3272 | os << newLine; |
| 3273 | os.indent(NumSpaces: currentIndent); |
| 3274 | } |
| 3275 | |
| 3276 | /// Increase indentation. |
| 3277 | void increaseIndent() override { currentIndent += indentWidth; } |
| 3278 | |
| 3279 | /// Decrease indentation. |
| 3280 | void decreaseIndent() override { currentIndent -= indentWidth; } |
| 3281 | |
| 3282 | /// Print a block argument in the usual format of: |
| 3283 | /// %ssaName : type {attr1=42} loc("here") |
| 3284 | /// where location printing is controlled by the standard internal option. |
| 3285 | /// You may pass omitType=true to not print a type, and pass an empty |
| 3286 | /// attribute list if you don't care for attributes. |
| 3287 | void printRegionArgument(BlockArgument arg, |
| 3288 | ArrayRef<NamedAttribute> argAttrs = {}, |
| 3289 | bool omitType = false) override; |
| 3290 | |
| 3291 | /// Print the ID for the given value. |
| 3292 | void printOperand(Value value) override { printValueID(value); } |
| 3293 | void printOperand(Value value, raw_ostream &os) override { |
| 3294 | printValueID(value, /*printResultNo=*/true, streamOverride: &os); |
| 3295 | } |
| 3296 | |
| 3297 | /// Print an optional attribute dictionary with a given set of elided values. |
| 3298 | void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
| 3299 | ArrayRef<StringRef> elidedAttrs = {}) override { |
| 3300 | Impl::printOptionalAttrDict(attrs, elidedAttrs); |
| 3301 | } |
| 3302 | void printOptionalAttrDictWithKeyword( |
| 3303 | ArrayRef<NamedAttribute> attrs, |
| 3304 | ArrayRef<StringRef> elidedAttrs = {}) override { |
| 3305 | Impl::printOptionalAttrDict(attrs, elidedAttrs, |
| 3306 | /*withKeyword=*/true); |
| 3307 | } |
| 3308 | |
| 3309 | /// Print the given successor. |
| 3310 | void printSuccessor(Block *successor) override; |
| 3311 | |
| 3312 | /// Print an operation successor with the operands used for the block |
| 3313 | /// arguments. |
| 3314 | void printSuccessorAndUseList(Block *successor, |
| 3315 | ValueRange succOperands) override; |
| 3316 | |
| 3317 | /// Print the given region. |
| 3318 | void printRegion(Region ®ion, bool printEntryBlockArgs, |
| 3319 | bool printBlockTerminators, bool printEmptyBlock) override; |
| 3320 | |
| 3321 | /// Renumber the arguments for the specified region to the same names as the |
| 3322 | /// SSA values in namesToUse. This may only be used for IsolatedFromAbove |
| 3323 | /// operations. If any entry in namesToUse is null, the corresponding |
| 3324 | /// argument name is left alone. |
| 3325 | void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override { |
| 3326 | state.getSSANameState().shadowRegionArgs(region, namesToUse); |
| 3327 | } |
| 3328 | |
| 3329 | /// Print the given affine map with the symbol and dimension operands printed |
| 3330 | /// inline with the map. |
| 3331 | void printAffineMapOfSSAIds(AffineMapAttr mapAttr, |
| 3332 | ValueRange operands) override; |
| 3333 | |
| 3334 | /// Print the given affine expression with the symbol and dimension operands |
| 3335 | /// printed inline with the expression. |
| 3336 | void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, |
| 3337 | ValueRange symOperands) override; |
| 3338 | |
| 3339 | /// Print users of this operation or id of this operation if it has no result. |
| 3340 | void printUsersComment(Operation *op); |
| 3341 | |
| 3342 | /// Print users of this block arg. |
| 3343 | void printUsersComment(BlockArgument arg); |
| 3344 | |
| 3345 | /// Print the users of a value. |
| 3346 | void printValueUsers(Value value); |
| 3347 | |
| 3348 | /// Print either the ids of the result values or the id of the operation if |
| 3349 | /// the operation has no results. |
| 3350 | void printUserIDs(Operation *user, bool prefixComma = false); |
| 3351 | |
| 3352 | private: |
| 3353 | /// This class represents a resource builder implementation for the MLIR |
| 3354 | /// textual assembly format. |
| 3355 | class ResourceBuilder : public AsmResourceBuilder { |
| 3356 | public: |
| 3357 | using ValueFn = function_ref<void(raw_ostream &)>; |
| 3358 | using PrintFn = function_ref<void(StringRef, ValueFn)>; |
| 3359 | |
| 3360 | ResourceBuilder(PrintFn printFn) : printFn(printFn) {} |
| 3361 | ~ResourceBuilder() override = default; |
| 3362 | |
| 3363 | void buildBool(StringRef key, bool data) final { |
| 3364 | printFn(key, [&](raw_ostream &os) { os << (data ? "true" : "false" ); }); |
| 3365 | } |
| 3366 | |
| 3367 | void buildString(StringRef key, StringRef data) final { |
| 3368 | printFn(key, [&](raw_ostream &os) { |
| 3369 | os << "\"" ; |
| 3370 | llvm::printEscapedString(Name: data, Out&: os); |
| 3371 | os << "\"" ; |
| 3372 | }); |
| 3373 | } |
| 3374 | |
| 3375 | void buildBlob(StringRef key, ArrayRef<char> data, |
| 3376 | uint32_t dataAlignment) final { |
| 3377 | printFn(key, [&](raw_ostream &os) { |
| 3378 | // Store the blob in a hex string containing the alignment and the data. |
| 3379 | llvm::support::ulittle32_t dataAlignmentLE(dataAlignment); |
| 3380 | os << "\"0x" |
| 3381 | << llvm::toHex(Input: StringRef(reinterpret_cast<char *>(&dataAlignmentLE), |
| 3382 | sizeof(dataAlignment))) |
| 3383 | << llvm::toHex(Input: StringRef(data.data(), data.size())) << "\"" ; |
| 3384 | }); |
| 3385 | } |
| 3386 | |
| 3387 | private: |
| 3388 | PrintFn printFn; |
| 3389 | }; |
| 3390 | |
| 3391 | /// Print the metadata dictionary for the file, eliding it if it is empty. |
| 3392 | void printFileMetadataDictionary(Operation *op); |
| 3393 | |
| 3394 | /// Print the resource sections for the file metadata dictionary. |
| 3395 | /// `checkAddMetadataDict` is used to indicate that metadata is going to be |
| 3396 | /// added, and the file metadata dictionary should be started if it hasn't |
| 3397 | /// yet. |
| 3398 | void printResourceFileMetadata(function_ref<void()> checkAddMetadataDict, |
| 3399 | Operation *op); |
| 3400 | |
| 3401 | // Contains the stack of default dialects to use when printing regions. |
| 3402 | // A new dialect is pushed to the stack before parsing regions nested under an |
| 3403 | // operation implementing `OpAsmOpInterface`, and popped when done. At the |
| 3404 | // top-level we start with "builtin" as the default, so that the top-level |
| 3405 | // `module` operation prints as-is. |
| 3406 | SmallVector<StringRef> defaultDialectStack{"builtin" }; |
| 3407 | |
| 3408 | /// The number of spaces used for indenting nested operations. |
| 3409 | const static unsigned indentWidth = 2; |
| 3410 | |
| 3411 | // This is the current indentation level for nested structures. |
| 3412 | unsigned currentIndent = 0; |
| 3413 | }; |
| 3414 | } // namespace |
| 3415 | |
| 3416 | void OperationPrinter::printTopLevelOperation(Operation *op) { |
| 3417 | // Output the aliases at the top level that can't be deferred. |
| 3418 | state.getAliasState().printNonDeferredAliases(p&: *this, newLine); |
| 3419 | |
| 3420 | // Print the module. |
| 3421 | printFullOpWithIndentAndLoc(op); |
| 3422 | os << newLine; |
| 3423 | |
| 3424 | // Output the aliases at the top level that can be deferred. |
| 3425 | state.getAliasState().printDeferredAliases(p&: *this, newLine); |
| 3426 | |
| 3427 | // Output any file level metadata. |
| 3428 | printFileMetadataDictionary(op); |
| 3429 | } |
| 3430 | |
| 3431 | void OperationPrinter::printFileMetadataDictionary(Operation *op) { |
| 3432 | bool sawMetadataEntry = false; |
| 3433 | auto checkAddMetadataDict = [&] { |
| 3434 | if (!std::exchange(obj&: sawMetadataEntry, new_val: true)) |
| 3435 | os << newLine << "{-#" << newLine; |
| 3436 | }; |
| 3437 | |
| 3438 | // Add the various types of metadata. |
| 3439 | printResourceFileMetadata(checkAddMetadataDict, op); |
| 3440 | |
| 3441 | // If the file dictionary exists, close it. |
| 3442 | if (sawMetadataEntry) |
| 3443 | os << newLine << "#-}" << newLine; |
| 3444 | } |
| 3445 | |
| 3446 | void OperationPrinter::printResourceFileMetadata( |
| 3447 | function_ref<void()> checkAddMetadataDict, Operation *op) { |
| 3448 | // Functor used to add data entries to the file metadata dictionary. |
| 3449 | bool hadResource = false; |
| 3450 | bool needResourceComma = false; |
| 3451 | bool needEntryComma = false; |
| 3452 | auto processProvider = [&](StringRef dictName, StringRef name, auto &provider, |
| 3453 | auto &&...providerArgs) { |
| 3454 | bool hadEntry = false; |
| 3455 | auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) { |
| 3456 | checkAddMetadataDict(); |
| 3457 | |
| 3458 | std::string resourceStr; |
| 3459 | auto printResourceStr = [&](raw_ostream &os) { os << resourceStr; }; |
| 3460 | std::optional<uint64_t> charLimit = |
| 3461 | printerFlags.getLargeResourceStringLimit(); |
| 3462 | if (charLimit.has_value()) { |
| 3463 | // Don't compute resourceStr when charLimit is 0. |
| 3464 | if (charLimit.value() == 0) |
| 3465 | return; |
| 3466 | |
| 3467 | llvm::raw_string_ostream ss(resourceStr); |
| 3468 | valueFn(ss); |
| 3469 | |
| 3470 | // Only print entry if its string is small enough. |
| 3471 | if (resourceStr.size() > charLimit.value()) |
| 3472 | return; |
| 3473 | |
| 3474 | // Don't recompute resourceStr when valueFn is called below. |
| 3475 | valueFn = printResourceStr; |
| 3476 | } |
| 3477 | |
| 3478 | // Emit the top-level resource entry if we haven't yet. |
| 3479 | if (!std::exchange(obj&: hadResource, new_val: true)) { |
| 3480 | if (needResourceComma) |
| 3481 | os << "," << newLine; |
| 3482 | os << " " << dictName << "_resources: {" << newLine; |
| 3483 | } |
| 3484 | // Emit the parent resource entry if we haven't yet. |
| 3485 | if (!std::exchange(obj&: hadEntry, new_val: true)) { |
| 3486 | if (needEntryComma) |
| 3487 | os << "," << newLine; |
| 3488 | os << " " << name << ": {" << newLine; |
| 3489 | } else { |
| 3490 | os << "," << newLine; |
| 3491 | } |
| 3492 | os << " " ; |
| 3493 | ::printKeywordOrString(keyword: key, os); |
| 3494 | os << ": " ; |
| 3495 | // Call printResourceStr or original valueFn, depending on charLimit. |
| 3496 | valueFn(os); |
| 3497 | }; |
| 3498 | ResourceBuilder entryBuilder(printFn); |
| 3499 | provider.buildResources(op, providerArgs..., entryBuilder); |
| 3500 | |
| 3501 | needEntryComma |= hadEntry; |
| 3502 | if (hadEntry) |
| 3503 | os << newLine << " }" ; |
| 3504 | }; |
| 3505 | |
| 3506 | // Print the `dialect_resources` section if we have any dialects with |
| 3507 | // resources. |
| 3508 | for (const OpAsmDialectInterface &interface : state.getDialectInterfaces()) { |
| 3509 | auto &dialectResources = state.getDialectResources(); |
| 3510 | StringRef name = interface.getDialect()->getNamespace(); |
| 3511 | auto it = dialectResources.find(Val: interface.getDialect()); |
| 3512 | if (it != dialectResources.end()) |
| 3513 | processProvider("dialect" , name, interface, it->second); |
| 3514 | else |
| 3515 | processProvider("dialect" , name, interface, |
| 3516 | SetVector<AsmDialectResourceHandle>()); |
| 3517 | } |
| 3518 | if (hadResource) |
| 3519 | os << newLine << " }" ; |
| 3520 | |
| 3521 | // Print the `external_resources` section if we have any external clients with |
| 3522 | // resources. |
| 3523 | needEntryComma = false; |
| 3524 | needResourceComma = hadResource; |
| 3525 | hadResource = false; |
| 3526 | for (const auto &printer : state.getResourcePrinters()) |
| 3527 | processProvider("external" , printer.getName(), printer); |
| 3528 | if (hadResource) |
| 3529 | os << newLine << " }" ; |
| 3530 | } |
| 3531 | |
| 3532 | /// Print a block argument in the usual format of: |
| 3533 | /// %ssaName : type {attr1=42} loc("here") |
| 3534 | /// where location printing is controlled by the standard internal option. |
| 3535 | /// You may pass omitType=true to not print a type, and pass an empty |
| 3536 | /// attribute list if you don't care for attributes. |
| 3537 | void OperationPrinter::printRegionArgument(BlockArgument arg, |
| 3538 | ArrayRef<NamedAttribute> argAttrs, |
| 3539 | bool omitType) { |
| 3540 | printOperand(value: arg); |
| 3541 | if (!omitType) { |
| 3542 | os << ": " ; |
| 3543 | printType(type: arg.getType()); |
| 3544 | } |
| 3545 | printOptionalAttrDict(attrs: argAttrs); |
| 3546 | // TODO: We should allow location aliases on block arguments. |
| 3547 | printTrailingLocation(loc: arg.getLoc(), /*allowAlias*/ false); |
| 3548 | } |
| 3549 | |
| 3550 | void OperationPrinter::printFullOpWithIndentAndLoc(Operation *op) { |
| 3551 | // Track the location of this operation. |
| 3552 | state.registerOperationLocation(op, line: newLine.curLine, col: currentIndent); |
| 3553 | |
| 3554 | os.indent(NumSpaces: currentIndent); |
| 3555 | printFullOp(op); |
| 3556 | printTrailingLocation(loc: op->getLoc()); |
| 3557 | if (printerFlags.shouldPrintValueUsers()) |
| 3558 | printUsersComment(op); |
| 3559 | } |
| 3560 | |
| 3561 | void OperationPrinter::printFullOp(Operation *op) { |
| 3562 | if (size_t numResults = op->getNumResults()) { |
| 3563 | auto printResultGroup = [&](size_t resultNo, size_t resultCount) { |
| 3564 | printValueID(value: op->getResult(idx: resultNo), /*printResultNo=*/false); |
| 3565 | if (resultCount > 1) |
| 3566 | os << ':' << resultCount; |
| 3567 | }; |
| 3568 | |
| 3569 | // Check to see if this operation has multiple result groups. |
| 3570 | ArrayRef<int> resultGroups = state.getSSANameState().getOpResultGroups(op); |
| 3571 | if (!resultGroups.empty()) { |
| 3572 | // Interleave the groups excluding the last one, this one will be handled |
| 3573 | // separately. |
| 3574 | interleaveComma(c: llvm::seq<int>(Begin: 0, End: resultGroups.size() - 1), eachFn: [&](int i) { |
| 3575 | printResultGroup(resultGroups[i], |
| 3576 | resultGroups[i + 1] - resultGroups[i]); |
| 3577 | }); |
| 3578 | os << ", " ; |
| 3579 | printResultGroup(resultGroups.back(), numResults - resultGroups.back()); |
| 3580 | |
| 3581 | } else { |
| 3582 | printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults); |
| 3583 | } |
| 3584 | |
| 3585 | os << " = " ; |
| 3586 | } |
| 3587 | |
| 3588 | printCustomOrGenericOp(op); |
| 3589 | } |
| 3590 | |
| 3591 | void OperationPrinter::(Operation *op) { |
| 3592 | unsigned numResults = op->getNumResults(); |
| 3593 | if (!numResults && op->getNumOperands()) { |
| 3594 | os << " // id: " ; |
| 3595 | printOperationID(op); |
| 3596 | } else if (numResults && op->use_empty()) { |
| 3597 | os << " // unused" ; |
| 3598 | } else if (numResults && !op->use_empty()) { |
| 3599 | // Print "user" if the operation has one result used to compute one other |
| 3600 | // result, or is used in one operation with no result. |
| 3601 | unsigned usedInNResults = 0; |
| 3602 | unsigned usedInNOperations = 0; |
| 3603 | SmallPtrSet<Operation *, 1> userSet; |
| 3604 | for (Operation *user : op->getUsers()) { |
| 3605 | if (userSet.insert(Ptr: user).second) { |
| 3606 | ++usedInNOperations; |
| 3607 | usedInNResults += user->getNumResults(); |
| 3608 | } |
| 3609 | } |
| 3610 | |
| 3611 | // We already know that users is not empty. |
| 3612 | bool exactlyOneUniqueUse = |
| 3613 | usedInNResults <= 1 && usedInNOperations <= 1 && numResults == 1; |
| 3614 | os << " // " << (exactlyOneUniqueUse ? "user" : "users" ) << ": " ; |
| 3615 | bool shouldPrintBrackets = numResults > 1; |
| 3616 | auto printOpResult = [&](OpResult opResult) { |
| 3617 | if (shouldPrintBrackets) |
| 3618 | os << "(" ; |
| 3619 | printValueUsers(value: opResult); |
| 3620 | if (shouldPrintBrackets) |
| 3621 | os << ")" ; |
| 3622 | }; |
| 3623 | |
| 3624 | interleaveComma(c: op->getResults(), eachFn: printOpResult); |
| 3625 | } |
| 3626 | } |
| 3627 | |
| 3628 | void OperationPrinter::(BlockArgument arg) { |
| 3629 | os << "// " ; |
| 3630 | printValueID(value: arg); |
| 3631 | if (arg.use_empty()) { |
| 3632 | os << " is unused" ; |
| 3633 | } else { |
| 3634 | os << " is used by " ; |
| 3635 | printValueUsers(value: arg); |
| 3636 | } |
| 3637 | os << newLine; |
| 3638 | } |
| 3639 | |
| 3640 | void OperationPrinter::printValueUsers(Value value) { |
| 3641 | if (value.use_empty()) |
| 3642 | os << "unused" ; |
| 3643 | |
| 3644 | // One value might be used as the operand of an operation more than once. |
| 3645 | // Only print the operations results once in that case. |
| 3646 | SmallPtrSet<Operation *, 1> userSet; |
| 3647 | for (auto [index, user] : enumerate(First: value.getUsers())) { |
| 3648 | if (userSet.insert(Ptr: user).second) |
| 3649 | printUserIDs(user, prefixComma: index); |
| 3650 | } |
| 3651 | } |
| 3652 | |
| 3653 | void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) { |
| 3654 | if (prefixComma) |
| 3655 | os << ", " ; |
| 3656 | |
| 3657 | if (!user->getNumResults()) { |
| 3658 | printOperationID(op: user); |
| 3659 | } else { |
| 3660 | interleaveComma(c: user->getResults(), |
| 3661 | eachFn: [this](Value result) { printValueID(value: result); }); |
| 3662 | } |
| 3663 | } |
| 3664 | |
| 3665 | void OperationPrinter::printCustomOrGenericOp(Operation *op) { |
| 3666 | // If requested, always print the generic form. |
| 3667 | if (!printerFlags.shouldPrintGenericOpForm()) { |
| 3668 | // Check to see if this is a known operation. If so, use the registered |
| 3669 | // custom printer hook. |
| 3670 | if (auto opInfo = op->getRegisteredInfo()) { |
| 3671 | opInfo->printAssembly(op, p&: *this, defaultDialect: defaultDialectStack.back()); |
| 3672 | return; |
| 3673 | } |
| 3674 | // Otherwise try to dispatch to the dialect, if available. |
| 3675 | if (Dialect *dialect = op->getDialect()) { |
| 3676 | if (auto opPrinter = dialect->getOperationPrinter(op)) { |
| 3677 | // Print the op name first. |
| 3678 | StringRef name = op->getName().getStringRef(); |
| 3679 | // Only drop the default dialect prefix when it cannot lead to |
| 3680 | // ambiguities. |
| 3681 | if (name.count(C: '.') == 1) |
| 3682 | name.consume_front(Prefix: (defaultDialectStack.back() + "." ).str()); |
| 3683 | os << name; |
| 3684 | |
| 3685 | // Print the rest of the op now. |
| 3686 | opPrinter(op, *this); |
| 3687 | return; |
| 3688 | } |
| 3689 | } |
| 3690 | } |
| 3691 | |
| 3692 | // Otherwise print with the generic assembly form. |
| 3693 | printGenericOp(op, /*printOpName=*/true); |
| 3694 | } |
| 3695 | |
| 3696 | void OperationPrinter::printGenericOp(Operation *op, bool printOpName) { |
| 3697 | if (printOpName) |
| 3698 | printEscapedString(str: op->getName().getStringRef()); |
| 3699 | os << '('; |
| 3700 | interleaveComma(c: op->getOperands(), eachFn: [&](Value value) { printValueID(value); }); |
| 3701 | os << ')'; |
| 3702 | |
| 3703 | // For terminators, print the list of successors and their operands. |
| 3704 | if (op->getNumSuccessors() != 0) { |
| 3705 | os << '['; |
| 3706 | interleaveComma(c: op->getSuccessors(), |
| 3707 | eachFn: [&](Block *successor) { printBlockName(block: successor); }); |
| 3708 | os << ']'; |
| 3709 | } |
| 3710 | |
| 3711 | // Print the properties. |
| 3712 | if (Attribute prop = op->getPropertiesAsAttribute()) { |
| 3713 | os << " <" ; |
| 3714 | Impl::printAttribute(attr: prop); |
| 3715 | os << '>'; |
| 3716 | } |
| 3717 | |
| 3718 | // Print regions. |
| 3719 | if (op->getNumRegions() != 0) { |
| 3720 | os << " (" ; |
| 3721 | interleaveComma(c: op->getRegions(), eachFn: [&](Region ®ion) { |
| 3722 | printRegion(region, /*printEntryBlockArgs=*/true, |
| 3723 | /*printBlockTerminators=*/true, /*printEmptyBlock=*/true); |
| 3724 | }); |
| 3725 | os << ')'; |
| 3726 | } |
| 3727 | |
| 3728 | printOptionalAttrDict(attrs: op->getPropertiesStorage() |
| 3729 | ? llvm::to_vector(op->getDiscardableAttrs()) |
| 3730 | : op->getAttrs()); |
| 3731 | |
| 3732 | // Print the type signature of the operation. |
| 3733 | os << " : " ; |
| 3734 | printFunctionalType(op); |
| 3735 | } |
| 3736 | |
| 3737 | void OperationPrinter::printBlockName(Block *block) { |
| 3738 | os << state.getSSANameState().getBlockInfo(block).name; |
| 3739 | } |
| 3740 | |
| 3741 | void OperationPrinter::print(Block *block, bool printBlockArgs, |
| 3742 | bool printBlockTerminator) { |
| 3743 | // Print the block label and argument list if requested. |
| 3744 | if (printBlockArgs) { |
| 3745 | os.indent(NumSpaces: currentIndent); |
| 3746 | printBlockName(block); |
| 3747 | |
| 3748 | // Print the argument list if non-empty. |
| 3749 | if (!block->args_empty()) { |
| 3750 | os << '('; |
| 3751 | interleaveComma(c: block->getArguments(), eachFn: [&](BlockArgument arg) { |
| 3752 | printValueID(value: arg); |
| 3753 | os << ": " ; |
| 3754 | printType(type: arg.getType()); |
| 3755 | // TODO: We should allow location aliases on block arguments. |
| 3756 | printTrailingLocation(loc: arg.getLoc(), /*allowAlias*/ false); |
| 3757 | }); |
| 3758 | os << ')'; |
| 3759 | } |
| 3760 | os << ':'; |
| 3761 | |
| 3762 | // Print out some context information about the predecessors of this block. |
| 3763 | if (!block->getParent()) { |
| 3764 | os << " // block is not in a region!" ; |
| 3765 | } else if (block->hasNoPredecessors()) { |
| 3766 | if (!block->isEntryBlock()) |
| 3767 | os << " // no predecessors" ; |
| 3768 | } else if (auto *pred = block->getSinglePredecessor()) { |
| 3769 | os << " // pred: " ; |
| 3770 | printBlockName(block: pred); |
| 3771 | } else { |
| 3772 | // We want to print the predecessors in a stable order, not in |
| 3773 | // whatever order the use-list is in, so gather and sort them. |
| 3774 | SmallVector<BlockInfo, 4> predIDs; |
| 3775 | for (auto *pred : block->getPredecessors()) |
| 3776 | predIDs.push_back(Elt: state.getSSANameState().getBlockInfo(block: pred)); |
| 3777 | llvm::sort(C&: predIDs, Comp: [](BlockInfo lhs, BlockInfo rhs) { |
| 3778 | return lhs.ordering < rhs.ordering; |
| 3779 | }); |
| 3780 | |
| 3781 | os << " // " << predIDs.size() << " preds: " ; |
| 3782 | |
| 3783 | interleaveComma(c: predIDs, eachFn: [&](BlockInfo pred) { os << pred.name; }); |
| 3784 | } |
| 3785 | os << newLine; |
| 3786 | } |
| 3787 | |
| 3788 | currentIndent += indentWidth; |
| 3789 | |
| 3790 | if (printerFlags.shouldPrintValueUsers()) { |
| 3791 | for (BlockArgument arg : block->getArguments()) { |
| 3792 | os.indent(NumSpaces: currentIndent); |
| 3793 | printUsersComment(arg); |
| 3794 | } |
| 3795 | } |
| 3796 | |
| 3797 | bool hasTerminator = |
| 3798 | !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>(); |
| 3799 | auto range = llvm::make_range( |
| 3800 | x: block->begin(), |
| 3801 | y: std::prev(x: block->end(), |
| 3802 | n: (!hasTerminator || printBlockTerminator) ? 0 : 1)); |
| 3803 | for (auto &op : range) { |
| 3804 | printFullOpWithIndentAndLoc(op: &op); |
| 3805 | os << newLine; |
| 3806 | } |
| 3807 | currentIndent -= indentWidth; |
| 3808 | } |
| 3809 | |
| 3810 | void OperationPrinter::printValueID(Value value, bool printResultNo, |
| 3811 | raw_ostream *streamOverride) const { |
| 3812 | state.getSSANameState().printValueID(value, printResultNo, |
| 3813 | stream&: streamOverride ? *streamOverride : os); |
| 3814 | } |
| 3815 | |
| 3816 | void OperationPrinter::printOperationID(Operation *op, |
| 3817 | raw_ostream *streamOverride) const { |
| 3818 | state.getSSANameState().printOperationID(op, stream&: streamOverride ? *streamOverride |
| 3819 | : os); |
| 3820 | } |
| 3821 | |
| 3822 | void OperationPrinter::printSuccessor(Block *successor) { |
| 3823 | printBlockName(block: successor); |
| 3824 | } |
| 3825 | |
| 3826 | void OperationPrinter::printSuccessorAndUseList(Block *successor, |
| 3827 | ValueRange succOperands) { |
| 3828 | printBlockName(block: successor); |
| 3829 | if (succOperands.empty()) |
| 3830 | return; |
| 3831 | |
| 3832 | os << '('; |
| 3833 | interleaveComma(c: succOperands, |
| 3834 | eachFn: [this](Value operand) { printValueID(value: operand); }); |
| 3835 | os << " : " ; |
| 3836 | interleaveComma(c: succOperands, |
| 3837 | eachFn: [this](Value operand) { printType(type: operand.getType()); }); |
| 3838 | os << ')'; |
| 3839 | } |
| 3840 | |
| 3841 | void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs, |
| 3842 | bool printBlockTerminators, |
| 3843 | bool printEmptyBlock) { |
| 3844 | if (printerFlags.shouldSkipRegions()) { |
| 3845 | os << "{...}" ; |
| 3846 | return; |
| 3847 | } |
| 3848 | os << "{" << newLine; |
| 3849 | if (!region.empty()) { |
| 3850 | auto restoreDefaultDialect = |
| 3851 | llvm::make_scope_exit(F: [&]() { defaultDialectStack.pop_back(); }); |
| 3852 | if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp())) |
| 3853 | defaultDialectStack.push_back(Elt: iface.getDefaultDialect()); |
| 3854 | else |
| 3855 | defaultDialectStack.push_back(Elt: "" ); |
| 3856 | |
| 3857 | auto *entryBlock = ®ion.front(); |
| 3858 | // Force printing the block header if printEmptyBlock is set and the block |
| 3859 | // is empty or if printEntryBlockArgs is set and there are arguments to |
| 3860 | // print. |
| 3861 | bool = |
| 3862 | (printEmptyBlock && entryBlock->empty()) || |
| 3863 | (printEntryBlockArgs && entryBlock->getNumArguments() != 0); |
| 3864 | print(block: entryBlock, printBlockArgs: shouldAlwaysPrintBlockHeader, printBlockTerminator: printBlockTerminators); |
| 3865 | for (auto &b : llvm::drop_begin(RangeOrContainer&: region.getBlocks(), N: 1)) |
| 3866 | print(block: &b); |
| 3867 | } |
| 3868 | os.indent(NumSpaces: currentIndent) << "}" ; |
| 3869 | } |
| 3870 | |
| 3871 | void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr, |
| 3872 | ValueRange operands) { |
| 3873 | if (!mapAttr) { |
| 3874 | os << "<<NULL AFFINE MAP>>" ; |
| 3875 | return; |
| 3876 | } |
| 3877 | AffineMap map = mapAttr.getValue(); |
| 3878 | unsigned numDims = map.getNumDims(); |
| 3879 | auto printValueName = [&](unsigned pos, bool isSymbol) { |
| 3880 | unsigned index = isSymbol ? numDims + pos : pos; |
| 3881 | assert(index < operands.size()); |
| 3882 | if (isSymbol) |
| 3883 | os << "symbol(" ; |
| 3884 | printValueID(value: operands[index]); |
| 3885 | if (isSymbol) |
| 3886 | os << ')'; |
| 3887 | }; |
| 3888 | |
| 3889 | interleaveComma(c: map.getResults(), eachFn: [&](AffineExpr expr) { |
| 3890 | printAffineExpr(expr, printValueName); |
| 3891 | }); |
| 3892 | } |
| 3893 | |
| 3894 | void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr, |
| 3895 | ValueRange dimOperands, |
| 3896 | ValueRange symOperands) { |
| 3897 | auto printValueName = [&](unsigned pos, bool isSymbol) { |
| 3898 | if (!isSymbol) |
| 3899 | return printValueID(value: dimOperands[pos]); |
| 3900 | os << "symbol(" ; |
| 3901 | printValueID(value: symOperands[pos]); |
| 3902 | os << ')'; |
| 3903 | }; |
| 3904 | printAffineExpr(expr, printValueName); |
| 3905 | } |
| 3906 | |
| 3907 | //===----------------------------------------------------------------------===// |
| 3908 | // print and dump methods |
| 3909 | //===----------------------------------------------------------------------===// |
| 3910 | |
| 3911 | void Attribute::print(raw_ostream &os, bool elideType) const { |
| 3912 | if (!*this) { |
| 3913 | os << "<<NULL ATTRIBUTE>>" ; |
| 3914 | return; |
| 3915 | } |
| 3916 | |
| 3917 | AsmState state(getContext()); |
| 3918 | print(os, state, elideType); |
| 3919 | } |
| 3920 | void Attribute::print(raw_ostream &os, AsmState &state, bool elideType) const { |
| 3921 | using AttrTypeElision = AsmPrinter::Impl::AttrTypeElision; |
| 3922 | AsmPrinter::Impl(os, state.getImpl()) |
| 3923 | .printAttribute(attr: *this, typeElision: elideType ? AttrTypeElision::Must |
| 3924 | : AttrTypeElision::Never); |
| 3925 | } |
| 3926 | |
| 3927 | void Attribute::dump() const { |
| 3928 | print(os&: llvm::errs()); |
| 3929 | llvm::errs() << "\n" ; |
| 3930 | } |
| 3931 | |
| 3932 | void Attribute::printStripped(raw_ostream &os, AsmState &state) const { |
| 3933 | if (!*this) { |
| 3934 | os << "<<NULL ATTRIBUTE>>" ; |
| 3935 | return; |
| 3936 | } |
| 3937 | |
| 3938 | AsmPrinter::Impl subPrinter(os, state.getImpl()); |
| 3939 | if (succeeded(Result: subPrinter.printAlias(attr: *this))) |
| 3940 | return; |
| 3941 | |
| 3942 | auto &dialect = this->getDialect(); |
| 3943 | uint64_t posPrior = os.tell(); |
| 3944 | DialectAsmPrinter printer(subPrinter); |
| 3945 | dialect.printAttribute(*this, printer); |
| 3946 | if (posPrior != os.tell()) |
| 3947 | return; |
| 3948 | |
| 3949 | // Fallback to printing with prefix if the above failed to write anything |
| 3950 | // to the output stream. |
| 3951 | print(os, state); |
| 3952 | } |
| 3953 | void Attribute::printStripped(raw_ostream &os) const { |
| 3954 | if (!*this) { |
| 3955 | os << "<<NULL ATTRIBUTE>>" ; |
| 3956 | return; |
| 3957 | } |
| 3958 | |
| 3959 | AsmState state(getContext()); |
| 3960 | printStripped(os, state); |
| 3961 | } |
| 3962 | |
| 3963 | void Type::print(raw_ostream &os) const { |
| 3964 | if (!*this) { |
| 3965 | os << "<<NULL TYPE>>" ; |
| 3966 | return; |
| 3967 | } |
| 3968 | |
| 3969 | AsmState state(getContext()); |
| 3970 | print(os, state); |
| 3971 | } |
| 3972 | void Type::print(raw_ostream &os, AsmState &state) const { |
| 3973 | AsmPrinter::Impl(os, state.getImpl()).printType(type: *this); |
| 3974 | } |
| 3975 | |
| 3976 | void Type::dump() const { |
| 3977 | print(os&: llvm::errs()); |
| 3978 | llvm::errs() << "\n" ; |
| 3979 | } |
| 3980 | |
| 3981 | void AffineMap::dump() const { |
| 3982 | print(os&: llvm::errs()); |
| 3983 | llvm::errs() << "\n" ; |
| 3984 | } |
| 3985 | |
| 3986 | void IntegerSet::dump() const { |
| 3987 | print(os&: llvm::errs()); |
| 3988 | llvm::errs() << "\n" ; |
| 3989 | } |
| 3990 | |
| 3991 | void AffineExpr::print(raw_ostream &os) const { |
| 3992 | if (!expr) { |
| 3993 | os << "<<NULL AFFINE EXPR>>" ; |
| 3994 | return; |
| 3995 | } |
| 3996 | AsmState state(getContext()); |
| 3997 | AsmPrinter::Impl(os, state.getImpl()).printAffineExpr(expr: *this); |
| 3998 | } |
| 3999 | |
| 4000 | void AffineExpr::dump() const { |
| 4001 | print(os&: llvm::errs()); |
| 4002 | llvm::errs() << "\n" ; |
| 4003 | } |
| 4004 | |
| 4005 | void AffineMap::print(raw_ostream &os) const { |
| 4006 | if (!map) { |
| 4007 | os << "<<NULL AFFINE MAP>>" ; |
| 4008 | return; |
| 4009 | } |
| 4010 | AsmState state(getContext()); |
| 4011 | AsmPrinter::Impl(os, state.getImpl()).printAffineMap(map: *this); |
| 4012 | } |
| 4013 | |
| 4014 | void IntegerSet::print(raw_ostream &os) const { |
| 4015 | AsmState state(getContext()); |
| 4016 | AsmPrinter::Impl(os, state.getImpl()).printIntegerSet(set: *this); |
| 4017 | } |
| 4018 | |
| 4019 | void Value::print(raw_ostream &os) const { print(os, flags: OpPrintingFlags()); } |
| 4020 | void Value::print(raw_ostream &os, const OpPrintingFlags &flags) const { |
| 4021 | if (!impl) { |
| 4022 | os << "<<NULL VALUE>>" ; |
| 4023 | return; |
| 4024 | } |
| 4025 | |
| 4026 | if (auto *op = getDefiningOp()) |
| 4027 | return op->print(os, flags); |
| 4028 | // TODO: Improve BlockArgument print'ing. |
| 4029 | BlockArgument arg = llvm::cast<BlockArgument>(Val: *this); |
| 4030 | os << "<block argument> of type '" << arg.getType() |
| 4031 | << "' at index: " << arg.getArgNumber(); |
| 4032 | } |
| 4033 | void Value::print(raw_ostream &os, AsmState &state) const { |
| 4034 | if (!impl) { |
| 4035 | os << "<<NULL VALUE>>" ; |
| 4036 | return; |
| 4037 | } |
| 4038 | |
| 4039 | if (auto *op = getDefiningOp()) |
| 4040 | return op->print(os, state); |
| 4041 | |
| 4042 | // TODO: Improve BlockArgument print'ing. |
| 4043 | BlockArgument arg = llvm::cast<BlockArgument>(Val: *this); |
| 4044 | os << "<block argument> of type '" << arg.getType() |
| 4045 | << "' at index: " << arg.getArgNumber(); |
| 4046 | } |
| 4047 | |
| 4048 | void Value::dump() const { |
| 4049 | print(os&: llvm::errs()); |
| 4050 | llvm::errs() << "\n" ; |
| 4051 | } |
| 4052 | |
| 4053 | void Value::printAsOperand(raw_ostream &os, AsmState &state) const { |
| 4054 | // TODO: This doesn't necessarily capture all potential cases. |
| 4055 | // Currently, region arguments can be shadowed when printing the main |
| 4056 | // operation. If the IR hasn't been printed, this will produce the old SSA |
| 4057 | // name and not the shadowed name. |
| 4058 | state.getImpl().getSSANameState().printValueID(value: *this, /*printResultNo=*/true, |
| 4059 | stream&: os); |
| 4060 | } |
| 4061 | |
| 4062 | static Operation *findParent(Operation *op, bool shouldUseLocalScope) { |
| 4063 | do { |
| 4064 | // If we are printing local scope, stop at the first operation that is |
| 4065 | // isolated from above. |
| 4066 | if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>()) |
| 4067 | break; |
| 4068 | |
| 4069 | // Otherwise, traverse up to the next parent. |
| 4070 | Operation *parentOp = op->getParentOp(); |
| 4071 | if (!parentOp) |
| 4072 | break; |
| 4073 | op = parentOp; |
| 4074 | } while (true); |
| 4075 | return op; |
| 4076 | } |
| 4077 | |
| 4078 | void Value::printAsOperand(raw_ostream &os, |
| 4079 | const OpPrintingFlags &flags) const { |
| 4080 | Operation *op; |
| 4081 | if (auto result = llvm::dyn_cast<OpResult>(Val: *this)) { |
| 4082 | op = result.getOwner(); |
| 4083 | } else { |
| 4084 | op = llvm::cast<BlockArgument>(Val: *this).getOwner()->getParentOp(); |
| 4085 | if (!op) { |
| 4086 | os << "<<UNKNOWN SSA VALUE>>" ; |
| 4087 | return; |
| 4088 | } |
| 4089 | } |
| 4090 | op = findParent(op, shouldUseLocalScope: flags.shouldUseLocalScope()); |
| 4091 | AsmState state(op, flags); |
| 4092 | printAsOperand(os, state); |
| 4093 | } |
| 4094 | |
| 4095 | void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) { |
| 4096 | // Find the operation to number from based upon the provided flags. |
| 4097 | Operation *op = findParent(op: this, shouldUseLocalScope: printerFlags.shouldUseLocalScope()); |
| 4098 | AsmState state(op, printerFlags); |
| 4099 | print(os, state); |
| 4100 | } |
| 4101 | void Operation::print(raw_ostream &os, AsmState &state) { |
| 4102 | OperationPrinter printer(os, state.getImpl()); |
| 4103 | if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) { |
| 4104 | state.getImpl().initializeAliases(op: this); |
| 4105 | printer.printTopLevelOperation(op: this); |
| 4106 | } else { |
| 4107 | printer.printFullOpWithIndentAndLoc(op: this); |
| 4108 | } |
| 4109 | } |
| 4110 | |
| 4111 | void Operation::dump() { |
| 4112 | print(os&: llvm::errs(), printerFlags: OpPrintingFlags().useLocalScope()); |
| 4113 | llvm::errs() << "\n" ; |
| 4114 | } |
| 4115 | |
| 4116 | void Operation::dumpPretty() { |
| 4117 | print(os&: llvm::errs(), printerFlags: OpPrintingFlags().useLocalScope().assumeVerified()); |
| 4118 | llvm::errs() << "\n" ; |
| 4119 | } |
| 4120 | |
| 4121 | void Block::print(raw_ostream &os) { |
| 4122 | Operation *parentOp = getParentOp(); |
| 4123 | if (!parentOp) { |
| 4124 | os << "<<UNLINKED BLOCK>>\n" ; |
| 4125 | return; |
| 4126 | } |
| 4127 | // Get the top-level op. |
| 4128 | while (auto *nextOp = parentOp->getParentOp()) |
| 4129 | parentOp = nextOp; |
| 4130 | |
| 4131 | AsmState state(parentOp); |
| 4132 | print(os, state); |
| 4133 | } |
| 4134 | void Block::print(raw_ostream &os, AsmState &state) { |
| 4135 | OperationPrinter(os, state.getImpl()).print(block: this); |
| 4136 | } |
| 4137 | |
| 4138 | void Block::dump() { print(os&: llvm::errs()); } |
| 4139 | |
| 4140 | /// Print out the name of the block without printing its body. |
| 4141 | void Block::printAsOperand(raw_ostream &os, bool printType) { |
| 4142 | Operation *parentOp = getParentOp(); |
| 4143 | if (!parentOp) { |
| 4144 | os << "<<UNLINKED BLOCK>>\n" ; |
| 4145 | return; |
| 4146 | } |
| 4147 | AsmState state(parentOp); |
| 4148 | printAsOperand(os, state); |
| 4149 | } |
| 4150 | void Block::printAsOperand(raw_ostream &os, AsmState &state) { |
| 4151 | OperationPrinter printer(os, state.getImpl()); |
| 4152 | printer.printBlockName(block: this); |
| 4153 | } |
| 4154 | |
| 4155 | raw_ostream &mlir::operator<<(raw_ostream &os, Block &block) { |
| 4156 | block.print(os); |
| 4157 | return os; |
| 4158 | } |
| 4159 | |
| 4160 | //===--------------------------------------------------------------------===// |
| 4161 | // Custom printers |
| 4162 | //===--------------------------------------------------------------------===// |
| 4163 | namespace mlir { |
| 4164 | |
| 4165 | void printDimensionList(OpAsmPrinter &printer, Operation *op, |
| 4166 | ArrayRef<int64_t> dimensions) { |
| 4167 | if (dimensions.empty()) |
| 4168 | printer << "[" ; |
| 4169 | printer.printDimensionList(shape: dimensions); |
| 4170 | if (dimensions.empty()) |
| 4171 | printer << "]" ; |
| 4172 | } |
| 4173 | |
| 4174 | ParseResult parseDimensionList(OpAsmParser &parser, |
| 4175 | DenseI64ArrayAttr &dimensions) { |
| 4176 | // Empty list case denoted by "[]". |
| 4177 | if (succeeded(Result: parser.parseOptionalLSquare())) { |
| 4178 | if (failed(Result: parser.parseRSquare())) { |
| 4179 | return parser.emitError(loc: parser.getCurrentLocation()) |
| 4180 | << "Failed parsing dimension list." ; |
| 4181 | } |
| 4182 | dimensions = |
| 4183 | DenseI64ArrayAttr::get(parser.getContext(), ArrayRef<int64_t>()); |
| 4184 | return success(); |
| 4185 | } |
| 4186 | |
| 4187 | // Non-empty list case. |
| 4188 | SmallVector<int64_t> shapeArr; |
| 4189 | if (failed(Result: parser.parseDimensionList(dimensions&: shapeArr, allowDynamic: true, withTrailingX: false))) { |
| 4190 | return parser.emitError(loc: parser.getCurrentLocation()) |
| 4191 | << "Failed parsing dimension list." ; |
| 4192 | } |
| 4193 | if (shapeArr.empty()) { |
| 4194 | return parser.emitError(loc: parser.getCurrentLocation()) |
| 4195 | << "Failed parsing dimension list. Did you mean an empty list? It " |
| 4196 | "must be denoted by \"[]\"." ; |
| 4197 | } |
| 4198 | dimensions = DenseI64ArrayAttr::get(parser.getContext(), shapeArr); |
| 4199 | return success(); |
| 4200 | } |
| 4201 | |
| 4202 | } // namespace mlir |
| 4203 | |