1 | //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the MLIR AsmPrinter class, which is used to implement |
10 | // the various print() methods on the core IR objects. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/IR/AffineExpr.h" |
15 | #include "mlir/IR/AffineMap.h" |
16 | #include "mlir/IR/AsmState.h" |
17 | #include "mlir/IR/Attributes.h" |
18 | #include "mlir/IR/Builders.h" |
19 | #include "mlir/IR/BuiltinAttributes.h" |
20 | #include "mlir/IR/BuiltinDialect.h" |
21 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
22 | #include "mlir/IR/BuiltinTypes.h" |
23 | #include "mlir/IR/Dialect.h" |
24 | #include "mlir/IR/DialectImplementation.h" |
25 | #include "mlir/IR/DialectResourceBlobManager.h" |
26 | #include "mlir/IR/IntegerSet.h" |
27 | #include "mlir/IR/MLIRContext.h" |
28 | #include "mlir/IR/OpImplementation.h" |
29 | #include "mlir/IR/Operation.h" |
30 | #include "mlir/IR/Verifier.h" |
31 | #include "llvm/ADT/APFloat.h" |
32 | #include "llvm/ADT/ArrayRef.h" |
33 | #include "llvm/ADT/DenseMap.h" |
34 | #include "llvm/ADT/MapVector.h" |
35 | #include "llvm/ADT/STLExtras.h" |
36 | #include "llvm/ADT/ScopeExit.h" |
37 | #include "llvm/ADT/ScopedHashTable.h" |
38 | #include "llvm/ADT/SetVector.h" |
39 | #include "llvm/ADT/SmallString.h" |
40 | #include "llvm/ADT/StringExtras.h" |
41 | #include "llvm/ADT/StringSet.h" |
42 | #include "llvm/ADT/TypeSwitch.h" |
43 | #include "llvm/Support/CommandLine.h" |
44 | #include "llvm/Support/Debug.h" |
45 | #include "llvm/Support/Endian.h" |
46 | #include "llvm/Support/ManagedStatic.h" |
47 | #include "llvm/Support/Regex.h" |
48 | #include "llvm/Support/SaveAndRestore.h" |
49 | #include "llvm/Support/Threading.h" |
50 | #include "llvm/Support/raw_ostream.h" |
51 | #include <type_traits> |
52 | |
53 | #include <optional> |
54 | #include <tuple> |
55 | |
56 | using namespace mlir; |
57 | using namespace mlir::detail; |
58 | |
59 | #define DEBUG_TYPE "mlir-asm-printer" |
60 | |
61 | void OperationName::print(raw_ostream &os) const { os << getStringRef(); } |
62 | |
63 | void OperationName::dump() const { print(os&: llvm::errs()); } |
64 | |
65 | //===--------------------------------------------------------------------===// |
66 | // AsmParser |
67 | //===--------------------------------------------------------------------===// |
68 | |
69 | AsmParser::~AsmParser() = default; |
70 | DialectAsmParser::~DialectAsmParser() = default; |
71 | OpAsmParser::~OpAsmParser() = default; |
72 | |
73 | MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); } |
74 | |
75 | /// Parse a type list. |
76 | /// This is out-of-line to work-around |
77 | /// https://github.com/llvm/llvm-project/issues/62918 |
78 | ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) { |
79 | return parseCommaSeparatedList( |
80 | parseElementFn: [&]() { return parseType(result&: result.emplace_back()); }); |
81 | } |
82 | |
83 | //===----------------------------------------------------------------------===// |
84 | // DialectAsmPrinter |
85 | //===----------------------------------------------------------------------===// |
86 | |
87 | DialectAsmPrinter::~DialectAsmPrinter() = default; |
88 | |
89 | //===----------------------------------------------------------------------===// |
90 | // OpAsmPrinter |
91 | //===----------------------------------------------------------------------===// |
92 | |
93 | OpAsmPrinter::~OpAsmPrinter() = default; |
94 | |
95 | void OpAsmPrinter::printFunctionalType(Operation *op) { |
96 | auto &os = getStream(); |
97 | os << '('; |
98 | llvm::interleaveComma(c: op->getOperands(), os, each_fn: [&](Value operand) { |
99 | // Print the types of null values as <<NULL TYPE>>. |
100 | *this << (operand ? operand.getType() : Type()); |
101 | }); |
102 | os << ") -> " ; |
103 | |
104 | // Print the result list. We don't parenthesize single result types unless |
105 | // it is a function (avoiding a grammar ambiguity). |
106 | bool wrapped = op->getNumResults() != 1; |
107 | if (!wrapped && op->getResult(idx: 0).getType() && |
108 | llvm::isa<FunctionType>(Val: op->getResult(idx: 0).getType())) |
109 | wrapped = true; |
110 | |
111 | if (wrapped) |
112 | os << '('; |
113 | |
114 | llvm::interleaveComma(c: op->getResults(), os, each_fn: [&](const OpResult &result) { |
115 | // Print the types of null values as <<NULL TYPE>>. |
116 | *this << (result ? result.getType() : Type()); |
117 | }); |
118 | |
119 | if (wrapped) |
120 | os << ')'; |
121 | } |
122 | |
123 | //===----------------------------------------------------------------------===// |
124 | // Operation OpAsm interface. |
125 | //===----------------------------------------------------------------------===// |
126 | |
127 | /// The OpAsmOpInterface, see OpAsmInterface.td for more details. |
128 | #include "mlir/IR/OpAsmAttrInterface.cpp.inc" |
129 | #include "mlir/IR/OpAsmOpInterface.cpp.inc" |
130 | #include "mlir/IR/OpAsmTypeInterface.cpp.inc" |
131 | |
132 | LogicalResult |
133 | OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const { |
134 | return entry.emitError() << "unknown 'resource' key '" << entry.getKey() |
135 | << "' for dialect '" << getDialect()->getNamespace() |
136 | << "'" ; |
137 | } |
138 | |
139 | //===----------------------------------------------------------------------===// |
140 | // OpPrintingFlags |
141 | //===----------------------------------------------------------------------===// |
142 | |
143 | namespace { |
144 | /// This struct contains command line options that can be used to initialize |
145 | /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need |
146 | /// for global command line options. |
147 | struct AsmPrinterOptions { |
148 | llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{ |
149 | "mlir-print-elementsattrs-with-hex-if-larger" , |
150 | llvm::cl::desc( |
151 | "Print DenseElementsAttrs with a hex string that have " |
152 | "more elements than the given upper limit (use -1 to disable)" )}; |
153 | |
154 | llvm::cl::opt<unsigned> elideElementsAttrIfLarger{ |
155 | "mlir-elide-elementsattrs-if-larger" , |
156 | llvm::cl::desc("Elide ElementsAttrs with \"...\" that have " |
157 | "more elements than the given upper limit" )}; |
158 | |
159 | llvm::cl::opt<unsigned> elideResourceStringsIfLarger{ |
160 | "mlir-elide-resource-strings-if-larger" , |
161 | llvm::cl::desc( |
162 | "Elide printing value of resources if string is too long in chars." )}; |
163 | |
164 | llvm::cl::opt<bool> printDebugInfoOpt{ |
165 | "mlir-print-debuginfo" , llvm::cl::init(Val: false), |
166 | llvm::cl::desc("Print debug info in MLIR output" )}; |
167 | |
168 | llvm::cl::opt<bool> printPrettyDebugInfoOpt{ |
169 | "mlir-pretty-debuginfo" , llvm::cl::init(Val: false), |
170 | llvm::cl::desc("Print pretty debug info in MLIR output" )}; |
171 | |
172 | // Use the generic op output form in the operation printer even if the custom |
173 | // form is defined. |
174 | llvm::cl::opt<bool> printGenericOpFormOpt{ |
175 | "mlir-print-op-generic" , llvm::cl::init(Val: false), |
176 | llvm::cl::desc("Print the generic op form" ), llvm::cl::Hidden}; |
177 | |
178 | llvm::cl::opt<bool> assumeVerifiedOpt{ |
179 | "mlir-print-assume-verified" , llvm::cl::init(Val: false), |
180 | llvm::cl::desc("Skip op verification when using custom printers" ), |
181 | llvm::cl::Hidden}; |
182 | |
183 | llvm::cl::opt<bool> printLocalScopeOpt{ |
184 | "mlir-print-local-scope" , llvm::cl::init(Val: false), |
185 | llvm::cl::desc("Print with local scope and inline information (eliding " |
186 | "aliases for attributes, types, and locations)" )}; |
187 | |
188 | llvm::cl::opt<bool> skipRegionsOpt{ |
189 | "mlir-print-skip-regions" , llvm::cl::init(Val: false), |
190 | llvm::cl::desc("Skip regions when printing ops." )}; |
191 | |
192 | llvm::cl::opt<bool> printValueUsers{ |
193 | "mlir-print-value-users" , llvm::cl::init(Val: false), |
194 | llvm::cl::desc( |
195 | "Print users of operation results and block arguments as a comment" )}; |
196 | |
197 | llvm::cl::opt<bool> printUniqueSSAIDs{ |
198 | "mlir-print-unique-ssa-ids" , llvm::cl::init(Val: false), |
199 | llvm::cl::desc("Print unique SSA ID numbers for values, block arguments " |
200 | "and naming conflicts across all regions" )}; |
201 | |
202 | llvm::cl::opt<bool> useNameLocAsPrefix{ |
203 | "mlir-use-nameloc-as-prefix" , llvm::cl::init(Val: false), |
204 | llvm::cl::desc("Print SSA IDs using NameLocs as prefixes" )}; |
205 | }; |
206 | } // namespace |
207 | |
208 | static llvm::ManagedStatic<AsmPrinterOptions> clOptions; |
209 | |
210 | /// Register a set of useful command-line options that can be used to configure |
211 | /// various flags within the AsmPrinter. |
212 | void mlir::registerAsmPrinterCLOptions() { |
213 | // Make sure that the options struct has been initialized. |
214 | *clOptions; |
215 | } |
216 | |
217 | /// Initialize the printing flags with default supplied by the cl::opts above. |
218 | OpPrintingFlags::OpPrintingFlags() |
219 | : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false), |
220 | printGenericOpFormFlag(false), skipRegionsFlag(false), |
221 | assumeVerifiedFlag(false), printLocalScope(false), |
222 | printValueUsersFlag(false), printUniqueSSAIDsFlag(false), |
223 | useNameLocAsPrefix(false) { |
224 | // Initialize based upon command line options, if they are available. |
225 | if (!clOptions.isConstructed()) |
226 | return; |
227 | if (clOptions->elideElementsAttrIfLarger.getNumOccurrences()) |
228 | elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger; |
229 | if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) |
230 | elementsAttrHexElementLimit = |
231 | clOptions->printElementsAttrWithHexIfLarger.getValue(); |
232 | if (clOptions->elideResourceStringsIfLarger.getNumOccurrences()) |
233 | resourceStringCharLimit = clOptions->elideResourceStringsIfLarger; |
234 | printDebugInfoFlag = clOptions->printDebugInfoOpt; |
235 | printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt; |
236 | printGenericOpFormFlag = clOptions->printGenericOpFormOpt; |
237 | assumeVerifiedFlag = clOptions->assumeVerifiedOpt; |
238 | printLocalScope = clOptions->printLocalScopeOpt; |
239 | skipRegionsFlag = clOptions->skipRegionsOpt; |
240 | printValueUsersFlag = clOptions->printValueUsers; |
241 | printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs; |
242 | useNameLocAsPrefix = clOptions->useNameLocAsPrefix; |
243 | } |
244 | |
245 | /// Enable the elision of large elements attributes, by printing a '...' |
246 | /// instead of the element data, when the number of elements is greater than |
247 | /// `largeElementLimit`. Note: The IR generated with this option is not |
248 | /// parsable. |
249 | OpPrintingFlags & |
250 | OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) { |
251 | elementsAttrElementLimit = largeElementLimit; |
252 | return *this; |
253 | } |
254 | |
255 | OpPrintingFlags & |
256 | OpPrintingFlags::printLargeElementsAttrWithHex(int64_t largeElementLimit) { |
257 | elementsAttrHexElementLimit = largeElementLimit; |
258 | return *this; |
259 | } |
260 | |
261 | OpPrintingFlags & |
262 | OpPrintingFlags::elideLargeResourceString(int64_t largeResourceLimit) { |
263 | resourceStringCharLimit = largeResourceLimit; |
264 | return *this; |
265 | } |
266 | |
267 | /// Enable printing of debug information. If 'prettyForm' is set to true, |
268 | /// debug information is printed in a more readable 'pretty' form. |
269 | OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool enable, |
270 | bool prettyForm) { |
271 | printDebugInfoFlag = enable; |
272 | printDebugInfoPrettyFormFlag = prettyForm; |
273 | return *this; |
274 | } |
275 | |
276 | /// Always print operations in the generic form. |
277 | OpPrintingFlags &OpPrintingFlags::printGenericOpForm(bool enable) { |
278 | printGenericOpFormFlag = enable; |
279 | return *this; |
280 | } |
281 | |
282 | /// Always skip Regions. |
283 | OpPrintingFlags &OpPrintingFlags::skipRegions(bool skip) { |
284 | skipRegionsFlag = skip; |
285 | return *this; |
286 | } |
287 | |
288 | /// Do not verify the operation when using custom operation printers. |
289 | OpPrintingFlags &OpPrintingFlags::assumeVerified(bool enable) { |
290 | assumeVerifiedFlag = enable; |
291 | return *this; |
292 | } |
293 | |
294 | /// Use local scope when printing the operation. This allows for using the |
295 | /// printer in a more localized and thread-safe setting, but may not necessarily |
296 | /// be identical of what the IR will look like when dumping the full module. |
297 | OpPrintingFlags &OpPrintingFlags::useLocalScope(bool enable) { |
298 | printLocalScope = enable; |
299 | return *this; |
300 | } |
301 | |
302 | /// Print users of values as comments. |
303 | OpPrintingFlags &OpPrintingFlags::printValueUsers(bool enable) { |
304 | printValueUsersFlag = enable; |
305 | return *this; |
306 | } |
307 | |
308 | /// Print unique SSA ID numbers for values, block arguments and naming conflicts |
309 | /// across all regions |
310 | OpPrintingFlags &OpPrintingFlags::printUniqueSSAIDs(bool enable) { |
311 | printUniqueSSAIDsFlag = enable; |
312 | return *this; |
313 | } |
314 | |
315 | /// Return if the given ElementsAttr should be elided. |
316 | bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const { |
317 | return elementsAttrElementLimit && |
318 | *elementsAttrElementLimit < int64_t(attr.getNumElements()) && |
319 | !llvm::isa<SplatElementsAttr>(attr); |
320 | } |
321 | |
322 | /// Return if the given ElementsAttr should be printed as hex string. |
323 | bool OpPrintingFlags::shouldPrintElementsAttrWithHex(ElementsAttr attr) const { |
324 | // -1 is used to disable hex printing. |
325 | return (elementsAttrHexElementLimit != -1) && |
326 | (elementsAttrHexElementLimit < int64_t(attr.getNumElements())) && |
327 | !llvm::isa<SplatElementsAttr>(attr); |
328 | } |
329 | |
330 | OpPrintingFlags &OpPrintingFlags::printNameLocAsPrefix(bool enable) { |
331 | useNameLocAsPrefix = enable; |
332 | return *this; |
333 | } |
334 | |
335 | /// Return the size limit for printing large ElementsAttr. |
336 | std::optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const { |
337 | return elementsAttrElementLimit; |
338 | } |
339 | |
340 | /// Return the size limit for printing large ElementsAttr as hex string. |
341 | int64_t OpPrintingFlags::getLargeElementsAttrHexLimit() const { |
342 | return elementsAttrHexElementLimit; |
343 | } |
344 | |
345 | /// Return the size limit for printing large ElementsAttr. |
346 | std::optional<uint64_t> OpPrintingFlags::getLargeResourceStringLimit() const { |
347 | return resourceStringCharLimit; |
348 | } |
349 | |
350 | /// Return if debug information should be printed. |
351 | bool OpPrintingFlags::shouldPrintDebugInfo() const { |
352 | return printDebugInfoFlag; |
353 | } |
354 | |
355 | /// Return if debug information should be printed in the pretty form. |
356 | bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const { |
357 | return printDebugInfoPrettyFormFlag; |
358 | } |
359 | |
360 | /// Return if operations should be printed in the generic form. |
361 | bool OpPrintingFlags::shouldPrintGenericOpForm() const { |
362 | return printGenericOpFormFlag; |
363 | } |
364 | |
365 | /// Return if Region should be skipped. |
366 | bool OpPrintingFlags::shouldSkipRegions() const { return skipRegionsFlag; } |
367 | |
368 | /// Return if operation verification should be skipped. |
369 | bool OpPrintingFlags::shouldAssumeVerified() const { |
370 | return assumeVerifiedFlag; |
371 | } |
372 | |
373 | /// Return if the printer should use local scope when dumping the IR. |
374 | bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; } |
375 | |
376 | /// Return if the printer should print users of values. |
377 | bool OpPrintingFlags::shouldPrintValueUsers() const { |
378 | return printValueUsersFlag; |
379 | } |
380 | |
381 | /// Return if the printer should use unique IDs. |
382 | bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const { |
383 | return printUniqueSSAIDsFlag || shouldPrintGenericOpForm(); |
384 | } |
385 | |
386 | /// Return if the printer should use NameLocs as prefixes when printing SSA IDs. |
387 | bool OpPrintingFlags::shouldUseNameLocAsPrefix() const { |
388 | return useNameLocAsPrefix; |
389 | } |
390 | |
391 | //===----------------------------------------------------------------------===// |
392 | // NewLineCounter |
393 | //===----------------------------------------------------------------------===// |
394 | |
395 | namespace { |
396 | /// This class is a simple formatter that emits a new line when inputted into a |
397 | /// stream, that enables counting the number of newlines emitted. This class |
398 | /// should be used whenever emitting newlines in the printer. |
399 | struct NewLineCounter { |
400 | unsigned curLine = 1; |
401 | }; |
402 | |
403 | static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) { |
404 | ++newLine.curLine; |
405 | return os << '\n'; |
406 | } |
407 | } // namespace |
408 | |
409 | //===----------------------------------------------------------------------===// |
410 | // AsmPrinter::Impl |
411 | //===----------------------------------------------------------------------===// |
412 | |
413 | namespace mlir { |
414 | class AsmPrinter::Impl { |
415 | public: |
416 | Impl(raw_ostream &os, AsmStateImpl &state); |
417 | explicit Impl(Impl &other) : Impl(other.os, other.state) {} |
418 | |
419 | /// Returns the output stream of the printer. |
420 | raw_ostream &getStream() { return os; } |
421 | |
422 | template <typename Container, typename UnaryFunctor> |
423 | inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const { |
424 | llvm::interleaveComma(c, os, eachFn); |
425 | } |
426 | |
427 | /// This enum describes the different kinds of elision for the type of an |
428 | /// attribute when printing it. |
429 | enum class AttrTypeElision { |
430 | /// The type must not be elided, |
431 | Never, |
432 | /// The type may be elided when it matches the default used in the parser |
433 | /// (for example i64 is the default for integer attributes). |
434 | May, |
435 | /// The type must be elided. |
436 | Must |
437 | }; |
438 | |
439 | /// Print the given attribute or an alias. |
440 | void printAttribute(Attribute attr, |
441 | AttrTypeElision typeElision = AttrTypeElision::Never); |
442 | /// Print the given attribute without considering an alias. |
443 | void printAttributeImpl(Attribute attr, |
444 | AttrTypeElision typeElision = AttrTypeElision::Never); |
445 | |
446 | /// Print the alias for the given attribute, return failure if no alias could |
447 | /// be printed. |
448 | LogicalResult printAlias(Attribute attr); |
449 | |
450 | /// Print the given type or an alias. |
451 | void printType(Type type); |
452 | /// Print the given type. |
453 | void printTypeImpl(Type type); |
454 | |
455 | /// Print the alias for the given type, return failure if no alias could |
456 | /// be printed. |
457 | LogicalResult printAlias(Type type); |
458 | |
459 | /// Print the given location to the stream. If `allowAlias` is true, this |
460 | /// allows for the internal location to use an attribute alias. |
461 | void printLocation(LocationAttr loc, bool allowAlias = false); |
462 | |
463 | /// Print a reference to the given resource that is owned by the given |
464 | /// dialect. |
465 | void printResourceHandle(const AsmDialectResourceHandle &resource); |
466 | |
467 | void printAffineMap(AffineMap map); |
468 | void |
469 | printAffineExpr(AffineExpr expr, |
470 | function_ref<void(unsigned, bool)> printValueName = nullptr); |
471 | void printAffineConstraint(AffineExpr expr, bool isEq); |
472 | void printIntegerSet(IntegerSet set); |
473 | |
474 | LogicalResult pushCyclicPrinting(const void *opaquePointer); |
475 | |
476 | void popCyclicPrinting(); |
477 | |
478 | void printDimensionList(ArrayRef<int64_t> shape); |
479 | |
480 | protected: |
481 | void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
482 | ArrayRef<StringRef> elidedAttrs = {}, |
483 | bool withKeyword = false); |
484 | void printNamedAttribute(NamedAttribute attr); |
485 | void printTrailingLocation(Location loc, bool allowAlias = true); |
486 | void printLocationInternal(LocationAttr loc, bool pretty = false, |
487 | bool isTopLevel = false); |
488 | |
489 | /// Print a dense elements attribute. If 'allowHex' is true, a hex string is |
490 | /// used instead of individual elements when the elements attr is large. |
491 | void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex); |
492 | |
493 | /// Print a dense string elements attribute. |
494 | void printDenseStringElementsAttr(DenseStringElementsAttr attr); |
495 | |
496 | /// Print a dense elements attribute. If 'allowHex' is true, a hex string is |
497 | /// used instead of individual elements when the elements attr is large. |
498 | void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, |
499 | bool allowHex); |
500 | |
501 | /// Print a dense array attribute. |
502 | void printDenseArrayAttr(DenseArrayAttr attr); |
503 | |
504 | void printDialectAttribute(Attribute attr); |
505 | void printDialectType(Type type); |
506 | |
507 | /// Print an escaped string, wrapped with "". |
508 | void printEscapedString(StringRef str); |
509 | |
510 | /// Print a hex string, wrapped with "". |
511 | void printHexString(StringRef str); |
512 | void printHexString(ArrayRef<char> data); |
513 | |
514 | /// This enum is used to represent the binding strength of the enclosing |
515 | /// context that an AffineExprStorage is being printed in, so we can |
516 | /// intelligently produce parens. |
517 | enum class BindingStrength { |
518 | Weak, // + and - |
519 | Strong, // All other binary operators. |
520 | }; |
521 | void printAffineExprInternal( |
522 | AffineExpr expr, BindingStrength enclosingTightness, |
523 | function_ref<void(unsigned, bool)> printValueName = nullptr); |
524 | |
525 | /// The output stream for the printer. |
526 | raw_ostream &os; |
527 | |
528 | /// An underlying assembly printer state. |
529 | AsmStateImpl &state; |
530 | |
531 | /// A set of flags to control the printer's behavior. |
532 | OpPrintingFlags printerFlags; |
533 | |
534 | /// A tracker for the number of new lines emitted during printing. |
535 | NewLineCounter newLine; |
536 | }; |
537 | } // namespace mlir |
538 | |
539 | //===----------------------------------------------------------------------===// |
540 | // AliasInitializer |
541 | //===----------------------------------------------------------------------===// |
542 | |
543 | namespace { |
544 | /// This class represents a specific instance of a symbol Alias. |
545 | class SymbolAlias { |
546 | public: |
547 | SymbolAlias(StringRef name, uint32_t suffixIndex, bool isType, |
548 | bool isDeferrable) |
549 | : name(name), suffixIndex(suffixIndex), isType(isType), |
550 | isDeferrable(isDeferrable) {} |
551 | |
552 | /// Print this alias to the given stream. |
553 | void print(raw_ostream &os) const { |
554 | os << (isType ? "!" : "#" ) << name; |
555 | if (suffixIndex) { |
556 | if (isdigit(name.back())) |
557 | os << '_'; |
558 | os << suffixIndex; |
559 | } |
560 | } |
561 | |
562 | /// Returns true if this is a type alias. |
563 | bool isTypeAlias() const { return isType; } |
564 | |
565 | /// Returns true if this alias supports deferred resolution when parsing. |
566 | bool canBeDeferred() const { return isDeferrable; } |
567 | |
568 | private: |
569 | /// The main name of the alias. |
570 | StringRef name; |
571 | /// The suffix index of the alias. |
572 | uint32_t suffixIndex : 30; |
573 | /// A flag indicating whether this alias is for a type. |
574 | bool isType : 1; |
575 | /// A flag indicating whether this alias may be deferred or not. |
576 | bool isDeferrable : 1; |
577 | |
578 | public: |
579 | /// Used to avoid printing incomplete aliases for recursive types. |
580 | bool isPrinted = false; |
581 | }; |
582 | |
583 | /// This class represents a utility that initializes the set of attribute and |
584 | /// type aliases, without the need to store the extra information within the |
585 | /// main AliasState class or pass it around via function arguments. |
586 | class AliasInitializer { |
587 | public: |
588 | AliasInitializer( |
589 | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces, |
590 | llvm::BumpPtrAllocator &aliasAllocator) |
591 | : interfaces(interfaces), aliasAllocator(aliasAllocator), |
592 | aliasOS(aliasBuffer) {} |
593 | |
594 | void initialize(Operation *op, const OpPrintingFlags &printerFlags, |
595 | llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias); |
596 | |
597 | /// Visit the given attribute to see if it has an alias. `canBeDeferred` is |
598 | /// set to true if the originator of this attribute can resolve the alias |
599 | /// after parsing has completed (e.g. in the case of operation locations). |
600 | /// `elideType` indicates if the type of the attribute should be skipped when |
601 | /// looking for nested aliases. Returns the maximum alias depth of the |
602 | /// attribute, and the alias index of this attribute. |
603 | std::pair<size_t, size_t> visit(Attribute attr, bool canBeDeferred = false, |
604 | bool elideType = false) { |
605 | return visitImpl(value: attr, aliases, canBeDeferred, printArgs&: elideType); |
606 | } |
607 | |
608 | /// Visit the given type to see if it has an alias. `canBeDeferred` is |
609 | /// set to true if the originator of this attribute can resolve the alias |
610 | /// after parsing has completed. Returns the maximum alias depth of the type, |
611 | /// and the alias index of this type. |
612 | std::pair<size_t, size_t> visit(Type type, bool canBeDeferred = false) { |
613 | return visitImpl(value: type, aliases, canBeDeferred); |
614 | } |
615 | |
616 | private: |
617 | struct InProgressAliasInfo { |
618 | InProgressAliasInfo() |
619 | : aliasDepth(0), isType(false), canBeDeferred(false) {} |
620 | InProgressAliasInfo(StringRef alias) |
621 | : alias(alias), aliasDepth(1), isType(false), canBeDeferred(false) {} |
622 | |
623 | bool operator<(const InProgressAliasInfo &rhs) const { |
624 | // Order first by depth, then by attr/type kind, and then by name. |
625 | if (aliasDepth != rhs.aliasDepth) |
626 | return aliasDepth < rhs.aliasDepth; |
627 | if (isType != rhs.isType) |
628 | return isType; |
629 | return alias < rhs.alias; |
630 | } |
631 | |
632 | /// The alias for the attribute or type, or std::nullopt if the value has no |
633 | /// alias. |
634 | std::optional<StringRef> alias; |
635 | /// The alias depth of this attribute or type, i.e. an indication of the |
636 | /// relative ordering of when to print this alias. |
637 | unsigned aliasDepth : 30; |
638 | /// If this alias represents a type or an attribute. |
639 | bool isType : 1; |
640 | /// If this alias can be deferred or not. |
641 | bool canBeDeferred : 1; |
642 | /// Indices for child aliases. |
643 | SmallVector<size_t> childIndices; |
644 | }; |
645 | |
646 | /// Visit the given attribute or type to see if it has an alias. |
647 | /// `canBeDeferred` is set to true if the originator of this value can resolve |
648 | /// the alias after parsing has completed (e.g. in the case of operation |
649 | /// locations). Returns the maximum alias depth of the value, and its alias |
650 | /// index. |
651 | template <typename T, typename... PrintArgs> |
652 | std::pair<size_t, size_t> |
653 | visitImpl(T value, |
654 | llvm::MapVector<const void *, InProgressAliasInfo> &aliases, |
655 | bool canBeDeferred, PrintArgs &&...printArgs); |
656 | |
657 | /// Mark the given alias as non-deferrable. |
658 | void markAliasNonDeferrable(size_t aliasIndex); |
659 | |
660 | /// Try to generate an alias for the provided symbol. If an alias is |
661 | /// generated, the provided alias mapping and reverse mapping are updated. |
662 | template <typename T> |
663 | void generateAlias(T symbol, InProgressAliasInfo &alias, bool canBeDeferred); |
664 | |
665 | /// Uniques the given alias name within the printer by generating name index |
666 | /// used as alias name suffix. |
667 | static unsigned |
668 | uniqueAliasNameIndex(StringRef alias, llvm::StringMap<unsigned> &nameCounts, |
669 | llvm::StringSet<llvm::BumpPtrAllocator &> &usedAliases); |
670 | |
671 | /// Given a collection of aliases and symbols, initialize a mapping from a |
672 | /// symbol to a given alias. |
673 | static void initializeAliases( |
674 | llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols, |
675 | llvm::MapVector<const void *, SymbolAlias> &symbolToAlias); |
676 | |
677 | /// The set of asm interfaces within the context. |
678 | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces; |
679 | |
680 | /// An allocator used for alias names. |
681 | llvm::BumpPtrAllocator &aliasAllocator; |
682 | |
683 | /// The set of built aliases. |
684 | llvm::MapVector<const void *, InProgressAliasInfo> aliases; |
685 | |
686 | /// Storage and stream used when generating an alias. |
687 | SmallString<32> aliasBuffer; |
688 | llvm::raw_svector_ostream aliasOS; |
689 | }; |
690 | |
691 | /// This class implements a dummy OpAsmPrinter that doesn't print any output, |
692 | /// and merely collects the attributes and types that *would* be printed in a |
693 | /// normal print invocation so that we can generate proper aliases. This allows |
694 | /// for us to generate aliases only for the attributes and types that would be |
695 | /// in the output, and trims down unnecessary output. |
696 | class DummyAliasOperationPrinter : private OpAsmPrinter { |
697 | public: |
698 | explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags, |
699 | AliasInitializer &initializer) |
700 | : printerFlags(printerFlags), initializer(initializer) {} |
701 | |
702 | /// Prints the entire operation with the custom assembly form, if available, |
703 | /// or the generic assembly form, otherwise. |
704 | void printCustomOrGenericOp(Operation *op) override { |
705 | // Visit the operation location. |
706 | if (printerFlags.shouldPrintDebugInfo()) |
707 | initializer.visit(attr: op->getLoc(), /*canBeDeferred=*/true); |
708 | |
709 | // If requested, always print the generic form. |
710 | if (!printerFlags.shouldPrintGenericOpForm()) { |
711 | op->getName().printAssembly(op, p&: *this, /*defaultDialect=*/"" ); |
712 | return; |
713 | } |
714 | |
715 | // Otherwise print with the generic assembly form. |
716 | printGenericOp(op); |
717 | } |
718 | |
719 | private: |
720 | /// Print the given operation in the generic form. |
721 | void printGenericOp(Operation *op, bool printOpName = true) override { |
722 | // Consider nested operations for aliases. |
723 | if (!printerFlags.shouldSkipRegions()) { |
724 | for (Region ®ion : op->getRegions()) |
725 | printRegion(region, /*printEntryBlockArgs=*/true, |
726 | /*printBlockTerminators=*/true); |
727 | } |
728 | |
729 | // Visit all the types used in the operation. |
730 | for (Type type : op->getOperandTypes()) |
731 | printType(type); |
732 | for (Type type : op->getResultTypes()) |
733 | printType(type); |
734 | |
735 | // Consider the attributes of the operation for aliases. |
736 | for (const NamedAttribute &attr : op->getAttrs()) |
737 | printAttribute(attr: attr.getValue()); |
738 | } |
739 | |
740 | /// Print the given block. If 'printBlockArgs' is false, the arguments of the |
741 | /// block are not printed. If 'printBlockTerminator' is false, the terminator |
742 | /// operation of the block is not printed. |
743 | void print(Block *block, bool printBlockArgs = true, |
744 | bool printBlockTerminator = true) { |
745 | // Consider the types of the block arguments for aliases if 'printBlockArgs' |
746 | // is set to true. |
747 | if (printBlockArgs) { |
748 | for (BlockArgument arg : block->getArguments()) { |
749 | printType(type: arg.getType()); |
750 | |
751 | // Visit the argument location. |
752 | if (printerFlags.shouldPrintDebugInfo()) |
753 | // TODO: Allow deferring argument locations. |
754 | initializer.visit(attr: arg.getLoc(), /*canBeDeferred=*/false); |
755 | } |
756 | } |
757 | |
758 | // Consider the operations within this block, ignoring the terminator if |
759 | // requested. |
760 | bool hasTerminator = |
761 | !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>(); |
762 | auto range = llvm::make_range( |
763 | x: block->begin(), |
764 | y: std::prev(x: block->end(), |
765 | n: (!hasTerminator || printBlockTerminator) ? 0 : 1)); |
766 | for (Operation &op : range) |
767 | printCustomOrGenericOp(op: &op); |
768 | } |
769 | |
770 | /// Print the given region. |
771 | void printRegion(Region ®ion, bool printEntryBlockArgs, |
772 | bool printBlockTerminators, |
773 | bool printEmptyBlock = false) override { |
774 | if (region.empty()) |
775 | return; |
776 | if (printerFlags.shouldSkipRegions()) { |
777 | os << "{...}" ; |
778 | return; |
779 | } |
780 | |
781 | auto *entryBlock = ®ion.front(); |
782 | print(block: entryBlock, printBlockArgs: printEntryBlockArgs, printBlockTerminator: printBlockTerminators); |
783 | for (Block &b : llvm::drop_begin(RangeOrContainer&: region, N: 1)) |
784 | print(block: &b); |
785 | } |
786 | |
787 | void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, |
788 | bool omitType) override { |
789 | printType(type: arg.getType()); |
790 | // Visit the argument location. |
791 | if (printerFlags.shouldPrintDebugInfo()) |
792 | // TODO: Allow deferring argument locations. |
793 | initializer.visit(attr: arg.getLoc(), /*canBeDeferred=*/false); |
794 | } |
795 | |
796 | /// Consider the given type to be printed for an alias. |
797 | void printType(Type type) override { initializer.visit(type); } |
798 | |
799 | /// Consider the given attribute to be printed for an alias. |
800 | void printAttribute(Attribute attr) override { initializer.visit(attr); } |
801 | void printAttributeWithoutType(Attribute attr) override { |
802 | printAttribute(attr); |
803 | } |
804 | LogicalResult printAlias(Attribute attr) override { |
805 | initializer.visit(attr); |
806 | return success(); |
807 | } |
808 | LogicalResult printAlias(Type type) override { |
809 | initializer.visit(type); |
810 | return success(); |
811 | } |
812 | |
813 | /// Consider the given location to be printed for an alias. |
814 | void printOptionalLocationSpecifier(Location loc) override { |
815 | printAttribute(attr: loc); |
816 | } |
817 | |
818 | /// Print the given set of attributes with names not included within |
819 | /// 'elidedAttrs'. |
820 | void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
821 | ArrayRef<StringRef> elidedAttrs = {}) override { |
822 | if (attrs.empty()) |
823 | return; |
824 | if (elidedAttrs.empty()) { |
825 | for (const NamedAttribute &attr : attrs) |
826 | printAttribute(attr: attr.getValue()); |
827 | return; |
828 | } |
829 | llvm::SmallDenseSet<StringRef> (elidedAttrs.begin(), |
830 | elidedAttrs.end()); |
831 | for (const NamedAttribute &attr : attrs) |
832 | if (!elidedAttrsSet.contains(V: attr.getName().strref())) |
833 | printAttribute(attr: attr.getValue()); |
834 | } |
835 | void printOptionalAttrDictWithKeyword( |
836 | ArrayRef<NamedAttribute> attrs, |
837 | ArrayRef<StringRef> elidedAttrs = {}) override { |
838 | printOptionalAttrDict(attrs, elidedAttrs); |
839 | } |
840 | |
841 | /// Return a null stream as the output stream, this will ignore any data fed |
842 | /// to it. |
843 | raw_ostream &getStream() const override { return os; } |
844 | |
845 | /// The following are hooks of `OpAsmPrinter` that are not necessary for |
846 | /// determining potential aliases. |
847 | void printFloat(const APFloat &) override {} |
848 | void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {} |
849 | void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {} |
850 | void printNewline() override {} |
851 | void increaseIndent() override {} |
852 | void decreaseIndent() override {} |
853 | void printOperand(Value) override {} |
854 | void printOperand(Value, raw_ostream &os) override { |
855 | // Users expect the output string to have at least the prefixed % to signal |
856 | // a value name. To maintain this invariant, emit a name even if it is |
857 | // guaranteed to go unused. |
858 | os << "%" ; |
859 | } |
860 | void printKeywordOrString(StringRef) override {} |
861 | void printString(StringRef) override {} |
862 | void printResourceHandle(const AsmDialectResourceHandle &) override {} |
863 | void printSymbolName(StringRef) override {} |
864 | void printSuccessor(Block *) override {} |
865 | void printSuccessorAndUseList(Block *, ValueRange) override {} |
866 | void shadowRegionArgs(Region &, ValueRange) override {} |
867 | |
868 | /// The printer flags to use when determining potential aliases. |
869 | const OpPrintingFlags &printerFlags; |
870 | |
871 | /// The initializer to use when identifying aliases. |
872 | AliasInitializer &initializer; |
873 | |
874 | /// A dummy output stream. |
875 | mutable llvm::raw_null_ostream os; |
876 | }; |
877 | |
878 | class DummyAliasDialectAsmPrinter : public DialectAsmPrinter { |
879 | public: |
880 | explicit DummyAliasDialectAsmPrinter(AliasInitializer &initializer, |
881 | bool canBeDeferred, |
882 | SmallVectorImpl<size_t> &childIndices) |
883 | : initializer(initializer), canBeDeferred(canBeDeferred), |
884 | childIndices(childIndices) {} |
885 | |
886 | /// Print the given attribute/type, visiting any nested aliases that would be |
887 | /// generated as part of printing. Returns the maximum alias depth found while |
888 | /// printing the given value. |
889 | template <typename T, typename... PrintArgs> |
890 | size_t printAndVisitNestedAliases(T value, PrintArgs &&...printArgs) { |
891 | printAndVisitNestedAliasesImpl(value, printArgs...); |
892 | return maxAliasDepth; |
893 | } |
894 | |
895 | private: |
896 | /// Print the given attribute/type, visiting any nested aliases that would be |
897 | /// generated as part of printing. |
898 | void printAndVisitNestedAliasesImpl(Attribute attr, bool elideType) { |
899 | if (!isa<BuiltinDialect>(Val: attr.getDialect())) { |
900 | attr.getDialect().printAttribute(attr, *this); |
901 | |
902 | // Process the builtin attributes. |
903 | } else if (llvm::isa<AffineMapAttr, DenseArrayAttr, FloatAttr, IntegerAttr, |
904 | IntegerSetAttr, UnitAttr>(Val: attr)) { |
905 | return; |
906 | } else if (auto distinctAttr = dyn_cast<DistinctAttr>(attr)) { |
907 | printAttribute(attr: distinctAttr.getReferencedAttr()); |
908 | } else if (auto dictAttr = dyn_cast<DictionaryAttr>(attr)) { |
909 | for (const NamedAttribute &nestedAttr : dictAttr.getValue()) { |
910 | printAttribute(nestedAttr.getName()); |
911 | printAttribute(nestedAttr.getValue()); |
912 | } |
913 | } else if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) { |
914 | for (Attribute nestedAttr : arrayAttr.getValue()) |
915 | printAttribute(nestedAttr); |
916 | } else if (auto typeAttr = dyn_cast<TypeAttr>(attr)) { |
917 | printType(type: typeAttr.getValue()); |
918 | } else if (auto locAttr = dyn_cast<OpaqueLoc>(attr)) { |
919 | printAttribute(attr: locAttr.getFallbackLocation()); |
920 | } else if (auto locAttr = dyn_cast<NameLoc>(attr)) { |
921 | if (!isa<UnknownLoc>(locAttr.getChildLoc())) |
922 | printAttribute(attr: locAttr.getChildLoc()); |
923 | } else if (auto locAttr = dyn_cast<CallSiteLoc>(attr)) { |
924 | printAttribute(attr: locAttr.getCallee()); |
925 | printAttribute(attr: locAttr.getCaller()); |
926 | } else if (auto locAttr = dyn_cast<FusedLoc>(attr)) { |
927 | if (Attribute metadata = locAttr.getMetadata()) |
928 | printAttribute(attr: metadata); |
929 | for (Location nestedLoc : locAttr.getLocations()) |
930 | printAttribute(nestedLoc); |
931 | } |
932 | |
933 | // Don't print the type if we must elide it, or if it is a None type. |
934 | if (!elideType) { |
935 | if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) { |
936 | Type attrType = typedAttr.getType(); |
937 | if (!llvm::isa<NoneType>(Val: attrType)) |
938 | printType(type: attrType); |
939 | } |
940 | } |
941 | } |
942 | void printAndVisitNestedAliasesImpl(Type type) { |
943 | if (!isa<BuiltinDialect>(Val: type.getDialect())) |
944 | return type.getDialect().printType(type, *this); |
945 | |
946 | // Only visit the layout of memref if it isn't the identity. |
947 | if (auto memrefTy = llvm::dyn_cast<MemRefType>(type)) { |
948 | printType(type: memrefTy.getElementType()); |
949 | MemRefLayoutAttrInterface layout = memrefTy.getLayout(); |
950 | if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) |
951 | printAttribute(attr: memrefTy.getLayout()); |
952 | if (memrefTy.getMemorySpace()) |
953 | printAttribute(attr: memrefTy.getMemorySpace()); |
954 | return; |
955 | } |
956 | |
957 | // For most builtin types, we can simply walk the sub elements. |
958 | auto visitFn = [&](auto element) { |
959 | if (element) |
960 | (void)printAlias(element); |
961 | }; |
962 | type.walkImmediateSubElements(walkAttrsFn: visitFn, walkTypesFn: visitFn); |
963 | } |
964 | |
965 | /// Consider the given type to be printed for an alias. |
966 | void printType(Type type) override { |
967 | recordAliasResult(aliasDepthAndIndex: initializer.visit(type, canBeDeferred)); |
968 | } |
969 | |
970 | /// Consider the given attribute to be printed for an alias. |
971 | void printAttribute(Attribute attr) override { |
972 | recordAliasResult(aliasDepthAndIndex: initializer.visit(attr, canBeDeferred)); |
973 | } |
974 | void printAttributeWithoutType(Attribute attr) override { |
975 | recordAliasResult( |
976 | aliasDepthAndIndex: initializer.visit(attr, canBeDeferred, /*elideType=*/true)); |
977 | } |
978 | LogicalResult printAlias(Attribute attr) override { |
979 | printAttribute(attr); |
980 | return success(); |
981 | } |
982 | LogicalResult printAlias(Type type) override { |
983 | printType(type); |
984 | return success(); |
985 | } |
986 | |
987 | /// Record the alias result of a child element. |
988 | void recordAliasResult(std::pair<size_t, size_t> aliasDepthAndIndex) { |
989 | childIndices.push_back(Elt: aliasDepthAndIndex.second); |
990 | if (aliasDepthAndIndex.first > maxAliasDepth) |
991 | maxAliasDepth = aliasDepthAndIndex.first; |
992 | } |
993 | |
994 | /// Return a null stream as the output stream, this will ignore any data fed |
995 | /// to it. |
996 | raw_ostream &getStream() const override { return os; } |
997 | |
998 | /// The following are hooks of `DialectAsmPrinter` that are not necessary for |
999 | /// determining potential aliases. |
1000 | void printFloat(const APFloat &) override {} |
1001 | void printKeywordOrString(StringRef) override {} |
1002 | void printString(StringRef) override {} |
1003 | void printSymbolName(StringRef) override {} |
1004 | void printResourceHandle(const AsmDialectResourceHandle &) override {} |
1005 | |
1006 | LogicalResult pushCyclicPrinting(const void *opaquePointer) override { |
1007 | return success(IsSuccess: cyclicPrintingStack.insert(X: opaquePointer)); |
1008 | } |
1009 | |
1010 | void popCyclicPrinting() override { cyclicPrintingStack.pop_back(); } |
1011 | |
1012 | /// Stack of potentially cyclic mutable attributes or type currently being |
1013 | /// printed. |
1014 | SetVector<const void *> cyclicPrintingStack; |
1015 | |
1016 | /// The initializer to use when identifying aliases. |
1017 | AliasInitializer &initializer; |
1018 | |
1019 | /// If the aliases visited by this printer can be deferred. |
1020 | bool canBeDeferred; |
1021 | |
1022 | /// The indices of child aliases. |
1023 | SmallVectorImpl<size_t> &childIndices; |
1024 | |
1025 | /// The maximum alias depth found by the printer. |
1026 | size_t maxAliasDepth = 0; |
1027 | |
1028 | /// A dummy output stream. |
1029 | mutable llvm::raw_null_ostream os; |
1030 | }; |
1031 | } // namespace |
1032 | |
1033 | /// Sanitize the given name such that it can be used as a valid identifier. If |
1034 | /// the string needs to be modified in any way, the provided buffer is used to |
1035 | /// store the new copy, |
1036 | static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, |
1037 | StringRef allowedPunctChars = "$._-" ) { |
1038 | assert(!name.empty() && "Shouldn't have an empty name here" ); |
1039 | |
1040 | auto validChar = [&](char ch) { |
1041 | return llvm::isAlnum(C: ch) || allowedPunctChars.contains(C: ch); |
1042 | }; |
1043 | |
1044 | auto copyNameToBuffer = [&] { |
1045 | for (char ch : name) { |
1046 | if (validChar(ch)) |
1047 | buffer.push_back(Elt: ch); |
1048 | else if (ch == ' ') |
1049 | buffer.push_back(Elt: '_'); |
1050 | else |
1051 | buffer.append(RHS: llvm::utohexstr(X: (unsigned char)ch)); |
1052 | } |
1053 | }; |
1054 | |
1055 | // Check to see if this name is valid. If it starts with a digit, then it |
1056 | // could conflict with the autogenerated numeric ID's, so add an underscore |
1057 | // prefix to avoid problems. |
1058 | if (isdigit(name[0]) || (!validChar(name[0]) && name[0] != ' ')) { |
1059 | buffer.push_back(Elt: '_'); |
1060 | copyNameToBuffer(); |
1061 | return buffer; |
1062 | } |
1063 | |
1064 | // Check to see that the name consists of only valid identifier characters. |
1065 | for (char ch : name) { |
1066 | if (!validChar(ch)) { |
1067 | copyNameToBuffer(); |
1068 | return buffer; |
1069 | } |
1070 | } |
1071 | |
1072 | // If there are no invalid characters, return the original name. |
1073 | return name; |
1074 | } |
1075 | |
1076 | unsigned AliasInitializer::uniqueAliasNameIndex( |
1077 | StringRef alias, llvm::StringMap<unsigned> &nameCounts, |
1078 | llvm::StringSet<llvm::BumpPtrAllocator &> &usedAliases) { |
1079 | if (!usedAliases.count(Key: alias)) { |
1080 | usedAliases.insert(key: alias); |
1081 | // 0 is not printed in SymbolAlias. |
1082 | return 0; |
1083 | } |
1084 | // Otherwise, we had a conflict - probe until we find a unique name. |
1085 | SmallString<64> probeAlias(alias); |
1086 | // alias with trailing digit will be printed as _N |
1087 | if (isdigit(alias.back())) |
1088 | probeAlias.push_back(Elt: '_'); |
1089 | // nameCounts start from 1 because 0 is not printed in SymbolAlias. |
1090 | if (nameCounts[probeAlias] == 0) |
1091 | nameCounts[probeAlias] = 1; |
1092 | // This is guaranteed to terminate (and usually in a single iteration) |
1093 | // because it generates new names by incrementing nameCounts. |
1094 | while (true) { |
1095 | unsigned nameIndex = nameCounts[probeAlias]++; |
1096 | probeAlias += llvm::utostr(X: nameIndex); |
1097 | if (!usedAliases.count(Key: probeAlias)) { |
1098 | usedAliases.insert(key: probeAlias); |
1099 | return nameIndex; |
1100 | } |
1101 | // Reset probeAlias to the original alias for the next iteration. |
1102 | probeAlias.resize(N: alias.size() + isdigit(alias.back()) ? 1 : 0); |
1103 | } |
1104 | } |
1105 | |
1106 | /// Given a collection of aliases and symbols, initialize a mapping from a |
1107 | /// symbol to a given alias. |
1108 | void AliasInitializer::initializeAliases( |
1109 | llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols, |
1110 | llvm::MapVector<const void *, SymbolAlias> &symbolToAlias) { |
1111 | SmallVector<std::pair<const void *, InProgressAliasInfo>, 0> |
1112 | unprocessedAliases = visitedSymbols.takeVector(); |
1113 | llvm::stable_sort(Range&: unprocessedAliases, C: llvm::less_second()); |
1114 | |
1115 | // This keeps track of all of the non-numeric names that are in flight, |
1116 | // allowing us to check for duplicates. |
1117 | llvm::BumpPtrAllocator usedAliasAllocator; |
1118 | llvm::StringSet<llvm::BumpPtrAllocator &> usedAliases(usedAliasAllocator); |
1119 | |
1120 | llvm::StringMap<unsigned> nameCounts; |
1121 | for (auto &[symbol, aliasInfo] : unprocessedAliases) { |
1122 | if (!aliasInfo.alias) |
1123 | continue; |
1124 | StringRef alias = *aliasInfo.alias; |
1125 | unsigned nameIndex = uniqueAliasNameIndex(alias, nameCounts, usedAliases); |
1126 | symbolToAlias.insert( |
1127 | KV: {symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType, |
1128 | aliasInfo.canBeDeferred)}); |
1129 | } |
1130 | } |
1131 | |
1132 | void AliasInitializer::initialize( |
1133 | Operation *op, const OpPrintingFlags &printerFlags, |
1134 | llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias) { |
1135 | // Use a dummy printer when walking the IR so that we can collect the |
1136 | // attributes/types that will actually be used during printing when |
1137 | // considering aliases. |
1138 | DummyAliasOperationPrinter aliasPrinter(printerFlags, *this); |
1139 | aliasPrinter.printCustomOrGenericOp(op); |
1140 | |
1141 | // Initialize the aliases. |
1142 | initializeAliases(visitedSymbols&: aliases, symbolToAlias&: attrTypeToAlias); |
1143 | } |
1144 | |
1145 | template <typename T, typename... PrintArgs> |
1146 | std::pair<size_t, size_t> AliasInitializer::visitImpl( |
1147 | T value, llvm::MapVector<const void *, InProgressAliasInfo> &aliases, |
1148 | bool canBeDeferred, PrintArgs &&...printArgs) { |
1149 | auto [it, inserted] = aliases.try_emplace(value.getAsOpaquePointer()); |
1150 | size_t aliasIndex = std::distance(aliases.begin(), it); |
1151 | if (!inserted) { |
1152 | // Make sure that the alias isn't deferred if we don't permit it. |
1153 | if (!canBeDeferred) |
1154 | markAliasNonDeferrable(aliasIndex); |
1155 | return {static_cast<size_t>(it->second.aliasDepth), aliasIndex}; |
1156 | } |
1157 | |
1158 | // Try to generate an alias for this value. |
1159 | generateAlias(value, it->second, canBeDeferred); |
1160 | it->second.isType = std::is_base_of_v<Type, T>; |
1161 | it->second.canBeDeferred = canBeDeferred; |
1162 | |
1163 | // Print the value, capturing any nested elements that require aliases. |
1164 | SmallVector<size_t> childAliases; |
1165 | DummyAliasDialectAsmPrinter printer(*this, canBeDeferred, childAliases); |
1166 | size_t maxAliasDepth = |
1167 | printer.printAndVisitNestedAliases(value, printArgs...); |
1168 | |
1169 | // Make sure to recompute `it` in case the map was reallocated. |
1170 | it = std::next(x: aliases.begin(), n: aliasIndex); |
1171 | |
1172 | // If we had sub elements, update to account for the depth. |
1173 | it->second.childIndices = std::move(childAliases); |
1174 | if (maxAliasDepth) |
1175 | it->second.aliasDepth = maxAliasDepth + 1; |
1176 | |
1177 | // Propagate the alias depth of the value. |
1178 | return {(size_t)it->second.aliasDepth, aliasIndex}; |
1179 | } |
1180 | |
1181 | void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) { |
1182 | auto *it = std::next(x: aliases.begin(), n: aliasIndex); |
1183 | |
1184 | // If already marked non-deferrable stop the recursion. |
1185 | // All children should already be marked non-deferrable as well. |
1186 | if (!it->second.canBeDeferred) |
1187 | return; |
1188 | |
1189 | it->second.canBeDeferred = false; |
1190 | |
1191 | // Propagate the non-deferrable flag to any child aliases. |
1192 | for (size_t childIndex : it->second.childIndices) |
1193 | markAliasNonDeferrable(aliasIndex: childIndex); |
1194 | } |
1195 | |
1196 | template <typename T> |
1197 | void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias, |
1198 | bool canBeDeferred) { |
1199 | SmallString<32> nameBuffer; |
1200 | |
1201 | OpAsmDialectInterface::AliasResult symbolInterfaceResult = |
1202 | OpAsmDialectInterface::AliasResult::NoAlias; |
1203 | using InterfaceT = std::conditional_t<std::is_base_of_v<Attribute, T>, |
1204 | OpAsmAttrInterface, OpAsmTypeInterface>; |
1205 | if (auto symbolInterface = dyn_cast<InterfaceT>(symbol)) { |
1206 | symbolInterfaceResult = symbolInterface.getAlias(aliasOS); |
1207 | if (symbolInterfaceResult != OpAsmDialectInterface::AliasResult::NoAlias) { |
1208 | nameBuffer = std::move(aliasBuffer); |
1209 | assert(!nameBuffer.empty() && "expected valid alias name" ); |
1210 | } |
1211 | } |
1212 | |
1213 | if (symbolInterfaceResult != OpAsmDialectInterface::AliasResult::FinalAlias) { |
1214 | for (const auto &interface : interfaces) { |
1215 | OpAsmDialectInterface::AliasResult result = |
1216 | interface.getAlias(symbol, aliasOS); |
1217 | if (result == OpAsmDialectInterface::AliasResult::NoAlias) |
1218 | continue; |
1219 | nameBuffer = std::move(aliasBuffer); |
1220 | assert(!nameBuffer.empty() && "expected valid alias name" ); |
1221 | if (result == OpAsmDialectInterface::AliasResult::FinalAlias) |
1222 | break; |
1223 | } |
1224 | } |
1225 | |
1226 | if (nameBuffer.empty()) |
1227 | return; |
1228 | |
1229 | SmallString<16> tempBuffer; |
1230 | StringRef name = |
1231 | sanitizeIdentifier(name: nameBuffer, buffer&: tempBuffer, /*allowedPunctChars=*/"$_-" ); |
1232 | name = name.copy(A&: aliasAllocator); |
1233 | alias = InProgressAliasInfo(name); |
1234 | } |
1235 | |
1236 | //===----------------------------------------------------------------------===// |
1237 | // AliasState |
1238 | //===----------------------------------------------------------------------===// |
1239 | |
1240 | namespace { |
1241 | /// This class manages the state for type and attribute aliases. |
1242 | class AliasState { |
1243 | public: |
1244 | // Initialize the internal aliases. |
1245 | void |
1246 | initialize(Operation *op, const OpPrintingFlags &printerFlags, |
1247 | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); |
1248 | |
1249 | /// Get an alias for the given attribute if it has one and print it in `os`. |
1250 | /// Returns success if an alias was printed, failure otherwise. |
1251 | LogicalResult getAlias(Attribute attr, raw_ostream &os) const; |
1252 | |
1253 | /// Get an alias for the given type if it has one and print it in `os`. |
1254 | /// Returns success if an alias was printed, failure otherwise. |
1255 | LogicalResult getAlias(Type ty, raw_ostream &os) const; |
1256 | |
1257 | /// Print all of the referenced aliases that can not be resolved in a deferred |
1258 | /// manner. |
1259 | void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) { |
1260 | printAliases(p, newLine, /*isDeferred=*/false); |
1261 | } |
1262 | |
1263 | /// Print all of the referenced aliases that support deferred resolution. |
1264 | void printDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) { |
1265 | printAliases(p, newLine, /*isDeferred=*/true); |
1266 | } |
1267 | |
1268 | private: |
1269 | /// Print all of the referenced aliases that support the provided resolution |
1270 | /// behavior. |
1271 | void printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, |
1272 | bool isDeferred); |
1273 | |
1274 | /// Mapping between attribute/type and alias. |
1275 | llvm::MapVector<const void *, SymbolAlias> attrTypeToAlias; |
1276 | |
1277 | /// An allocator used for alias names. |
1278 | llvm::BumpPtrAllocator aliasAllocator; |
1279 | }; |
1280 | } // namespace |
1281 | |
1282 | void AliasState::initialize( |
1283 | Operation *op, const OpPrintingFlags &printerFlags, |
1284 | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { |
1285 | AliasInitializer initializer(interfaces, aliasAllocator); |
1286 | initializer.initialize(op, printerFlags, attrTypeToAlias); |
1287 | } |
1288 | |
1289 | LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const { |
1290 | const auto *it = attrTypeToAlias.find(Key: attr.getAsOpaquePointer()); |
1291 | if (it == attrTypeToAlias.end()) |
1292 | return failure(); |
1293 | it->second.print(os); |
1294 | return success(); |
1295 | } |
1296 | |
1297 | LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const { |
1298 | const auto *it = attrTypeToAlias.find(Key: ty.getAsOpaquePointer()); |
1299 | if (it == attrTypeToAlias.end()) |
1300 | return failure(); |
1301 | if (!it->second.isPrinted) |
1302 | return failure(); |
1303 | |
1304 | it->second.print(os); |
1305 | return success(); |
1306 | } |
1307 | |
1308 | void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, |
1309 | bool isDeferred) { |
1310 | auto filterFn = [=](const auto &aliasIt) { |
1311 | return aliasIt.second.canBeDeferred() == isDeferred; |
1312 | }; |
1313 | for (auto &[opaqueSymbol, alias] : |
1314 | llvm::make_filter_range(Range&: attrTypeToAlias, Pred: filterFn)) { |
1315 | alias.print(os&: p.getStream()); |
1316 | p.getStream() << " = " ; |
1317 | |
1318 | if (alias.isTypeAlias()) { |
1319 | Type type = Type::getFromOpaquePointer(pointer: opaqueSymbol); |
1320 | p.printTypeImpl(type); |
1321 | alias.isPrinted = true; |
1322 | } else { |
1323 | // TODO: Support nested aliases in mutable attributes. |
1324 | Attribute attr = Attribute::getFromOpaquePointer(ptr: opaqueSymbol); |
1325 | if (attr.hasTrait<AttributeTrait::IsMutable>()) |
1326 | p.getStream() << attr; |
1327 | else |
1328 | p.printAttributeImpl(attr); |
1329 | } |
1330 | |
1331 | p.getStream() << newLine; |
1332 | } |
1333 | } |
1334 | |
1335 | //===----------------------------------------------------------------------===// |
1336 | // SSANameState |
1337 | //===----------------------------------------------------------------------===// |
1338 | |
1339 | namespace { |
1340 | /// Info about block printing: a number which is its position in the visitation |
1341 | /// order, and a name that is used to print reference to it, e.g. ^bb42. |
1342 | struct BlockInfo { |
1343 | int ordering; |
1344 | StringRef name; |
1345 | }; |
1346 | |
1347 | /// This class manages the state of SSA value names. |
1348 | class SSANameState { |
1349 | public: |
1350 | /// A sentinel value used for values with names set. |
1351 | enum : unsigned { NameSentinel = ~0U }; |
1352 | |
1353 | SSANameState(Operation *op, const OpPrintingFlags &printerFlags); |
1354 | SSANameState() = default; |
1355 | |
1356 | /// Print the SSA identifier for the given value to 'stream'. If |
1357 | /// 'printResultNo' is true, it also presents the result number ('#' number) |
1358 | /// of this value. |
1359 | void printValueID(Value value, bool printResultNo, raw_ostream &stream) const; |
1360 | |
1361 | /// Print the operation identifier. |
1362 | void printOperationID(Operation *op, raw_ostream &stream) const; |
1363 | |
1364 | /// Return the result indices for each of the result groups registered by this |
1365 | /// operation, or empty if none exist. |
1366 | ArrayRef<int> getOpResultGroups(Operation *op); |
1367 | |
1368 | /// Get the info for the given block. |
1369 | BlockInfo getBlockInfo(Block *block); |
1370 | |
1371 | /// Renumber the arguments for the specified region to the same names as the |
1372 | /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for |
1373 | /// details. |
1374 | void shadowRegionArgs(Region ®ion, ValueRange namesToUse); |
1375 | |
1376 | private: |
1377 | /// Number the SSA values within the given IR unit. |
1378 | void numberValuesInRegion(Region ®ion); |
1379 | void numberValuesInBlock(Block &block); |
1380 | void numberValuesInOp(Operation &op); |
1381 | |
1382 | /// Given a result of an operation 'result', find the result group head |
1383 | /// 'lookupValue' and the result of 'result' within that group in |
1384 | /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group |
1385 | /// has more than 1 result. |
1386 | void getResultIDAndNumber(OpResult result, Value &lookupValue, |
1387 | std::optional<int> &lookupResultNo) const; |
1388 | |
1389 | /// Set a special value name for the given value. |
1390 | void setValueName(Value value, StringRef name); |
1391 | |
1392 | /// Uniques the given value name within the printer. If the given name |
1393 | /// conflicts, it is automatically renamed. |
1394 | StringRef uniqueValueName(StringRef name); |
1395 | |
1396 | /// This is the value ID for each SSA value. If this returns NameSentinel, |
1397 | /// then the valueID has an entry in valueNames. |
1398 | DenseMap<Value, unsigned> valueIDs; |
1399 | DenseMap<Value, StringRef> valueNames; |
1400 | |
1401 | /// When printing users of values, an operation without a result might |
1402 | /// be the user. This map holds ids for such operations. |
1403 | DenseMap<Operation *, unsigned> operationIDs; |
1404 | |
1405 | /// This is a map of operations that contain multiple named result groups, |
1406 | /// i.e. there may be multiple names for the results of the operation. The |
1407 | /// value of this map are the result numbers that start a result group. |
1408 | DenseMap<Operation *, SmallVector<int, 1>> opResultGroups; |
1409 | |
1410 | /// This maps blocks to there visitation number in the current region as well |
1411 | /// as the string representing their name. |
1412 | DenseMap<Block *, BlockInfo> blockNames; |
1413 | |
1414 | /// This keeps track of all of the non-numeric names that are in flight, |
1415 | /// allowing us to check for duplicates. |
1416 | /// Note: the value of the map is unused. |
1417 | llvm::ScopedHashTable<StringRef, char> usedNames; |
1418 | llvm::BumpPtrAllocator usedNameAllocator; |
1419 | |
1420 | /// This is the next value ID to assign in numbering. |
1421 | unsigned nextValueID = 0; |
1422 | /// This is the next ID to assign to a region entry block argument. |
1423 | unsigned nextArgumentID = 0; |
1424 | /// This is the next ID to assign when a name conflict is detected. |
1425 | unsigned nextConflictID = 0; |
1426 | |
1427 | /// These are the printing flags. They control, eg., whether to print in |
1428 | /// generic form. |
1429 | OpPrintingFlags printerFlags; |
1430 | }; |
1431 | } // namespace |
1432 | |
1433 | SSANameState::SSANameState(Operation *op, const OpPrintingFlags &printerFlags) |
1434 | : printerFlags(printerFlags) { |
1435 | llvm::SaveAndRestore valueIDSaver(nextValueID); |
1436 | llvm::SaveAndRestore argumentIDSaver(nextArgumentID); |
1437 | llvm::SaveAndRestore conflictIDSaver(nextConflictID); |
1438 | |
1439 | // The naming context includes `nextValueID`, `nextArgumentID`, |
1440 | // `nextConflictID` and `usedNames` scoped HashTable. This information is |
1441 | // carried from the parent region. |
1442 | using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy; |
1443 | using NamingContext = |
1444 | std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>; |
1445 | |
1446 | // Allocator for UsedNamesScopeTy |
1447 | llvm::BumpPtrAllocator allocator; |
1448 | |
1449 | // Add a scope for the top level operation. |
1450 | auto *topLevelNamesScope = |
1451 | new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames); |
1452 | |
1453 | SmallVector<NamingContext, 8> nameContext; |
1454 | for (Region ®ion : op->getRegions()) |
1455 | nameContext.push_back(Elt: std::make_tuple(args: ®ion, args&: nextValueID, args&: nextArgumentID, |
1456 | args&: nextConflictID, args&: topLevelNamesScope)); |
1457 | |
1458 | numberValuesInOp(op&: *op); |
1459 | |
1460 | while (!nameContext.empty()) { |
1461 | Region *region; |
1462 | UsedNamesScopeTy *parentScope; |
1463 | |
1464 | if (printerFlags.shouldPrintUniqueSSAIDs()) |
1465 | // To print unique SSA IDs, ignore saved ID counts from parent regions |
1466 | std::tie(args&: region, args: std::ignore, args: std::ignore, args: std::ignore, args&: parentScope) = |
1467 | nameContext.pop_back_val(); |
1468 | else |
1469 | std::tie(args&: region, args&: nextValueID, args&: nextArgumentID, args&: nextConflictID, |
1470 | args&: parentScope) = nameContext.pop_back_val(); |
1471 | |
1472 | // When we switch from one subtree to another, pop the scopes(needless) |
1473 | // until the parent scope. |
1474 | while (usedNames.getCurScope() != parentScope) { |
1475 | usedNames.getCurScope()->~UsedNamesScopeTy(); |
1476 | assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) && |
1477 | "top level parentScope must be a nullptr" ); |
1478 | } |
1479 | |
1480 | // Add a scope for the current region. |
1481 | auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>()) |
1482 | UsedNamesScopeTy(usedNames); |
1483 | |
1484 | numberValuesInRegion(region&: *region); |
1485 | |
1486 | for (Operation &op : region->getOps()) |
1487 | for (Region ®ion : op.getRegions()) |
1488 | nameContext.push_back(Elt: std::make_tuple(args: ®ion, args&: nextValueID, |
1489 | args&: nextArgumentID, args&: nextConflictID, |
1490 | args&: curNamesScope)); |
1491 | } |
1492 | |
1493 | // Manually remove all the scopes. |
1494 | while (usedNames.getCurScope() != nullptr) |
1495 | usedNames.getCurScope()->~UsedNamesScopeTy(); |
1496 | } |
1497 | |
1498 | void SSANameState::printValueID(Value value, bool printResultNo, |
1499 | raw_ostream &stream) const { |
1500 | if (!value) { |
1501 | stream << "<<NULL VALUE>>" ; |
1502 | return; |
1503 | } |
1504 | |
1505 | std::optional<int> resultNo; |
1506 | auto lookupValue = value; |
1507 | |
1508 | // If this is an operation result, collect the head lookup value of the result |
1509 | // group and the result number of 'result' within that group. |
1510 | if (OpResult result = dyn_cast<OpResult>(Val&: value)) |
1511 | getResultIDAndNumber(result, lookupValue, lookupResultNo&: resultNo); |
1512 | |
1513 | auto it = valueIDs.find(Val: lookupValue); |
1514 | if (it == valueIDs.end()) { |
1515 | stream << "<<UNKNOWN SSA VALUE>>" ; |
1516 | return; |
1517 | } |
1518 | |
1519 | stream << '%'; |
1520 | if (it->second != NameSentinel) { |
1521 | stream << it->second; |
1522 | } else { |
1523 | auto nameIt = valueNames.find(Val: lookupValue); |
1524 | assert(nameIt != valueNames.end() && "Didn't have a name entry?" ); |
1525 | stream << nameIt->second; |
1526 | } |
1527 | |
1528 | if (resultNo && printResultNo) |
1529 | stream << '#' << *resultNo; |
1530 | } |
1531 | |
1532 | void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const { |
1533 | auto it = operationIDs.find(Val: op); |
1534 | if (it == operationIDs.end()) { |
1535 | stream << "<<UNKNOWN OPERATION>>" ; |
1536 | } else { |
1537 | stream << '%' << it->second; |
1538 | } |
1539 | } |
1540 | |
1541 | ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) { |
1542 | auto it = opResultGroups.find(Val: op); |
1543 | return it == opResultGroups.end() ? ArrayRef<int>() : it->second; |
1544 | } |
1545 | |
1546 | BlockInfo SSANameState::getBlockInfo(Block *block) { |
1547 | auto it = blockNames.find(Val: block); |
1548 | BlockInfo invalidBlock{.ordering: -1, .name: "INVALIDBLOCK" }; |
1549 | return it != blockNames.end() ? it->second : invalidBlock; |
1550 | } |
1551 | |
1552 | void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { |
1553 | assert(!region.empty() && "cannot shadow arguments of an empty region" ); |
1554 | assert(region.getNumArguments() == namesToUse.size() && |
1555 | "incorrect number of names passed in" ); |
1556 | assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() && |
1557 | "only KnownIsolatedFromAbove ops can shadow names" ); |
1558 | |
1559 | SmallVector<char, 16> nameStr; |
1560 | for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { |
1561 | auto nameToUse = namesToUse[i]; |
1562 | if (nameToUse == nullptr) |
1563 | continue; |
1564 | auto nameToReplace = region.getArgument(i); |
1565 | |
1566 | nameStr.clear(); |
1567 | llvm::raw_svector_ostream nameStream(nameStr); |
1568 | printValueID(value: nameToUse, /*printResultNo=*/true, stream&: nameStream); |
1569 | |
1570 | // Entry block arguments should already have a pretty "arg" name. |
1571 | assert(valueIDs[nameToReplace] == NameSentinel); |
1572 | |
1573 | // Use the name without the leading %. |
1574 | auto name = StringRef(nameStream.str()).drop_front(); |
1575 | |
1576 | // Overwrite the name. |
1577 | valueNames[nameToReplace] = name.copy(A&: usedNameAllocator); |
1578 | } |
1579 | } |
1580 | |
1581 | namespace { |
1582 | /// Try to get value name from value's location, fallback to `name`. |
1583 | StringRef maybeGetValueNameFromLoc(Value value, StringRef name) { |
1584 | if (auto maybeNameLoc = value.getLoc()->findInstanceOf<NameLoc>()) |
1585 | return maybeNameLoc.getName(); |
1586 | return name; |
1587 | } |
1588 | } // namespace |
1589 | |
1590 | void SSANameState::numberValuesInRegion(Region ®ion) { |
1591 | // Indicates whether OpAsmOpInterface set a name. |
1592 | bool opAsmOpInterfaceUsed = false; |
1593 | auto setBlockArgNameFn = [&](Value arg, StringRef name) { |
1594 | assert(!valueIDs.count(arg) && "arg numbered multiple times" ); |
1595 | assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == ®ion && |
1596 | "arg not defined in current region" ); |
1597 | opAsmOpInterfaceUsed = true; |
1598 | if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix())) |
1599 | name = maybeGetValueNameFromLoc(value: arg, name); |
1600 | setValueName(value: arg, name); |
1601 | }; |
1602 | |
1603 | if (!printerFlags.shouldPrintGenericOpForm()) { |
1604 | if (Operation *op = region.getParentOp()) { |
1605 | if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op)) |
1606 | asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn); |
1607 | // If the OpAsmOpInterface didn't set a name, get name from the type. |
1608 | if (!opAsmOpInterfaceUsed) { |
1609 | for (BlockArgument arg : region.getArguments()) { |
1610 | if (auto interface = dyn_cast<OpAsmTypeInterface>(arg.getType())) { |
1611 | interface.getAsmName( |
1612 | [&](StringRef name) { setBlockArgNameFn(arg, name); }); |
1613 | } |
1614 | } |
1615 | } |
1616 | } |
1617 | } |
1618 | |
1619 | // Number the values within this region in a breadth-first order. |
1620 | unsigned nextBlockID = 0; |
1621 | for (auto &block : region) { |
1622 | // Each block gets a unique ID, and all of the operations within it get |
1623 | // numbered as well. |
1624 | auto blockInfoIt = blockNames.insert(KV: {&block, {.ordering: -1, .name: "" }}); |
1625 | if (blockInfoIt.second) { |
1626 | // This block hasn't been named through `getAsmBlockArgumentNames`, use |
1627 | // default `^bbNNN` format. |
1628 | std::string name; |
1629 | llvm::raw_string_ostream(name) << "^bb" << nextBlockID; |
1630 | blockInfoIt.first->second.name = StringRef(name).copy(A&: usedNameAllocator); |
1631 | } |
1632 | blockInfoIt.first->second.ordering = nextBlockID++; |
1633 | |
1634 | numberValuesInBlock(block); |
1635 | } |
1636 | } |
1637 | |
1638 | void SSANameState::numberValuesInBlock(Block &block) { |
1639 | // Number the block arguments. We give entry block arguments a special name |
1640 | // 'arg'. |
1641 | bool isEntryBlock = block.isEntryBlock(); |
1642 | SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "" ); |
1643 | llvm::raw_svector_ostream specialName(specialNameBuffer); |
1644 | for (auto arg : block.getArguments()) { |
1645 | if (valueIDs.count(Val: arg)) |
1646 | continue; |
1647 | if (isEntryBlock) { |
1648 | specialNameBuffer.resize(N: strlen(s: "arg" )); |
1649 | specialName << nextArgumentID++; |
1650 | } |
1651 | StringRef specialNameStr = specialName.str(); |
1652 | if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix())) |
1653 | specialNameStr = maybeGetValueNameFromLoc(value: arg, name: specialNameStr); |
1654 | setValueName(value: arg, name: specialNameStr); |
1655 | } |
1656 | |
1657 | // Number the operations in this block. |
1658 | for (auto &op : block) |
1659 | numberValuesInOp(op); |
1660 | } |
1661 | |
1662 | void SSANameState::numberValuesInOp(Operation &op) { |
1663 | // Function used to set the special result names for the operation. |
1664 | SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0); |
1665 | // Indicates whether OpAsmOpInterface set a name. |
1666 | bool opAsmOpInterfaceUsed = false; |
1667 | auto setResultNameFn = [&](Value result, StringRef name) { |
1668 | assert(!valueIDs.count(result) && "result numbered multiple times" ); |
1669 | assert(result.getDefiningOp() == &op && "result not defined by 'op'" ); |
1670 | opAsmOpInterfaceUsed = true; |
1671 | if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix())) |
1672 | name = maybeGetValueNameFromLoc(value: result, name); |
1673 | setValueName(value: result, name); |
1674 | |
1675 | // Record the result number for groups not anchored at 0. |
1676 | if (int resultNo = llvm::cast<OpResult>(Val&: result).getResultNumber()) |
1677 | resultGroups.push_back(Elt: resultNo); |
1678 | }; |
1679 | // Operations can customize the printing of block names in OpAsmOpInterface. |
1680 | auto setBlockNameFn = [&](Block *block, StringRef name) { |
1681 | assert(block->getParentOp() == &op && |
1682 | "getAsmBlockArgumentNames callback invoked on a block not directly " |
1683 | "nested under the current operation" ); |
1684 | assert(!blockNames.count(block) && "block numbered multiple times" ); |
1685 | SmallString<16> tmpBuffer{"^" }; |
1686 | name = sanitizeIdentifier(name, buffer&: tmpBuffer); |
1687 | if (name.data() != tmpBuffer.data()) { |
1688 | tmpBuffer.append(RHS: name); |
1689 | name = tmpBuffer.str(); |
1690 | } |
1691 | name = name.copy(A&: usedNameAllocator); |
1692 | blockNames[block] = {.ordering: -1, .name: name}; |
1693 | }; |
1694 | |
1695 | if (!printerFlags.shouldPrintGenericOpForm()) { |
1696 | if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) { |
1697 | asmInterface.getAsmBlockNames(setBlockNameFn); |
1698 | asmInterface.getAsmResultNames(setResultNameFn); |
1699 | } |
1700 | if (!opAsmOpInterfaceUsed) { |
1701 | // If the OpAsmOpInterface didn't set a name, and all results have |
1702 | // OpAsmTypeInterface, get names from types. |
1703 | bool allHaveOpAsmTypeInterface = |
1704 | llvm::all_of(Range: op.getResultTypes(), P: [&](Type type) { |
1705 | return isa<OpAsmTypeInterface>(type); |
1706 | }); |
1707 | if (allHaveOpAsmTypeInterface) { |
1708 | for (OpResult result : op.getResults()) { |
1709 | auto interface = cast<OpAsmTypeInterface>(result.getType()); |
1710 | interface.getAsmName( |
1711 | [&](StringRef name) { setResultNameFn(result, name); }); |
1712 | } |
1713 | } |
1714 | } |
1715 | } |
1716 | |
1717 | unsigned numResults = op.getNumResults(); |
1718 | if (numResults == 0) { |
1719 | // If value users should be printed, operations with no result need an id. |
1720 | if (printerFlags.shouldPrintValueUsers()) { |
1721 | if (operationIDs.try_emplace(Key: &op, Args&: nextValueID).second) |
1722 | ++nextValueID; |
1723 | } |
1724 | return; |
1725 | } |
1726 | Value resultBegin = op.getResult(idx: 0); |
1727 | |
1728 | if (printerFlags.shouldUseNameLocAsPrefix() && !valueIDs.count(Val: resultBegin)) { |
1729 | if (auto nameLoc = resultBegin.getLoc()->findInstanceOf<NameLoc>()) { |
1730 | setValueName(value: resultBegin, name: nameLoc.getName()); |
1731 | } |
1732 | } |
1733 | |
1734 | // If the first result wasn't numbered, give it a default number. |
1735 | if (valueIDs.try_emplace(Key: resultBegin, Args&: nextValueID).second) |
1736 | ++nextValueID; |
1737 | |
1738 | // If this operation has multiple result groups, mark it. |
1739 | if (resultGroups.size() != 1) { |
1740 | llvm::array_pod_sort(Start: resultGroups.begin(), End: resultGroups.end()); |
1741 | opResultGroups.try_emplace(Key: &op, Args: std::move(resultGroups)); |
1742 | } |
1743 | } |
1744 | |
1745 | void SSANameState::getResultIDAndNumber( |
1746 | OpResult result, Value &lookupValue, |
1747 | std::optional<int> &lookupResultNo) const { |
1748 | Operation *owner = result.getOwner(); |
1749 | if (owner->getNumResults() == 1) |
1750 | return; |
1751 | int resultNo = result.getResultNumber(); |
1752 | |
1753 | // If this operation has multiple result groups, we will need to find the |
1754 | // one corresponding to this result. |
1755 | auto resultGroupIt = opResultGroups.find(Val: owner); |
1756 | if (resultGroupIt == opResultGroups.end()) { |
1757 | // If not, just use the first result. |
1758 | lookupResultNo = resultNo; |
1759 | lookupValue = owner->getResult(idx: 0); |
1760 | return; |
1761 | } |
1762 | |
1763 | // Find the correct index using a binary search, as the groups are ordered. |
1764 | ArrayRef<int> resultGroups = resultGroupIt->second; |
1765 | const auto *it = llvm::upper_bound(Range&: resultGroups, Value&: resultNo); |
1766 | int groupResultNo = 0, groupSize = 0; |
1767 | |
1768 | // If there are no smaller elements, the last result group is the lookup. |
1769 | if (it == resultGroups.end()) { |
1770 | groupResultNo = resultGroups.back(); |
1771 | groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back(); |
1772 | } else { |
1773 | // Otherwise, the previous element is the lookup. |
1774 | groupResultNo = *std::prev(x: it); |
1775 | groupSize = *it - groupResultNo; |
1776 | } |
1777 | |
1778 | // We only record the result number for a group of size greater than 1. |
1779 | if (groupSize != 1) |
1780 | lookupResultNo = resultNo - groupResultNo; |
1781 | lookupValue = owner->getResult(idx: groupResultNo); |
1782 | } |
1783 | |
1784 | void SSANameState::setValueName(Value value, StringRef name) { |
1785 | // If the name is empty, the value uses the default numbering. |
1786 | if (name.empty()) { |
1787 | valueIDs[value] = nextValueID++; |
1788 | return; |
1789 | } |
1790 | |
1791 | valueIDs[value] = NameSentinel; |
1792 | valueNames[value] = uniqueValueName(name); |
1793 | } |
1794 | |
1795 | StringRef SSANameState::uniqueValueName(StringRef name) { |
1796 | SmallString<16> tmpBuffer; |
1797 | name = sanitizeIdentifier(name, buffer&: tmpBuffer); |
1798 | |
1799 | // Check to see if this name is already unique. |
1800 | if (!usedNames.count(Key: name)) { |
1801 | name = name.copy(A&: usedNameAllocator); |
1802 | } else { |
1803 | // Otherwise, we had a conflict - probe until we find a unique name. This |
1804 | // is guaranteed to terminate (and usually in a single iteration) because it |
1805 | // generates new names by incrementing nextConflictID. |
1806 | SmallString<64> probeName(name); |
1807 | probeName.push_back(Elt: '_'); |
1808 | while (true) { |
1809 | probeName += llvm::utostr(X: nextConflictID++); |
1810 | if (!usedNames.count(Key: probeName)) { |
1811 | name = probeName.str().copy(A&: usedNameAllocator); |
1812 | break; |
1813 | } |
1814 | probeName.resize(N: name.size() + 1); |
1815 | } |
1816 | } |
1817 | |
1818 | usedNames.insert(Key: name, Val: char()); |
1819 | return name; |
1820 | } |
1821 | |
1822 | //===----------------------------------------------------------------------===// |
1823 | // DistinctState |
1824 | //===----------------------------------------------------------------------===// |
1825 | |
1826 | namespace { |
1827 | /// This class manages the state for distinct attributes. |
1828 | class DistinctState { |
1829 | public: |
1830 | /// Returns a unique identifier for the given distinct attribute. |
1831 | uint64_t getId(DistinctAttr distinctAttr); |
1832 | |
1833 | private: |
1834 | uint64_t distinctCounter = 0; |
1835 | DenseMap<DistinctAttr, uint64_t> distinctAttrMap; |
1836 | }; |
1837 | } // namespace |
1838 | |
1839 | uint64_t DistinctState::getId(DistinctAttr distinctAttr) { |
1840 | auto [it, inserted] = |
1841 | distinctAttrMap.try_emplace(Key: distinctAttr, Args&: distinctCounter); |
1842 | if (inserted) |
1843 | distinctCounter++; |
1844 | return it->getSecond(); |
1845 | } |
1846 | |
1847 | //===----------------------------------------------------------------------===// |
1848 | // Resources |
1849 | //===----------------------------------------------------------------------===// |
1850 | |
1851 | AsmParsedResourceEntry::~AsmParsedResourceEntry() = default; |
1852 | AsmResourceBuilder::~AsmResourceBuilder() = default; |
1853 | AsmResourceParser::~AsmResourceParser() = default; |
1854 | AsmResourcePrinter::~AsmResourcePrinter() = default; |
1855 | |
1856 | StringRef mlir::toString(AsmResourceEntryKind kind) { |
1857 | switch (kind) { |
1858 | case AsmResourceEntryKind::Blob: |
1859 | return "blob" ; |
1860 | case AsmResourceEntryKind::Bool: |
1861 | return "bool" ; |
1862 | case AsmResourceEntryKind::String: |
1863 | return "string" ; |
1864 | } |
1865 | llvm_unreachable("unknown AsmResourceEntryKind" ); |
1866 | } |
1867 | |
1868 | AsmResourceParser &FallbackAsmResourceMap::getParserFor(StringRef key) { |
1869 | std::unique_ptr<ResourceCollection> &collection = keyToResources[key.str()]; |
1870 | if (!collection) |
1871 | collection = std::make_unique<ResourceCollection>(args&: key); |
1872 | return *collection; |
1873 | } |
1874 | |
1875 | std::vector<std::unique_ptr<AsmResourcePrinter>> |
1876 | FallbackAsmResourceMap::getPrinters() { |
1877 | std::vector<std::unique_ptr<AsmResourcePrinter>> printers; |
1878 | for (auto &it : keyToResources) { |
1879 | ResourceCollection *collection = it.second.get(); |
1880 | auto buildValues = [=](Operation *op, AsmResourceBuilder &builder) { |
1881 | return collection->buildResources(op, builder); |
1882 | }; |
1883 | printers.emplace_back( |
1884 | args: AsmResourcePrinter::fromCallable(name: collection->getName(), printFn&: buildValues)); |
1885 | } |
1886 | return printers; |
1887 | } |
1888 | |
1889 | LogicalResult FallbackAsmResourceMap::ResourceCollection::parseResource( |
1890 | AsmParsedResourceEntry &entry) { |
1891 | switch (entry.getKind()) { |
1892 | case AsmResourceEntryKind::Blob: { |
1893 | FailureOr<AsmResourceBlob> blob = entry.parseAsBlob(); |
1894 | if (failed(Result: blob)) |
1895 | return failure(); |
1896 | resources.emplace_back(Args: entry.getKey(), Args: std::move(*blob)); |
1897 | return success(); |
1898 | } |
1899 | case AsmResourceEntryKind::Bool: { |
1900 | FailureOr<bool> value = entry.parseAsBool(); |
1901 | if (failed(Result: value)) |
1902 | return failure(); |
1903 | resources.emplace_back(Args: entry.getKey(), Args&: *value); |
1904 | break; |
1905 | } |
1906 | case AsmResourceEntryKind::String: { |
1907 | FailureOr<std::string> str = entry.parseAsString(); |
1908 | if (failed(Result: str)) |
1909 | return failure(); |
1910 | resources.emplace_back(Args: entry.getKey(), Args: std::move(*str)); |
1911 | break; |
1912 | } |
1913 | } |
1914 | return success(); |
1915 | } |
1916 | |
1917 | void FallbackAsmResourceMap::ResourceCollection::buildResources( |
1918 | Operation *op, AsmResourceBuilder &builder) const { |
1919 | for (const auto &entry : resources) { |
1920 | if (const auto *value = std::get_if<AsmResourceBlob>(ptr: &entry.value)) |
1921 | builder.buildBlob(key: entry.key, blob: *value); |
1922 | else if (const auto *value = std::get_if<bool>(ptr: &entry.value)) |
1923 | builder.buildBool(key: entry.key, data: *value); |
1924 | else if (const auto *value = std::get_if<std::string>(ptr: &entry.value)) |
1925 | builder.buildString(key: entry.key, data: *value); |
1926 | else |
1927 | llvm_unreachable("unknown AsmResourceEntryKind" ); |
1928 | } |
1929 | } |
1930 | |
1931 | //===----------------------------------------------------------------------===// |
1932 | // AsmState |
1933 | //===----------------------------------------------------------------------===// |
1934 | |
1935 | namespace mlir { |
1936 | namespace detail { |
1937 | class AsmStateImpl { |
1938 | public: |
1939 | explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags, |
1940 | AsmState::LocationMap *locationMap) |
1941 | : interfaces(op->getContext()), nameState(op, printerFlags), |
1942 | printerFlags(printerFlags), locationMap(locationMap) {} |
1943 | explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags, |
1944 | AsmState::LocationMap *locationMap) |
1945 | : interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {} |
1946 | |
1947 | /// Initialize the alias state to enable the printing of aliases. |
1948 | void initializeAliases(Operation *op) { |
1949 | aliasState.initialize(op, printerFlags, interfaces); |
1950 | } |
1951 | |
1952 | /// Get the state used for aliases. |
1953 | AliasState &getAliasState() { return aliasState; } |
1954 | |
1955 | /// Get the state used for SSA names. |
1956 | SSANameState &getSSANameState() { return nameState; } |
1957 | |
1958 | /// Get the state used for distinct attribute identifiers. |
1959 | DistinctState &getDistinctState() { return distinctState; } |
1960 | |
1961 | /// Return the dialects within the context that implement |
1962 | /// OpAsmDialectInterface. |
1963 | DialectInterfaceCollection<OpAsmDialectInterface> &getDialectInterfaces() { |
1964 | return interfaces; |
1965 | } |
1966 | |
1967 | /// Return the non-dialect resource printers. |
1968 | auto getResourcePrinters() { |
1969 | return llvm::make_pointee_range(Range&: externalResourcePrinters); |
1970 | } |
1971 | |
1972 | /// Get the printer flags. |
1973 | const OpPrintingFlags &getPrinterFlags() const { return printerFlags; } |
1974 | |
1975 | /// Register the location, line and column, within the buffer that the given |
1976 | /// operation was printed at. |
1977 | void registerOperationLocation(Operation *op, unsigned line, unsigned col) { |
1978 | if (locationMap) |
1979 | (*locationMap)[op] = std::make_pair(x&: line, y&: col); |
1980 | } |
1981 | |
1982 | /// Return the referenced dialect resources within the printer. |
1983 | DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> & |
1984 | getDialectResources() { |
1985 | return dialectResources; |
1986 | } |
1987 | |
1988 | LogicalResult pushCyclicPrinting(const void *opaquePointer) { |
1989 | return success(IsSuccess: cyclicPrintingStack.insert(X: opaquePointer)); |
1990 | } |
1991 | |
1992 | void popCyclicPrinting() { cyclicPrintingStack.pop_back(); } |
1993 | |
1994 | private: |
1995 | /// Collection of OpAsm interfaces implemented in the context. |
1996 | DialectInterfaceCollection<OpAsmDialectInterface> interfaces; |
1997 | |
1998 | /// A collection of non-dialect resource printers. |
1999 | SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters; |
2000 | |
2001 | /// A set of dialect resources that were referenced during printing. |
2002 | DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> dialectResources; |
2003 | |
2004 | /// The state used for attribute and type aliases. |
2005 | AliasState aliasState; |
2006 | |
2007 | /// The state used for SSA value names. |
2008 | SSANameState nameState; |
2009 | |
2010 | /// The state used for distinct attribute identifiers. |
2011 | DistinctState distinctState; |
2012 | |
2013 | /// Flags that control op output. |
2014 | OpPrintingFlags printerFlags; |
2015 | |
2016 | /// An optional location map to be populated. |
2017 | AsmState::LocationMap *locationMap; |
2018 | |
2019 | /// Stack of potentially cyclic mutable attributes or type currently being |
2020 | /// printed. |
2021 | SetVector<const void *> cyclicPrintingStack; |
2022 | |
2023 | // Allow direct access to the impl fields. |
2024 | friend AsmState; |
2025 | }; |
2026 | |
2027 | template <typename Range> |
2028 | void printDimensionList(raw_ostream &stream, Range &&shape) { |
2029 | llvm::interleave( |
2030 | shape, stream, |
2031 | [&stream](const auto &dimSize) { |
2032 | if (ShapedType::isDynamic(dimSize)) |
2033 | stream << "?" ; |
2034 | else |
2035 | stream << dimSize; |
2036 | }, |
2037 | "x" ); |
2038 | } |
2039 | |
2040 | } // namespace detail |
2041 | } // namespace mlir |
2042 | |
2043 | /// Verifies the operation and switches to generic op printing if verification |
2044 | /// fails. We need to do this because custom print functions may fail for |
2045 | /// invalid ops. |
2046 | static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op, |
2047 | OpPrintingFlags printerFlags) { |
2048 | if (printerFlags.shouldPrintGenericOpForm() || |
2049 | printerFlags.shouldAssumeVerified()) |
2050 | return printerFlags; |
2051 | |
2052 | // Ignore errors emitted by the verifier. We check the thread id to avoid |
2053 | // consuming other threads' errors. |
2054 | auto parentThreadId = llvm::get_threadid(); |
2055 | ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) { |
2056 | if (parentThreadId == llvm::get_threadid()) { |
2057 | LLVM_DEBUG({ |
2058 | diag.print(llvm::dbgs()); |
2059 | llvm::dbgs() << "\n" ; |
2060 | }); |
2061 | return success(); |
2062 | } |
2063 | return failure(); |
2064 | }); |
2065 | if (failed(Result: verify(op))) { |
2066 | LLVM_DEBUG(llvm::dbgs() |
2067 | << DEBUG_TYPE << ": '" << op->getName() |
2068 | << "' failed to verify and will be printed in generic form\n" ); |
2069 | printerFlags.printGenericOpForm(); |
2070 | } |
2071 | |
2072 | return printerFlags; |
2073 | } |
2074 | |
2075 | AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags, |
2076 | LocationMap *locationMap, FallbackAsmResourceMap *map) |
2077 | : impl(std::make_unique<AsmStateImpl>( |
2078 | args&: op, args: verifyOpAndAdjustFlags(op, printerFlags), args&: locationMap)) { |
2079 | if (map) |
2080 | attachFallbackResourcePrinter(map&: *map); |
2081 | } |
2082 | AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags, |
2083 | LocationMap *locationMap, FallbackAsmResourceMap *map) |
2084 | : impl(std::make_unique<AsmStateImpl>(args&: ctx, args: printerFlags, args&: locationMap)) { |
2085 | if (map) |
2086 | attachFallbackResourcePrinter(map&: *map); |
2087 | } |
2088 | AsmState::~AsmState() = default; |
2089 | |
2090 | const OpPrintingFlags &AsmState::getPrinterFlags() const { |
2091 | return impl->getPrinterFlags(); |
2092 | } |
2093 | |
2094 | void AsmState::attachResourcePrinter( |
2095 | std::unique_ptr<AsmResourcePrinter> printer) { |
2096 | impl->externalResourcePrinters.emplace_back(Args: std::move(printer)); |
2097 | } |
2098 | |
2099 | DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> & |
2100 | AsmState::getDialectResources() const { |
2101 | return impl->getDialectResources(); |
2102 | } |
2103 | |
2104 | //===----------------------------------------------------------------------===// |
2105 | // AsmPrinter::Impl |
2106 | //===----------------------------------------------------------------------===// |
2107 | |
2108 | AsmPrinter::Impl::Impl(raw_ostream &os, AsmStateImpl &state) |
2109 | : os(os), state(state), printerFlags(state.getPrinterFlags()) {} |
2110 | |
2111 | void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) { |
2112 | // Check to see if we are printing debug information. |
2113 | if (!printerFlags.shouldPrintDebugInfo()) |
2114 | return; |
2115 | |
2116 | os << " " ; |
2117 | printLocation(loc, /*allowAlias=*/allowAlias); |
2118 | } |
2119 | |
2120 | void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty, |
2121 | bool isTopLevel) { |
2122 | // If this isn't a top-level location, check for an alias. |
2123 | if (!isTopLevel && succeeded(Result: state.getAliasState().getAlias(attr: loc, os))) |
2124 | return; |
2125 | |
2126 | TypeSwitch<LocationAttr>(loc) |
2127 | .Case<OpaqueLoc>([&](OpaqueLoc loc) { |
2128 | printLocationInternal(loc.getFallbackLocation(), pretty); |
2129 | }) |
2130 | .Case<UnknownLoc>([&](UnknownLoc loc) { |
2131 | if (pretty) |
2132 | os << "[unknown]" ; |
2133 | else |
2134 | os << "unknown" ; |
2135 | }) |
2136 | .Case<FileLineColRange>([&](FileLineColRange loc) { |
2137 | if (pretty) |
2138 | os << loc.getFilename().getValue(); |
2139 | else |
2140 | printEscapedString(loc.getFilename()); |
2141 | if (loc.getEndColumn() == loc.getStartColumn() && |
2142 | loc.getStartLine() == loc.getEndLine()) { |
2143 | os << ':' << loc.getStartLine() << ':' << loc.getStartColumn(); |
2144 | return; |
2145 | } |
2146 | if (loc.getStartLine() == loc.getEndLine()) { |
2147 | os << ':' << loc.getStartLine() << ':' << loc.getStartColumn() |
2148 | << " to :" << loc.getEndColumn(); |
2149 | return; |
2150 | } |
2151 | os << ':' << loc.getStartLine() << ':' << loc.getStartColumn() << " to " |
2152 | << loc.getEndLine() << ':' << loc.getEndColumn(); |
2153 | }) |
2154 | .Case<NameLoc>([&](NameLoc loc) { |
2155 | printEscapedString(loc.getName()); |
2156 | |
2157 | // Print the child if it isn't unknown. |
2158 | auto childLoc = loc.getChildLoc(); |
2159 | if (!llvm::isa<UnknownLoc>(childLoc)) { |
2160 | os << '('; |
2161 | printLocationInternal(childLoc, pretty); |
2162 | os << ')'; |
2163 | } |
2164 | }) |
2165 | .Case<CallSiteLoc>([&](CallSiteLoc loc) { |
2166 | Location caller = loc.getCaller(); |
2167 | Location callee = loc.getCallee(); |
2168 | if (!pretty) |
2169 | os << "callsite(" ; |
2170 | printLocationInternal(callee, pretty); |
2171 | if (pretty) { |
2172 | if (llvm::isa<NameLoc>(callee)) { |
2173 | if (llvm::isa<FileLineColLoc>(caller)) { |
2174 | os << " at " ; |
2175 | } else { |
2176 | os << newLine << " at " ; |
2177 | } |
2178 | } else { |
2179 | os << newLine << " at " ; |
2180 | } |
2181 | } else { |
2182 | os << " at " ; |
2183 | } |
2184 | printLocationInternal(caller, pretty); |
2185 | if (!pretty) |
2186 | os << ")" ; |
2187 | }) |
2188 | .Case<FusedLoc>([&](FusedLoc loc) { |
2189 | if (!pretty) |
2190 | os << "fused" ; |
2191 | if (Attribute metadata = loc.getMetadata()) { |
2192 | os << '<'; |
2193 | printAttribute(metadata); |
2194 | os << '>'; |
2195 | } |
2196 | os << '['; |
2197 | interleave( |
2198 | loc.getLocations(), |
2199 | [&](Location loc) { printLocationInternal(loc, pretty); }, |
2200 | [&]() { os << ", " ; }); |
2201 | os << ']'; |
2202 | }) |
2203 | .Default([&](LocationAttr loc) { |
2204 | // Assumes that this is a dialect-specific attribute and prints it |
2205 | // directly. |
2206 | printAttribute(loc); |
2207 | }); |
2208 | } |
2209 | |
2210 | /// Print a floating point value in a way that the parser will be able to |
2211 | /// round-trip losslessly. |
2212 | static void printFloatValue(const APFloat &apValue, raw_ostream &os, |
2213 | bool *printedHex = nullptr) { |
2214 | // We would like to output the FP constant value in exponential notation, |
2215 | // but we cannot do this if doing so will lose precision. Check here to |
2216 | // make sure that we only output it in exponential format if we can parse |
2217 | // the value back and get the same value. |
2218 | bool isInf = apValue.isInfinity(); |
2219 | bool isNaN = apValue.isNaN(); |
2220 | if (!isInf && !isNaN) { |
2221 | SmallString<128> strValue; |
2222 | apValue.toString(Str&: strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0, |
2223 | /*TruncateZero=*/false); |
2224 | |
2225 | // Check to make sure that the stringized number is not some string like |
2226 | // "Inf" or NaN, that atof will accept, but the lexer will not. Check |
2227 | // that the string matches the "[-+]?[0-9]" regex. |
2228 | assert(((strValue[0] >= '0' && strValue[0] <= '9') || |
2229 | ((strValue[0] == '-' || strValue[0] == '+') && |
2230 | (strValue[1] >= '0' && strValue[1] <= '9'))) && |
2231 | "[-+]?[0-9] regex does not match!" ); |
2232 | |
2233 | // Parse back the stringized version and check that the value is equal |
2234 | // (i.e., there is no precision loss). |
2235 | if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(RHS: apValue)) { |
2236 | os << strValue; |
2237 | return; |
2238 | } |
2239 | |
2240 | // If it is not, use the default format of APFloat instead of the |
2241 | // exponential notation. |
2242 | strValue.clear(); |
2243 | apValue.toString(Str&: strValue); |
2244 | |
2245 | // Make sure that we can parse the default form as a float. |
2246 | if (strValue.str().contains(C: '.')) { |
2247 | os << strValue; |
2248 | return; |
2249 | } |
2250 | } |
2251 | |
2252 | // Print special values in hexadecimal format. The sign bit should be included |
2253 | // in the literal. |
2254 | if (printedHex) |
2255 | *printedHex = true; |
2256 | SmallVector<char, 16> str; |
2257 | APInt apInt = apValue.bitcastToAPInt(); |
2258 | apInt.toString(Str&: str, /*Radix=*/16, /*Signed=*/false, |
2259 | /*formatAsCLiteral=*/true); |
2260 | os << str; |
2261 | } |
2262 | |
2263 | void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) { |
2264 | if (printerFlags.shouldPrintDebugInfoPrettyForm()) |
2265 | return printLocationInternal(loc, /*pretty=*/true, /*isTopLevel=*/true); |
2266 | |
2267 | os << "loc(" ; |
2268 | if (!allowAlias || failed(Result: printAlias(attr: loc))) |
2269 | printLocationInternal(loc, /*pretty=*/false, /*isTopLevel=*/true); |
2270 | os << ')'; |
2271 | } |
2272 | |
2273 | /// Returns true if the given dialect symbol data is simple enough to print in |
2274 | /// the pretty form. This is essentially when the symbol takes the form: |
2275 | /// identifier (`<` body `>`)? |
2276 | static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) { |
2277 | // The name must start with an identifier. |
2278 | if (symName.empty() || !isalpha(symName.front())) |
2279 | return false; |
2280 | |
2281 | // Ignore all the characters that are valid in an identifier in the symbol |
2282 | // name. |
2283 | symName = symName.drop_while( |
2284 | F: [](char c) { return llvm::isAlnum(C: c) || c == '.' || c == '_'; }); |
2285 | if (symName.empty()) |
2286 | return true; |
2287 | |
2288 | // If we got to an unexpected character, then it must be a <>. Check that the |
2289 | // rest of the symbol is wrapped within <>. |
2290 | return symName.front() == '<' && symName.back() == '>'; |
2291 | } |
2292 | |
2293 | /// Print the given dialect symbol to the stream. |
2294 | static void printDialectSymbol(raw_ostream &os, StringRef symPrefix, |
2295 | StringRef dialectName, StringRef symString) { |
2296 | os << symPrefix << dialectName; |
2297 | |
2298 | // If this symbol name is simple enough, print it directly in pretty form, |
2299 | // otherwise, we print it as an escaped string. |
2300 | if (isDialectSymbolSimpleEnoughForPrettyForm(symName: symString)) { |
2301 | os << '.' << symString; |
2302 | return; |
2303 | } |
2304 | |
2305 | os << '<' << symString << '>'; |
2306 | } |
2307 | |
2308 | /// Returns true if the given string can be represented as a bare identifier. |
2309 | static bool isBareIdentifier(StringRef name) { |
2310 | // By making this unsigned, the value passed in to isalnum will always be |
2311 | // in the range 0-255. This is important when building with MSVC because |
2312 | // its implementation will assert. This situation can arise when dealing |
2313 | // with UTF-8 multibyte characters. |
2314 | if (name.empty() || (!isalpha(name[0]) && name[0] != '_')) |
2315 | return false; |
2316 | return llvm::all_of(Range: name.drop_front(), P: [](unsigned char c) { |
2317 | return isalnum(c) || c == '_' || c == '$' || c == '.'; |
2318 | }); |
2319 | } |
2320 | |
2321 | /// Print the given string as a keyword, or a quoted and escaped string if it |
2322 | /// has any special or non-printable characters in it. |
2323 | static void printKeywordOrString(StringRef keyword, raw_ostream &os) { |
2324 | // If it can be represented as a bare identifier, write it directly. |
2325 | if (isBareIdentifier(name: keyword)) { |
2326 | os << keyword; |
2327 | return; |
2328 | } |
2329 | |
2330 | // Otherwise, output the keyword wrapped in quotes with proper escaping. |
2331 | os << "\"" ; |
2332 | printEscapedString(Name: keyword, Out&: os); |
2333 | os << '"'; |
2334 | } |
2335 | |
2336 | /// Print the given string as a symbol reference. A symbol reference is |
2337 | /// represented as a string prefixed with '@'. The reference is surrounded with |
2338 | /// ""'s and escaped if it has any special or non-printable characters in it. |
2339 | static void printSymbolReference(StringRef symbolRef, raw_ostream &os) { |
2340 | if (symbolRef.empty()) { |
2341 | os << "@<<INVALID EMPTY SYMBOL>>" ; |
2342 | return; |
2343 | } |
2344 | os << '@'; |
2345 | printKeywordOrString(keyword: symbolRef, os); |
2346 | } |
2347 | |
2348 | // Print out a valid ElementsAttr that is succinct and can represent any |
2349 | // potential shape/type, for use when eliding a large ElementsAttr. |
2350 | // |
2351 | // We choose to use a dense resource ElementsAttr literal with conspicuous |
2352 | // content to hopefully alert readers to the fact that this has been elided. |
2353 | static void printElidedElementsAttr(raw_ostream &os) { |
2354 | os << R"(dense_resource<__elided__>)" ; |
2355 | } |
2356 | |
2357 | void AsmPrinter::Impl::printResourceHandle( |
2358 | const AsmDialectResourceHandle &resource) { |
2359 | auto *interface = cast<OpAsmDialectInterface>(Val: resource.getDialect()); |
2360 | ::printKeywordOrString(keyword: interface->getResourceKey(handle: resource), os); |
2361 | state.getDialectResources()[resource.getDialect()].insert(X: resource); |
2362 | } |
2363 | |
2364 | LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) { |
2365 | return state.getAliasState().getAlias(attr, os); |
2366 | } |
2367 | |
2368 | LogicalResult AsmPrinter::Impl::printAlias(Type type) { |
2369 | return state.getAliasState().getAlias(ty: type, os); |
2370 | } |
2371 | |
2372 | void AsmPrinter::Impl::printAttribute(Attribute attr, |
2373 | AttrTypeElision typeElision) { |
2374 | if (!attr) { |
2375 | os << "<<NULL ATTRIBUTE>>" ; |
2376 | return; |
2377 | } |
2378 | |
2379 | // Try to print an alias for this attribute. |
2380 | if (succeeded(Result: printAlias(attr))) |
2381 | return; |
2382 | return printAttributeImpl(attr, typeElision); |
2383 | } |
2384 | |
2385 | void AsmPrinter::Impl::printAttributeImpl(Attribute attr, |
2386 | AttrTypeElision typeElision) { |
2387 | if (!isa<BuiltinDialect>(Val: attr.getDialect())) { |
2388 | printDialectAttribute(attr); |
2389 | } else if (auto opaqueAttr = llvm::dyn_cast<OpaqueAttr>(attr)) { |
2390 | printDialectSymbol(os, "#" , opaqueAttr.getDialectNamespace(), |
2391 | opaqueAttr.getAttrData()); |
2392 | } else if (llvm::isa<UnitAttr>(Val: attr)) { |
2393 | os << "unit" ; |
2394 | return; |
2395 | } else if (auto distinctAttr = llvm::dyn_cast<DistinctAttr>(attr)) { |
2396 | os << "distinct[" << state.getDistinctState().getId(distinctAttr) << "]<" ; |
2397 | if (!llvm::isa<UnitAttr>(Val: distinctAttr.getReferencedAttr())) { |
2398 | printAttribute(attr: distinctAttr.getReferencedAttr()); |
2399 | } |
2400 | os << '>'; |
2401 | return; |
2402 | } else if (auto dictAttr = llvm::dyn_cast<DictionaryAttr>(attr)) { |
2403 | os << '{'; |
2404 | interleaveComma(dictAttr.getValue(), |
2405 | [&](NamedAttribute attr) { printNamedAttribute(attr); }); |
2406 | os << '}'; |
2407 | |
2408 | } else if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) { |
2409 | Type intType = intAttr.getType(); |
2410 | if (intType.isSignlessInteger(width: 1)) { |
2411 | os << (intAttr.getValue().getBoolValue() ? "true" : "false" ); |
2412 | |
2413 | // Boolean integer attributes always elides the type. |
2414 | return; |
2415 | } |
2416 | |
2417 | // Only print attributes as unsigned if they are explicitly unsigned or are |
2418 | // signless 1-bit values. Indexes, signed values, and multi-bit signless |
2419 | // values print as signed. |
2420 | bool isUnsigned = |
2421 | intType.isUnsignedInteger() || intType.isSignlessInteger(width: 1); |
2422 | intAttr.getValue().print(os, !isUnsigned); |
2423 | |
2424 | // IntegerAttr elides the type if I64. |
2425 | if (typeElision == AttrTypeElision::May && intType.isSignlessInteger(width: 64)) |
2426 | return; |
2427 | |
2428 | } else if (auto floatAttr = llvm::dyn_cast<FloatAttr>(attr)) { |
2429 | bool printedHex = false; |
2430 | printFloatValue(floatAttr.getValue(), os, &printedHex); |
2431 | |
2432 | // FloatAttr elides the type if F64. |
2433 | if (typeElision == AttrTypeElision::May && floatAttr.getType().isF64() && |
2434 | !printedHex) |
2435 | return; |
2436 | |
2437 | } else if (auto strAttr = llvm::dyn_cast<StringAttr>(attr)) { |
2438 | printEscapedString(str: strAttr.getValue()); |
2439 | |
2440 | } else if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr)) { |
2441 | os << '['; |
2442 | interleaveComma(arrayAttr.getValue(), [&](Attribute attr) { |
2443 | printAttribute(attr, typeElision: AttrTypeElision::May); |
2444 | }); |
2445 | os << ']'; |
2446 | |
2447 | } else if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(attr)) { |
2448 | os << "affine_map<" ; |
2449 | affineMapAttr.getValue().print(os); |
2450 | os << '>'; |
2451 | |
2452 | // AffineMap always elides the type. |
2453 | return; |
2454 | |
2455 | } else if (auto integerSetAttr = llvm::dyn_cast<IntegerSetAttr>(attr)) { |
2456 | os << "affine_set<" ; |
2457 | integerSetAttr.getValue().print(os); |
2458 | os << '>'; |
2459 | |
2460 | // IntegerSet always elides the type. |
2461 | return; |
2462 | |
2463 | } else if (auto typeAttr = llvm::dyn_cast<TypeAttr>(attr)) { |
2464 | printType(type: typeAttr.getValue()); |
2465 | |
2466 | } else if (auto refAttr = llvm::dyn_cast<SymbolRefAttr>(attr)) { |
2467 | printSymbolReference(refAttr.getRootReference().getValue(), os); |
2468 | for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) { |
2469 | os << "::" ; |
2470 | printSymbolReference(nestedRef.getValue(), os); |
2471 | } |
2472 | |
2473 | } else if (auto intOrFpEltAttr = |
2474 | llvm::dyn_cast<DenseIntOrFPElementsAttr>(attr)) { |
2475 | if (printerFlags.shouldElideElementsAttr(attr: intOrFpEltAttr)) { |
2476 | printElidedElementsAttr(os); |
2477 | } else { |
2478 | os << "dense<" ; |
2479 | printDenseIntOrFPElementsAttr(attr: intOrFpEltAttr, /*allowHex=*/true); |
2480 | os << '>'; |
2481 | } |
2482 | |
2483 | } else if (auto strEltAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr)) { |
2484 | if (printerFlags.shouldElideElementsAttr(attr: strEltAttr)) { |
2485 | printElidedElementsAttr(os); |
2486 | } else { |
2487 | os << "dense<" ; |
2488 | printDenseStringElementsAttr(attr: strEltAttr); |
2489 | os << '>'; |
2490 | } |
2491 | |
2492 | } else if (auto sparseEltAttr = llvm::dyn_cast<SparseElementsAttr>(attr)) { |
2493 | if (printerFlags.shouldElideElementsAttr(attr: sparseEltAttr.getIndices()) || |
2494 | printerFlags.shouldElideElementsAttr(attr: sparseEltAttr.getValues())) { |
2495 | printElidedElementsAttr(os); |
2496 | } else { |
2497 | os << "sparse<" ; |
2498 | DenseIntElementsAttr indices = sparseEltAttr.getIndices(); |
2499 | if (indices.getNumElements() != 0) { |
2500 | printDenseIntOrFPElementsAttr(attr: indices, /*allowHex=*/false); |
2501 | os << ", " ; |
2502 | printDenseElementsAttr(attr: sparseEltAttr.getValues(), /*allowHex=*/true); |
2503 | } |
2504 | os << '>'; |
2505 | } |
2506 | } else if (auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(attr)) { |
2507 | stridedLayoutAttr.print(os); |
2508 | } else if (auto denseArrayAttr = llvm::dyn_cast<DenseArrayAttr>(attr)) { |
2509 | os << "array<" ; |
2510 | printType(type: denseArrayAttr.getElementType()); |
2511 | if (!denseArrayAttr.empty()) { |
2512 | os << ": " ; |
2513 | printDenseArrayAttr(attr: denseArrayAttr); |
2514 | } |
2515 | os << ">" ; |
2516 | return; |
2517 | } else if (auto resourceAttr = |
2518 | llvm::dyn_cast<DenseResourceElementsAttr>(attr)) { |
2519 | os << "dense_resource<" ; |
2520 | printResourceHandle(resource: resourceAttr.getRawHandle()); |
2521 | os << ">" ; |
2522 | } else if (auto locAttr = llvm::dyn_cast<LocationAttr>(Val&: attr)) { |
2523 | printLocation(loc: locAttr); |
2524 | } else { |
2525 | llvm::report_fatal_error(reason: "Unknown builtin attribute" ); |
2526 | } |
2527 | // Don't print the type if we must elide it, or if it is a None type. |
2528 | if (typeElision != AttrTypeElision::Must) { |
2529 | if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) { |
2530 | Type attrType = typedAttr.getType(); |
2531 | if (!llvm::isa<NoneType>(Val: attrType)) { |
2532 | os << " : " ; |
2533 | printType(type: attrType); |
2534 | } |
2535 | } |
2536 | } |
2537 | } |
2538 | |
2539 | /// Print the integer element of a DenseElementsAttr. |
2540 | static void printDenseIntElement(const APInt &value, raw_ostream &os, |
2541 | Type type) { |
2542 | if (type.isInteger(width: 1)) |
2543 | os << (value.getBoolValue() ? "true" : "false" ); |
2544 | else |
2545 | value.print(OS&: os, isSigned: !type.isUnsignedInteger()); |
2546 | } |
2547 | |
2548 | static void |
2549 | printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os, |
2550 | function_ref<void(unsigned)> printEltFn) { |
2551 | // Special case for 0-d and splat tensors. |
2552 | if (isSplat) |
2553 | return printEltFn(0); |
2554 | |
2555 | // Special case for degenerate tensors. |
2556 | auto numElements = type.getNumElements(); |
2557 | if (numElements == 0) |
2558 | return; |
2559 | |
2560 | // We use a mixed-radix counter to iterate through the shape. When we bump a |
2561 | // non-least-significant digit, we emit a close bracket. When we next emit an |
2562 | // element we re-open all closed brackets. |
2563 | |
2564 | // The mixed-radix counter, with radices in 'shape'. |
2565 | int64_t rank = type.getRank(); |
2566 | SmallVector<unsigned, 4> counter(rank, 0); |
2567 | // The number of brackets that have been opened and not closed. |
2568 | unsigned openBrackets = 0; |
2569 | |
2570 | auto shape = type.getShape(); |
2571 | auto bumpCounter = [&] { |
2572 | // Bump the least significant digit. |
2573 | ++counter[rank - 1]; |
2574 | // Iterate backwards bubbling back the increment. |
2575 | for (unsigned i = rank - 1; i > 0; --i) |
2576 | if (counter[i] >= shape[i]) { |
2577 | // Index 'i' is rolled over. Bump (i-1) and close a bracket. |
2578 | counter[i] = 0; |
2579 | ++counter[i - 1]; |
2580 | --openBrackets; |
2581 | os << ']'; |
2582 | } |
2583 | }; |
2584 | |
2585 | for (unsigned idx = 0, e = numElements; idx != e; ++idx) { |
2586 | if (idx != 0) |
2587 | os << ", " ; |
2588 | while (openBrackets++ < rank) |
2589 | os << '['; |
2590 | openBrackets = rank; |
2591 | printEltFn(idx); |
2592 | bumpCounter(); |
2593 | } |
2594 | while (openBrackets-- > 0) |
2595 | os << ']'; |
2596 | } |
2597 | |
2598 | void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr, |
2599 | bool allowHex) { |
2600 | if (auto stringAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr)) |
2601 | return printDenseStringElementsAttr(attr: stringAttr); |
2602 | |
2603 | printDenseIntOrFPElementsAttr(llvm::cast<DenseIntOrFPElementsAttr>(attr), |
2604 | allowHex); |
2605 | } |
2606 | |
2607 | void AsmPrinter::Impl::printDenseIntOrFPElementsAttr( |
2608 | DenseIntOrFPElementsAttr attr, bool allowHex) { |
2609 | auto type = attr.getType(); |
2610 | auto elementType = type.getElementType(); |
2611 | |
2612 | // Check to see if we should format this attribute as a hex string. |
2613 | if (allowHex && printerFlags.shouldPrintElementsAttrWithHex(attr: attr)) { |
2614 | ArrayRef<char> rawData = attr.getRawData(); |
2615 | if (llvm::endianness::native == llvm::endianness::big) { |
2616 | // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE |
2617 | // machines. It is converted here to print in LE format. |
2618 | SmallVector<char, 64> outDataVec(rawData.size()); |
2619 | MutableArrayRef<char> convRawData(outDataVec); |
2620 | DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( |
2621 | rawData, convRawData, type); |
2622 | printHexString(data: convRawData); |
2623 | } else { |
2624 | printHexString(data: rawData); |
2625 | } |
2626 | |
2627 | return; |
2628 | } |
2629 | |
2630 | if (ComplexType complexTy = llvm::dyn_cast<ComplexType>(elementType)) { |
2631 | Type complexElementType = complexTy.getElementType(); |
2632 | // Note: The if and else below had a common lambda function which invoked |
2633 | // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2 |
2634 | // and hence was replaced. |
2635 | if (llvm::isa<IntegerType>(Val: complexElementType)) { |
2636 | auto valueIt = attr.value_begin<std::complex<APInt>>(); |
2637 | printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { |
2638 | auto complexValue = *(valueIt + index); |
2639 | os << "(" ; |
2640 | printDenseIntElement(complexValue.real(), os, complexElementType); |
2641 | os << "," ; |
2642 | printDenseIntElement(complexValue.imag(), os, complexElementType); |
2643 | os << ")" ; |
2644 | }); |
2645 | } else { |
2646 | auto valueIt = attr.value_begin<std::complex<APFloat>>(); |
2647 | printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { |
2648 | auto complexValue = *(valueIt + index); |
2649 | os << "(" ; |
2650 | printFloatValue(complexValue.real(), os); |
2651 | os << "," ; |
2652 | printFloatValue(complexValue.imag(), os); |
2653 | os << ")" ; |
2654 | }); |
2655 | } |
2656 | } else if (elementType.isIntOrIndex()) { |
2657 | auto valueIt = attr.value_begin<APInt>(); |
2658 | printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { |
2659 | printDenseIntElement(*(valueIt + index), os, elementType); |
2660 | }); |
2661 | } else { |
2662 | assert(llvm::isa<FloatType>(elementType) && "unexpected element type" ); |
2663 | auto valueIt = attr.value_begin<APFloat>(); |
2664 | printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { |
2665 | printFloatValue(*(valueIt + index), os); |
2666 | }); |
2667 | } |
2668 | } |
2669 | |
2670 | void AsmPrinter::Impl::printDenseStringElementsAttr( |
2671 | DenseStringElementsAttr attr) { |
2672 | ArrayRef<StringRef> data = attr.getRawStringData(); |
2673 | auto printFn = [&](unsigned index) { printEscapedString(str: data[index]); }; |
2674 | printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn); |
2675 | } |
2676 | |
2677 | void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) { |
2678 | Type type = attr.getElementType(); |
2679 | unsigned bitwidth = type.isInteger(width: 1) ? 8 : type.getIntOrFloatBitWidth(); |
2680 | unsigned byteSize = bitwidth / 8; |
2681 | ArrayRef<char> data = attr.getRawData(); |
2682 | |
2683 | auto printElementAt = [&](unsigned i) { |
2684 | APInt value(bitwidth, 0); |
2685 | if (bitwidth) { |
2686 | llvm::LoadIntFromMemory( |
2687 | IntVal&: value, Src: reinterpret_cast<const uint8_t *>(data.begin() + byteSize * i), |
2688 | LoadBytes: byteSize); |
2689 | } |
2690 | // Print the data as-is or as a float. |
2691 | if (type.isIntOrIndex()) { |
2692 | printDenseIntElement(value, os&: getStream(), type); |
2693 | } else { |
2694 | APFloat fltVal(llvm::cast<FloatType>(type).getFloatSemantics(), value); |
2695 | printFloatValue(apValue: fltVal, os&: getStream()); |
2696 | } |
2697 | }; |
2698 | llvm::interleaveComma(llvm::seq<unsigned>(0, attr.size()), getStream(), |
2699 | printElementAt); |
2700 | } |
2701 | |
2702 | void AsmPrinter::Impl::printType(Type type) { |
2703 | if (!type) { |
2704 | os << "<<NULL TYPE>>" ; |
2705 | return; |
2706 | } |
2707 | |
2708 | // Try to print an alias for this type. |
2709 | if (succeeded(Result: printAlias(type))) |
2710 | return; |
2711 | return printTypeImpl(type); |
2712 | } |
2713 | |
2714 | void AsmPrinter::Impl::printTypeImpl(Type type) { |
2715 | TypeSwitch<Type>(type) |
2716 | .Case<OpaqueType>([&](OpaqueType opaqueTy) { |
2717 | printDialectSymbol(os, "!" , opaqueTy.getDialectNamespace(), |
2718 | opaqueTy.getTypeData()); |
2719 | }) |
2720 | .Case<IndexType>([&](Type) { os << "index" ; }) |
2721 | .Case<Float4E2M1FNType>([&](Type) { os << "f4E2M1FN" ; }) |
2722 | .Case<Float6E2M3FNType>([&](Type) { os << "f6E2M3FN" ; }) |
2723 | .Case<Float6E3M2FNType>([&](Type) { os << "f6E3M2FN" ; }) |
2724 | .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2" ; }) |
2725 | .Case<Float8E4M3Type>([&](Type) { os << "f8E4M3" ; }) |
2726 | .Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN" ; }) |
2727 | .Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ" ; }) |
2728 | .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ" ; }) |
2729 | .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ" ; }) |
2730 | .Case<Float8E3M4Type>([&](Type) { os << "f8E3M4" ; }) |
2731 | .Case<Float8E8M0FNUType>([&](Type) { os << "f8E8M0FNU" ; }) |
2732 | .Case<BFloat16Type>([&](Type) { os << "bf16" ; }) |
2733 | .Case<Float16Type>([&](Type) { os << "f16" ; }) |
2734 | .Case<FloatTF32Type>([&](Type) { os << "tf32" ; }) |
2735 | .Case<Float32Type>([&](Type) { os << "f32" ; }) |
2736 | .Case<Float64Type>([&](Type) { os << "f64" ; }) |
2737 | .Case<Float80Type>([&](Type) { os << "f80" ; }) |
2738 | .Case<Float128Type>([&](Type) { os << "f128" ; }) |
2739 | .Case<IntegerType>([&](IntegerType integerTy) { |
2740 | if (integerTy.isSigned()) |
2741 | os << 's'; |
2742 | else if (integerTy.isUnsigned()) |
2743 | os << 'u'; |
2744 | os << 'i' << integerTy.getWidth(); |
2745 | }) |
2746 | .Case<FunctionType>([&](FunctionType funcTy) { |
2747 | os << '('; |
2748 | interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); }); |
2749 | os << ") -> " ; |
2750 | ArrayRef<Type> results = funcTy.getResults(); |
2751 | if (results.size() == 1 && !llvm::isa<FunctionType>(results[0])) { |
2752 | printType(results[0]); |
2753 | } else { |
2754 | os << '('; |
2755 | interleaveComma(results, [&](Type ty) { printType(ty); }); |
2756 | os << ')'; |
2757 | } |
2758 | }) |
2759 | .Case<VectorType>([&](VectorType vectorTy) { |
2760 | auto scalableDims = vectorTy.getScalableDims(); |
2761 | os << "vector<" ; |
2762 | auto vShape = vectorTy.getShape(); |
2763 | unsigned lastDim = vShape.size(); |
2764 | unsigned dimIdx = 0; |
2765 | for (dimIdx = 0; dimIdx < lastDim; dimIdx++) { |
2766 | if (!scalableDims.empty() && scalableDims[dimIdx]) |
2767 | os << '['; |
2768 | os << vShape[dimIdx]; |
2769 | if (!scalableDims.empty() && scalableDims[dimIdx]) |
2770 | os << ']'; |
2771 | os << 'x'; |
2772 | } |
2773 | printType(vectorTy.getElementType()); |
2774 | os << '>'; |
2775 | }) |
2776 | .Case<RankedTensorType>([&](RankedTensorType tensorTy) { |
2777 | os << "tensor<" ; |
2778 | printDimensionList(tensorTy.getShape()); |
2779 | if (!tensorTy.getShape().empty()) |
2780 | os << 'x'; |
2781 | printType(tensorTy.getElementType()); |
2782 | // Only print the encoding attribute value if set. |
2783 | if (tensorTy.getEncoding()) { |
2784 | os << ", " ; |
2785 | printAttribute(tensorTy.getEncoding()); |
2786 | } |
2787 | os << '>'; |
2788 | }) |
2789 | .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) { |
2790 | os << "tensor<*x" ; |
2791 | printType(tensorTy.getElementType()); |
2792 | os << '>'; |
2793 | }) |
2794 | .Case<MemRefType>([&](MemRefType memrefTy) { |
2795 | os << "memref<" ; |
2796 | printDimensionList(memrefTy.getShape()); |
2797 | if (!memrefTy.getShape().empty()) |
2798 | os << 'x'; |
2799 | printType(memrefTy.getElementType()); |
2800 | MemRefLayoutAttrInterface layout = memrefTy.getLayout(); |
2801 | if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) { |
2802 | os << ", " ; |
2803 | printAttribute(memrefTy.getLayout(), AttrTypeElision::May); |
2804 | } |
2805 | // Only print the memory space if it is the non-default one. |
2806 | if (memrefTy.getMemorySpace()) { |
2807 | os << ", " ; |
2808 | printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); |
2809 | } |
2810 | os << '>'; |
2811 | }) |
2812 | .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) { |
2813 | os << "memref<*x" ; |
2814 | printType(memrefTy.getElementType()); |
2815 | // Only print the memory space if it is the non-default one. |
2816 | if (memrefTy.getMemorySpace()) { |
2817 | os << ", " ; |
2818 | printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); |
2819 | } |
2820 | os << '>'; |
2821 | }) |
2822 | .Case<ComplexType>([&](ComplexType complexTy) { |
2823 | os << "complex<" ; |
2824 | printType(complexTy.getElementType()); |
2825 | os << '>'; |
2826 | }) |
2827 | .Case<TupleType>([&](TupleType tupleTy) { |
2828 | os << "tuple<" ; |
2829 | interleaveComma(tupleTy.getTypes(), |
2830 | [&](Type type) { printType(type); }); |
2831 | os << '>'; |
2832 | }) |
2833 | .Case<NoneType>([&](Type) { os << "none" ; }) |
2834 | .Default([&](Type type) { return printDialectType(type); }); |
2835 | } |
2836 | |
2837 | void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
2838 | ArrayRef<StringRef> elidedAttrs, |
2839 | bool withKeyword) { |
2840 | // If there are no attributes, then there is nothing to be done. |
2841 | if (attrs.empty()) |
2842 | return; |
2843 | |
2844 | // Functor used to print a filtered attribute list. |
2845 | auto printFilteredAttributesFn = [&](auto filteredAttrs) { |
2846 | // Print the 'attributes' keyword if necessary. |
2847 | if (withKeyword) |
2848 | os << " attributes" ; |
2849 | |
2850 | // Otherwise, print them all out in braces. |
2851 | os << " {" ; |
2852 | interleaveComma(filteredAttrs, |
2853 | [&](NamedAttribute attr) { printNamedAttribute(attr); }); |
2854 | os << '}'; |
2855 | }; |
2856 | |
2857 | // If no attributes are elided, we can directly print with no filtering. |
2858 | if (elidedAttrs.empty()) |
2859 | return printFilteredAttributesFn(attrs); |
2860 | |
2861 | // Otherwise, filter out any attributes that shouldn't be included. |
2862 | llvm::SmallDenseSet<StringRef> (elidedAttrs.begin(), |
2863 | elidedAttrs.end()); |
2864 | auto filteredAttrs = llvm::make_filter_range(Range&: attrs, Pred: [&](NamedAttribute attr) { |
2865 | return !elidedAttrsSet.contains(attr.getName().strref()); |
2866 | }); |
2867 | if (!filteredAttrs.empty()) |
2868 | printFilteredAttributesFn(filteredAttrs); |
2869 | } |
2870 | void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) { |
2871 | // Print the name without quotes if possible. |
2872 | ::printKeywordOrString(keyword: attr.getName().strref(), os); |
2873 | |
2874 | // Pretty printing elides the attribute value for unit attributes. |
2875 | if (llvm::isa<UnitAttr>(Val: attr.getValue())) |
2876 | return; |
2877 | |
2878 | os << " = " ; |
2879 | printAttribute(attr: attr.getValue()); |
2880 | } |
2881 | |
2882 | void AsmPrinter::Impl::printDialectAttribute(Attribute attr) { |
2883 | auto &dialect = attr.getDialect(); |
2884 | |
2885 | // Ask the dialect to serialize the attribute to a string. |
2886 | std::string attrName; |
2887 | { |
2888 | llvm::raw_string_ostream attrNameStr(attrName); |
2889 | Impl subPrinter(attrNameStr, state); |
2890 | DialectAsmPrinter printer(subPrinter); |
2891 | dialect.printAttribute(attr, printer); |
2892 | } |
2893 | printDialectSymbol(os, symPrefix: "#" , dialectName: dialect.getNamespace(), symString: attrName); |
2894 | } |
2895 | |
2896 | void AsmPrinter::Impl::printDialectType(Type type) { |
2897 | auto &dialect = type.getDialect(); |
2898 | |
2899 | // Ask the dialect to serialize the type to a string. |
2900 | std::string typeName; |
2901 | { |
2902 | llvm::raw_string_ostream typeNameStr(typeName); |
2903 | Impl subPrinter(typeNameStr, state); |
2904 | DialectAsmPrinter printer(subPrinter); |
2905 | dialect.printType(type, printer); |
2906 | } |
2907 | printDialectSymbol(os, symPrefix: "!" , dialectName: dialect.getNamespace(), symString: typeName); |
2908 | } |
2909 | |
2910 | void AsmPrinter::Impl::printEscapedString(StringRef str) { |
2911 | os << "\"" ; |
2912 | llvm::printEscapedString(Name: str, Out&: os); |
2913 | os << "\"" ; |
2914 | } |
2915 | |
2916 | void AsmPrinter::Impl::printHexString(StringRef str) { |
2917 | os << "\"0x" << llvm::toHex(Input: str) << "\"" ; |
2918 | } |
2919 | void AsmPrinter::Impl::printHexString(ArrayRef<char> data) { |
2920 | printHexString(str: StringRef(data.data(), data.size())); |
2921 | } |
2922 | |
2923 | LogicalResult AsmPrinter::Impl::pushCyclicPrinting(const void *opaquePointer) { |
2924 | return state.pushCyclicPrinting(opaquePointer); |
2925 | } |
2926 | |
2927 | void AsmPrinter::Impl::popCyclicPrinting() { state.popCyclicPrinting(); } |
2928 | |
2929 | void AsmPrinter::Impl::printDimensionList(ArrayRef<int64_t> shape) { |
2930 | detail::printDimensionList(stream&: os, shape); |
2931 | } |
2932 | |
2933 | //===--------------------------------------------------------------------===// |
2934 | // AsmPrinter |
2935 | //===--------------------------------------------------------------------===// |
2936 | |
2937 | AsmPrinter::~AsmPrinter() = default; |
2938 | |
2939 | raw_ostream &AsmPrinter::getStream() const { |
2940 | assert(impl && "expected AsmPrinter::getStream to be overriden" ); |
2941 | return impl->getStream(); |
2942 | } |
2943 | |
2944 | /// Print the given floating point value in a stablized form. |
2945 | void AsmPrinter::printFloat(const APFloat &value) { |
2946 | assert(impl && "expected AsmPrinter::printFloat to be overriden" ); |
2947 | printFloatValue(apValue: value, os&: impl->getStream()); |
2948 | } |
2949 | |
2950 | void AsmPrinter::printType(Type type) { |
2951 | assert(impl && "expected AsmPrinter::printType to be overriden" ); |
2952 | impl->printType(type); |
2953 | } |
2954 | |
2955 | void AsmPrinter::printAttribute(Attribute attr) { |
2956 | assert(impl && "expected AsmPrinter::printAttribute to be overriden" ); |
2957 | impl->printAttribute(attr); |
2958 | } |
2959 | |
2960 | LogicalResult AsmPrinter::printAlias(Attribute attr) { |
2961 | assert(impl && "expected AsmPrinter::printAlias to be overriden" ); |
2962 | return impl->printAlias(attr); |
2963 | } |
2964 | |
2965 | LogicalResult AsmPrinter::printAlias(Type type) { |
2966 | assert(impl && "expected AsmPrinter::printAlias to be overriden" ); |
2967 | return impl->printAlias(type); |
2968 | } |
2969 | |
2970 | void AsmPrinter::printAttributeWithoutType(Attribute attr) { |
2971 | assert(impl && |
2972 | "expected AsmPrinter::printAttributeWithoutType to be overriden" ); |
2973 | impl->printAttribute(attr, typeElision: Impl::AttrTypeElision::Must); |
2974 | } |
2975 | |
2976 | void AsmPrinter::printKeywordOrString(StringRef keyword) { |
2977 | assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden" ); |
2978 | ::printKeywordOrString(keyword, os&: impl->getStream()); |
2979 | } |
2980 | |
2981 | void AsmPrinter::printString(StringRef keyword) { |
2982 | assert(impl && "expected AsmPrinter::printString to be overriden" ); |
2983 | *this << '"'; |
2984 | printEscapedString(Name: keyword, Out&: getStream()); |
2985 | *this << '"'; |
2986 | } |
2987 | |
2988 | void AsmPrinter::printSymbolName(StringRef symbolRef) { |
2989 | assert(impl && "expected AsmPrinter::printSymbolName to be overriden" ); |
2990 | ::printSymbolReference(symbolRef, os&: impl->getStream()); |
2991 | } |
2992 | |
2993 | void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) { |
2994 | assert(impl && "expected AsmPrinter::printResourceHandle to be overriden" ); |
2995 | impl->printResourceHandle(resource); |
2996 | } |
2997 | |
2998 | void AsmPrinter::printDimensionList(ArrayRef<int64_t> shape) { |
2999 | detail::printDimensionList(stream&: getStream(), shape); |
3000 | } |
3001 | |
3002 | LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) { |
3003 | return impl->pushCyclicPrinting(opaquePointer); |
3004 | } |
3005 | |
3006 | void AsmPrinter::popCyclicPrinting() { impl->popCyclicPrinting(); } |
3007 | |
3008 | //===----------------------------------------------------------------------===// |
3009 | // Affine expressions and maps |
3010 | //===----------------------------------------------------------------------===// |
3011 | |
3012 | void AsmPrinter::Impl::printAffineExpr( |
3013 | AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) { |
3014 | printAffineExprInternal(expr, enclosingTightness: BindingStrength::Weak, printValueName); |
3015 | } |
3016 | |
3017 | void AsmPrinter::Impl::printAffineExprInternal( |
3018 | AffineExpr expr, BindingStrength enclosingTightness, |
3019 | function_ref<void(unsigned, bool)> printValueName) { |
3020 | const char *binopSpelling = nullptr; |
3021 | switch (expr.getKind()) { |
3022 | case AffineExprKind::SymbolId: { |
3023 | unsigned pos = cast<AffineSymbolExpr>(Val&: expr).getPosition(); |
3024 | if (printValueName) |
3025 | printValueName(pos, /*isSymbol=*/true); |
3026 | else |
3027 | os << 's' << pos; |
3028 | return; |
3029 | } |
3030 | case AffineExprKind::DimId: { |
3031 | unsigned pos = cast<AffineDimExpr>(Val&: expr).getPosition(); |
3032 | if (printValueName) |
3033 | printValueName(pos, /*isSymbol=*/false); |
3034 | else |
3035 | os << 'd' << pos; |
3036 | return; |
3037 | } |
3038 | case AffineExprKind::Constant: |
3039 | os << cast<AffineConstantExpr>(Val&: expr).getValue(); |
3040 | return; |
3041 | case AffineExprKind::Add: |
3042 | binopSpelling = " + " ; |
3043 | break; |
3044 | case AffineExprKind::Mul: |
3045 | binopSpelling = " * " ; |
3046 | break; |
3047 | case AffineExprKind::FloorDiv: |
3048 | binopSpelling = " floordiv " ; |
3049 | break; |
3050 | case AffineExprKind::CeilDiv: |
3051 | binopSpelling = " ceildiv " ; |
3052 | break; |
3053 | case AffineExprKind::Mod: |
3054 | binopSpelling = " mod " ; |
3055 | break; |
3056 | } |
3057 | |
3058 | auto binOp = cast<AffineBinaryOpExpr>(Val&: expr); |
3059 | AffineExpr lhsExpr = binOp.getLHS(); |
3060 | AffineExpr rhsExpr = binOp.getRHS(); |
3061 | |
3062 | // Handle tightly binding binary operators. |
3063 | if (binOp.getKind() != AffineExprKind::Add) { |
3064 | if (enclosingTightness == BindingStrength::Strong) |
3065 | os << '('; |
3066 | |
3067 | // Pretty print multiplication with -1. |
3068 | auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhsExpr); |
3069 | if (rhsConst && binOp.getKind() == AffineExprKind::Mul && |
3070 | rhsConst.getValue() == -1) { |
3071 | os << "-" ; |
3072 | printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Strong, printValueName); |
3073 | if (enclosingTightness == BindingStrength::Strong) |
3074 | os << ')'; |
3075 | return; |
3076 | } |
3077 | |
3078 | printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Strong, printValueName); |
3079 | |
3080 | os << binopSpelling; |
3081 | printAffineExprInternal(expr: rhsExpr, enclosingTightness: BindingStrength::Strong, printValueName); |
3082 | |
3083 | if (enclosingTightness == BindingStrength::Strong) |
3084 | os << ')'; |
3085 | return; |
3086 | } |
3087 | |
3088 | // Print out special "pretty" forms for add. |
3089 | if (enclosingTightness == BindingStrength::Strong) |
3090 | os << '('; |
3091 | |
3092 | // Pretty print addition to a product that has a negative operand as a |
3093 | // subtraction. |
3094 | if (auto rhs = dyn_cast<AffineBinaryOpExpr>(Val&: rhsExpr)) { |
3095 | if (rhs.getKind() == AffineExprKind::Mul) { |
3096 | AffineExpr rrhsExpr = rhs.getRHS(); |
3097 | if (auto rrhs = dyn_cast<AffineConstantExpr>(Val&: rrhsExpr)) { |
3098 | if (rrhs.getValue() == -1) { |
3099 | printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Weak, |
3100 | printValueName); |
3101 | os << " - " ; |
3102 | if (rhs.getLHS().getKind() == AffineExprKind::Add) { |
3103 | printAffineExprInternal(expr: rhs.getLHS(), enclosingTightness: BindingStrength::Strong, |
3104 | printValueName); |
3105 | } else { |
3106 | printAffineExprInternal(expr: rhs.getLHS(), enclosingTightness: BindingStrength::Weak, |
3107 | printValueName); |
3108 | } |
3109 | |
3110 | if (enclosingTightness == BindingStrength::Strong) |
3111 | os << ')'; |
3112 | return; |
3113 | } |
3114 | |
3115 | if (rrhs.getValue() < -1) { |
3116 | printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Weak, |
3117 | printValueName); |
3118 | os << " - " ; |
3119 | printAffineExprInternal(expr: rhs.getLHS(), enclosingTightness: BindingStrength::Strong, |
3120 | printValueName); |
3121 | os << " * " << -rrhs.getValue(); |
3122 | if (enclosingTightness == BindingStrength::Strong) |
3123 | os << ')'; |
3124 | return; |
3125 | } |
3126 | } |
3127 | } |
3128 | } |
3129 | |
3130 | // Pretty print addition to a negative number as a subtraction. |
3131 | if (auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhsExpr)) { |
3132 | if (rhsConst.getValue() < 0) { |
3133 | printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Weak, printValueName); |
3134 | os << " - " << -rhsConst.getValue(); |
3135 | if (enclosingTightness == BindingStrength::Strong) |
3136 | os << ')'; |
3137 | return; |
3138 | } |
3139 | } |
3140 | |
3141 | printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Weak, printValueName); |
3142 | |
3143 | os << " + " ; |
3144 | printAffineExprInternal(expr: rhsExpr, enclosingTightness: BindingStrength::Weak, printValueName); |
3145 | |
3146 | if (enclosingTightness == BindingStrength::Strong) |
3147 | os << ')'; |
3148 | } |
3149 | |
3150 | void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) { |
3151 | printAffineExprInternal(expr, enclosingTightness: BindingStrength::Weak); |
3152 | isEq ? os << " == 0" : os << " >= 0" ; |
3153 | } |
3154 | |
3155 | void AsmPrinter::Impl::printAffineMap(AffineMap map) { |
3156 | // Dimension identifiers. |
3157 | os << '('; |
3158 | for (int i = 0; i < (int)map.getNumDims() - 1; ++i) |
3159 | os << 'd' << i << ", " ; |
3160 | if (map.getNumDims() >= 1) |
3161 | os << 'd' << map.getNumDims() - 1; |
3162 | os << ')'; |
3163 | |
3164 | // Symbolic identifiers. |
3165 | if (map.getNumSymbols() != 0) { |
3166 | os << '['; |
3167 | for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i) |
3168 | os << 's' << i << ", " ; |
3169 | if (map.getNumSymbols() >= 1) |
3170 | os << 's' << map.getNumSymbols() - 1; |
3171 | os << ']'; |
3172 | } |
3173 | |
3174 | // Result affine expressions. |
3175 | os << " -> (" ; |
3176 | interleaveComma(c: map.getResults(), |
3177 | eachFn: [&](AffineExpr expr) { printAffineExpr(expr); }); |
3178 | os << ')'; |
3179 | } |
3180 | |
3181 | void AsmPrinter::Impl::printIntegerSet(IntegerSet set) { |
3182 | // Dimension identifiers. |
3183 | os << '('; |
3184 | for (unsigned i = 1; i < set.getNumDims(); ++i) |
3185 | os << 'd' << i - 1 << ", " ; |
3186 | if (set.getNumDims() >= 1) |
3187 | os << 'd' << set.getNumDims() - 1; |
3188 | os << ')'; |
3189 | |
3190 | // Symbolic identifiers. |
3191 | if (set.getNumSymbols() != 0) { |
3192 | os << '['; |
3193 | for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i) |
3194 | os << 's' << i << ", " ; |
3195 | if (set.getNumSymbols() >= 1) |
3196 | os << 's' << set.getNumSymbols() - 1; |
3197 | os << ']'; |
3198 | } |
3199 | |
3200 | // Print constraints. |
3201 | os << " : (" ; |
3202 | int numConstraints = set.getNumConstraints(); |
3203 | for (int i = 1; i < numConstraints; ++i) { |
3204 | printAffineConstraint(expr: set.getConstraint(idx: i - 1), isEq: set.isEq(idx: i - 1)); |
3205 | os << ", " ; |
3206 | } |
3207 | if (numConstraints >= 1) |
3208 | printAffineConstraint(expr: set.getConstraint(idx: numConstraints - 1), |
3209 | isEq: set.isEq(idx: numConstraints - 1)); |
3210 | os << ')'; |
3211 | } |
3212 | |
3213 | //===----------------------------------------------------------------------===// |
3214 | // OperationPrinter |
3215 | //===----------------------------------------------------------------------===// |
3216 | |
3217 | namespace { |
3218 | /// This class contains the logic for printing operations, regions, and blocks. |
3219 | class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter { |
3220 | public: |
3221 | using Impl = AsmPrinter::Impl; |
3222 | using Impl::printType; |
3223 | |
3224 | explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state) |
3225 | : Impl(os, state), OpAsmPrinter(static_cast<Impl &>(*this)) {} |
3226 | |
3227 | /// Print the given top-level operation. |
3228 | void printTopLevelOperation(Operation *op); |
3229 | |
3230 | /// Print the given operation, including its left-hand side and its right-hand |
3231 | /// side, with its indent and location. |
3232 | void printFullOpWithIndentAndLoc(Operation *op); |
3233 | /// Print the given operation, including its left-hand side and its right-hand |
3234 | /// side, but not including indentation and location. |
3235 | void printFullOp(Operation *op); |
3236 | /// Print the right-hand size of the given operation in the custom or generic |
3237 | /// form. |
3238 | void printCustomOrGenericOp(Operation *op) override; |
3239 | /// Print the right-hand side of the given operation in the generic form. |
3240 | void printGenericOp(Operation *op, bool printOpName) override; |
3241 | |
3242 | /// Print the name of the given block. |
3243 | void printBlockName(Block *block); |
3244 | |
3245 | /// Print the given block. If 'printBlockArgs' is false, the arguments of the |
3246 | /// block are not printed. If 'printBlockTerminator' is false, the terminator |
3247 | /// operation of the block is not printed. |
3248 | void print(Block *block, bool printBlockArgs = true, |
3249 | bool printBlockTerminator = true); |
3250 | |
3251 | /// Print the ID of the given value, optionally with its result number. |
3252 | void printValueID(Value value, bool printResultNo = true, |
3253 | raw_ostream *streamOverride = nullptr) const; |
3254 | |
3255 | /// Print the ID of the given operation. |
3256 | void printOperationID(Operation *op, |
3257 | raw_ostream *streamOverride = nullptr) const; |
3258 | |
3259 | //===--------------------------------------------------------------------===// |
3260 | // OpAsmPrinter methods |
3261 | //===--------------------------------------------------------------------===// |
3262 | |
3263 | /// Print a loc(...) specifier if printing debug info is enabled. Locations |
3264 | /// may be deferred with an alias. |
3265 | void printOptionalLocationSpecifier(Location loc) override { |
3266 | printTrailingLocation(loc); |
3267 | } |
3268 | |
3269 | /// Print a newline and indent the printer to the start of the current |
3270 | /// operation. |
3271 | void printNewline() override { |
3272 | os << newLine; |
3273 | os.indent(NumSpaces: currentIndent); |
3274 | } |
3275 | |
3276 | /// Increase indentation. |
3277 | void increaseIndent() override { currentIndent += indentWidth; } |
3278 | |
3279 | /// Decrease indentation. |
3280 | void decreaseIndent() override { currentIndent -= indentWidth; } |
3281 | |
3282 | /// Print a block argument in the usual format of: |
3283 | /// %ssaName : type {attr1=42} loc("here") |
3284 | /// where location printing is controlled by the standard internal option. |
3285 | /// You may pass omitType=true to not print a type, and pass an empty |
3286 | /// attribute list if you don't care for attributes. |
3287 | void printRegionArgument(BlockArgument arg, |
3288 | ArrayRef<NamedAttribute> argAttrs = {}, |
3289 | bool omitType = false) override; |
3290 | |
3291 | /// Print the ID for the given value. |
3292 | void printOperand(Value value) override { printValueID(value); } |
3293 | void printOperand(Value value, raw_ostream &os) override { |
3294 | printValueID(value, /*printResultNo=*/true, streamOverride: &os); |
3295 | } |
3296 | |
3297 | /// Print an optional attribute dictionary with a given set of elided values. |
3298 | void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
3299 | ArrayRef<StringRef> elidedAttrs = {}) override { |
3300 | Impl::printOptionalAttrDict(attrs, elidedAttrs); |
3301 | } |
3302 | void printOptionalAttrDictWithKeyword( |
3303 | ArrayRef<NamedAttribute> attrs, |
3304 | ArrayRef<StringRef> elidedAttrs = {}) override { |
3305 | Impl::printOptionalAttrDict(attrs, elidedAttrs, |
3306 | /*withKeyword=*/true); |
3307 | } |
3308 | |
3309 | /// Print the given successor. |
3310 | void printSuccessor(Block *successor) override; |
3311 | |
3312 | /// Print an operation successor with the operands used for the block |
3313 | /// arguments. |
3314 | void printSuccessorAndUseList(Block *successor, |
3315 | ValueRange succOperands) override; |
3316 | |
3317 | /// Print the given region. |
3318 | void printRegion(Region ®ion, bool printEntryBlockArgs, |
3319 | bool printBlockTerminators, bool printEmptyBlock) override; |
3320 | |
3321 | /// Renumber the arguments for the specified region to the same names as the |
3322 | /// SSA values in namesToUse. This may only be used for IsolatedFromAbove |
3323 | /// operations. If any entry in namesToUse is null, the corresponding |
3324 | /// argument name is left alone. |
3325 | void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override { |
3326 | state.getSSANameState().shadowRegionArgs(region, namesToUse); |
3327 | } |
3328 | |
3329 | /// Print the given affine map with the symbol and dimension operands printed |
3330 | /// inline with the map. |
3331 | void printAffineMapOfSSAIds(AffineMapAttr mapAttr, |
3332 | ValueRange operands) override; |
3333 | |
3334 | /// Print the given affine expression with the symbol and dimension operands |
3335 | /// printed inline with the expression. |
3336 | void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, |
3337 | ValueRange symOperands) override; |
3338 | |
3339 | /// Print users of this operation or id of this operation if it has no result. |
3340 | void printUsersComment(Operation *op); |
3341 | |
3342 | /// Print users of this block arg. |
3343 | void printUsersComment(BlockArgument arg); |
3344 | |
3345 | /// Print the users of a value. |
3346 | void printValueUsers(Value value); |
3347 | |
3348 | /// Print either the ids of the result values or the id of the operation if |
3349 | /// the operation has no results. |
3350 | void printUserIDs(Operation *user, bool prefixComma = false); |
3351 | |
3352 | private: |
3353 | /// This class represents a resource builder implementation for the MLIR |
3354 | /// textual assembly format. |
3355 | class ResourceBuilder : public AsmResourceBuilder { |
3356 | public: |
3357 | using ValueFn = function_ref<void(raw_ostream &)>; |
3358 | using PrintFn = function_ref<void(StringRef, ValueFn)>; |
3359 | |
3360 | ResourceBuilder(PrintFn printFn) : printFn(printFn) {} |
3361 | ~ResourceBuilder() override = default; |
3362 | |
3363 | void buildBool(StringRef key, bool data) final { |
3364 | printFn(key, [&](raw_ostream &os) { os << (data ? "true" : "false" ); }); |
3365 | } |
3366 | |
3367 | void buildString(StringRef key, StringRef data) final { |
3368 | printFn(key, [&](raw_ostream &os) { |
3369 | os << "\"" ; |
3370 | llvm::printEscapedString(Name: data, Out&: os); |
3371 | os << "\"" ; |
3372 | }); |
3373 | } |
3374 | |
3375 | void buildBlob(StringRef key, ArrayRef<char> data, |
3376 | uint32_t dataAlignment) final { |
3377 | printFn(key, [&](raw_ostream &os) { |
3378 | // Store the blob in a hex string containing the alignment and the data. |
3379 | llvm::support::ulittle32_t dataAlignmentLE(dataAlignment); |
3380 | os << "\"0x" |
3381 | << llvm::toHex(Input: StringRef(reinterpret_cast<char *>(&dataAlignmentLE), |
3382 | sizeof(dataAlignment))) |
3383 | << llvm::toHex(Input: StringRef(data.data(), data.size())) << "\"" ; |
3384 | }); |
3385 | } |
3386 | |
3387 | private: |
3388 | PrintFn printFn; |
3389 | }; |
3390 | |
3391 | /// Print the metadata dictionary for the file, eliding it if it is empty. |
3392 | void printFileMetadataDictionary(Operation *op); |
3393 | |
3394 | /// Print the resource sections for the file metadata dictionary. |
3395 | /// `checkAddMetadataDict` is used to indicate that metadata is going to be |
3396 | /// added, and the file metadata dictionary should be started if it hasn't |
3397 | /// yet. |
3398 | void printResourceFileMetadata(function_ref<void()> checkAddMetadataDict, |
3399 | Operation *op); |
3400 | |
3401 | // Contains the stack of default dialects to use when printing regions. |
3402 | // A new dialect is pushed to the stack before parsing regions nested under an |
3403 | // operation implementing `OpAsmOpInterface`, and popped when done. At the |
3404 | // top-level we start with "builtin" as the default, so that the top-level |
3405 | // `module` operation prints as-is. |
3406 | SmallVector<StringRef> defaultDialectStack{"builtin" }; |
3407 | |
3408 | /// The number of spaces used for indenting nested operations. |
3409 | const static unsigned indentWidth = 2; |
3410 | |
3411 | // This is the current indentation level for nested structures. |
3412 | unsigned currentIndent = 0; |
3413 | }; |
3414 | } // namespace |
3415 | |
3416 | void OperationPrinter::printTopLevelOperation(Operation *op) { |
3417 | // Output the aliases at the top level that can't be deferred. |
3418 | state.getAliasState().printNonDeferredAliases(p&: *this, newLine); |
3419 | |
3420 | // Print the module. |
3421 | printFullOpWithIndentAndLoc(op); |
3422 | os << newLine; |
3423 | |
3424 | // Output the aliases at the top level that can be deferred. |
3425 | state.getAliasState().printDeferredAliases(p&: *this, newLine); |
3426 | |
3427 | // Output any file level metadata. |
3428 | printFileMetadataDictionary(op); |
3429 | } |
3430 | |
3431 | void OperationPrinter::printFileMetadataDictionary(Operation *op) { |
3432 | bool sawMetadataEntry = false; |
3433 | auto checkAddMetadataDict = [&] { |
3434 | if (!std::exchange(obj&: sawMetadataEntry, new_val: true)) |
3435 | os << newLine << "{-#" << newLine; |
3436 | }; |
3437 | |
3438 | // Add the various types of metadata. |
3439 | printResourceFileMetadata(checkAddMetadataDict, op); |
3440 | |
3441 | // If the file dictionary exists, close it. |
3442 | if (sawMetadataEntry) |
3443 | os << newLine << "#-}" << newLine; |
3444 | } |
3445 | |
3446 | void OperationPrinter::printResourceFileMetadata( |
3447 | function_ref<void()> checkAddMetadataDict, Operation *op) { |
3448 | // Functor used to add data entries to the file metadata dictionary. |
3449 | bool hadResource = false; |
3450 | bool needResourceComma = false; |
3451 | bool needEntryComma = false; |
3452 | auto processProvider = [&](StringRef dictName, StringRef name, auto &provider, |
3453 | auto &&...providerArgs) { |
3454 | bool hadEntry = false; |
3455 | auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) { |
3456 | checkAddMetadataDict(); |
3457 | |
3458 | std::string resourceStr; |
3459 | auto printResourceStr = [&](raw_ostream &os) { os << resourceStr; }; |
3460 | std::optional<uint64_t> charLimit = |
3461 | printerFlags.getLargeResourceStringLimit(); |
3462 | if (charLimit.has_value()) { |
3463 | // Don't compute resourceStr when charLimit is 0. |
3464 | if (charLimit.value() == 0) |
3465 | return; |
3466 | |
3467 | llvm::raw_string_ostream ss(resourceStr); |
3468 | valueFn(ss); |
3469 | |
3470 | // Only print entry if its string is small enough. |
3471 | if (resourceStr.size() > charLimit.value()) |
3472 | return; |
3473 | |
3474 | // Don't recompute resourceStr when valueFn is called below. |
3475 | valueFn = printResourceStr; |
3476 | } |
3477 | |
3478 | // Emit the top-level resource entry if we haven't yet. |
3479 | if (!std::exchange(obj&: hadResource, new_val: true)) { |
3480 | if (needResourceComma) |
3481 | os << "," << newLine; |
3482 | os << " " << dictName << "_resources: {" << newLine; |
3483 | } |
3484 | // Emit the parent resource entry if we haven't yet. |
3485 | if (!std::exchange(obj&: hadEntry, new_val: true)) { |
3486 | if (needEntryComma) |
3487 | os << "," << newLine; |
3488 | os << " " << name << ": {" << newLine; |
3489 | } else { |
3490 | os << "," << newLine; |
3491 | } |
3492 | os << " " ; |
3493 | ::printKeywordOrString(keyword: key, os); |
3494 | os << ": " ; |
3495 | // Call printResourceStr or original valueFn, depending on charLimit. |
3496 | valueFn(os); |
3497 | }; |
3498 | ResourceBuilder entryBuilder(printFn); |
3499 | provider.buildResources(op, providerArgs..., entryBuilder); |
3500 | |
3501 | needEntryComma |= hadEntry; |
3502 | if (hadEntry) |
3503 | os << newLine << " }" ; |
3504 | }; |
3505 | |
3506 | // Print the `dialect_resources` section if we have any dialects with |
3507 | // resources. |
3508 | for (const OpAsmDialectInterface &interface : state.getDialectInterfaces()) { |
3509 | auto &dialectResources = state.getDialectResources(); |
3510 | StringRef name = interface.getDialect()->getNamespace(); |
3511 | auto it = dialectResources.find(Val: interface.getDialect()); |
3512 | if (it != dialectResources.end()) |
3513 | processProvider("dialect" , name, interface, it->second); |
3514 | else |
3515 | processProvider("dialect" , name, interface, |
3516 | SetVector<AsmDialectResourceHandle>()); |
3517 | } |
3518 | if (hadResource) |
3519 | os << newLine << " }" ; |
3520 | |
3521 | // Print the `external_resources` section if we have any external clients with |
3522 | // resources. |
3523 | needEntryComma = false; |
3524 | needResourceComma = hadResource; |
3525 | hadResource = false; |
3526 | for (const auto &printer : state.getResourcePrinters()) |
3527 | processProvider("external" , printer.getName(), printer); |
3528 | if (hadResource) |
3529 | os << newLine << " }" ; |
3530 | } |
3531 | |
3532 | /// Print a block argument in the usual format of: |
3533 | /// %ssaName : type {attr1=42} loc("here") |
3534 | /// where location printing is controlled by the standard internal option. |
3535 | /// You may pass omitType=true to not print a type, and pass an empty |
3536 | /// attribute list if you don't care for attributes. |
3537 | void OperationPrinter::printRegionArgument(BlockArgument arg, |
3538 | ArrayRef<NamedAttribute> argAttrs, |
3539 | bool omitType) { |
3540 | printOperand(value: arg); |
3541 | if (!omitType) { |
3542 | os << ": " ; |
3543 | printType(type: arg.getType()); |
3544 | } |
3545 | printOptionalAttrDict(attrs: argAttrs); |
3546 | // TODO: We should allow location aliases on block arguments. |
3547 | printTrailingLocation(loc: arg.getLoc(), /*allowAlias*/ false); |
3548 | } |
3549 | |
3550 | void OperationPrinter::printFullOpWithIndentAndLoc(Operation *op) { |
3551 | // Track the location of this operation. |
3552 | state.registerOperationLocation(op, line: newLine.curLine, col: currentIndent); |
3553 | |
3554 | os.indent(NumSpaces: currentIndent); |
3555 | printFullOp(op); |
3556 | printTrailingLocation(loc: op->getLoc()); |
3557 | if (printerFlags.shouldPrintValueUsers()) |
3558 | printUsersComment(op); |
3559 | } |
3560 | |
3561 | void OperationPrinter::printFullOp(Operation *op) { |
3562 | if (size_t numResults = op->getNumResults()) { |
3563 | auto printResultGroup = [&](size_t resultNo, size_t resultCount) { |
3564 | printValueID(value: op->getResult(idx: resultNo), /*printResultNo=*/false); |
3565 | if (resultCount > 1) |
3566 | os << ':' << resultCount; |
3567 | }; |
3568 | |
3569 | // Check to see if this operation has multiple result groups. |
3570 | ArrayRef<int> resultGroups = state.getSSANameState().getOpResultGroups(op); |
3571 | if (!resultGroups.empty()) { |
3572 | // Interleave the groups excluding the last one, this one will be handled |
3573 | // separately. |
3574 | interleaveComma(c: llvm::seq<int>(Begin: 0, End: resultGroups.size() - 1), eachFn: [&](int i) { |
3575 | printResultGroup(resultGroups[i], |
3576 | resultGroups[i + 1] - resultGroups[i]); |
3577 | }); |
3578 | os << ", " ; |
3579 | printResultGroup(resultGroups.back(), numResults - resultGroups.back()); |
3580 | |
3581 | } else { |
3582 | printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults); |
3583 | } |
3584 | |
3585 | os << " = " ; |
3586 | } |
3587 | |
3588 | printCustomOrGenericOp(op); |
3589 | } |
3590 | |
3591 | void OperationPrinter::(Operation *op) { |
3592 | unsigned numResults = op->getNumResults(); |
3593 | if (!numResults && op->getNumOperands()) { |
3594 | os << " // id: " ; |
3595 | printOperationID(op); |
3596 | } else if (numResults && op->use_empty()) { |
3597 | os << " // unused" ; |
3598 | } else if (numResults && !op->use_empty()) { |
3599 | // Print "user" if the operation has one result used to compute one other |
3600 | // result, or is used in one operation with no result. |
3601 | unsigned usedInNResults = 0; |
3602 | unsigned usedInNOperations = 0; |
3603 | SmallPtrSet<Operation *, 1> userSet; |
3604 | for (Operation *user : op->getUsers()) { |
3605 | if (userSet.insert(Ptr: user).second) { |
3606 | ++usedInNOperations; |
3607 | usedInNResults += user->getNumResults(); |
3608 | } |
3609 | } |
3610 | |
3611 | // We already know that users is not empty. |
3612 | bool exactlyOneUniqueUse = |
3613 | usedInNResults <= 1 && usedInNOperations <= 1 && numResults == 1; |
3614 | os << " // " << (exactlyOneUniqueUse ? "user" : "users" ) << ": " ; |
3615 | bool shouldPrintBrackets = numResults > 1; |
3616 | auto printOpResult = [&](OpResult opResult) { |
3617 | if (shouldPrintBrackets) |
3618 | os << "(" ; |
3619 | printValueUsers(value: opResult); |
3620 | if (shouldPrintBrackets) |
3621 | os << ")" ; |
3622 | }; |
3623 | |
3624 | interleaveComma(c: op->getResults(), eachFn: printOpResult); |
3625 | } |
3626 | } |
3627 | |
3628 | void OperationPrinter::(BlockArgument arg) { |
3629 | os << "// " ; |
3630 | printValueID(value: arg); |
3631 | if (arg.use_empty()) { |
3632 | os << " is unused" ; |
3633 | } else { |
3634 | os << " is used by " ; |
3635 | printValueUsers(value: arg); |
3636 | } |
3637 | os << newLine; |
3638 | } |
3639 | |
3640 | void OperationPrinter::printValueUsers(Value value) { |
3641 | if (value.use_empty()) |
3642 | os << "unused" ; |
3643 | |
3644 | // One value might be used as the operand of an operation more than once. |
3645 | // Only print the operations results once in that case. |
3646 | SmallPtrSet<Operation *, 1> userSet; |
3647 | for (auto [index, user] : enumerate(First: value.getUsers())) { |
3648 | if (userSet.insert(Ptr: user).second) |
3649 | printUserIDs(user, prefixComma: index); |
3650 | } |
3651 | } |
3652 | |
3653 | void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) { |
3654 | if (prefixComma) |
3655 | os << ", " ; |
3656 | |
3657 | if (!user->getNumResults()) { |
3658 | printOperationID(op: user); |
3659 | } else { |
3660 | interleaveComma(c: user->getResults(), |
3661 | eachFn: [this](Value result) { printValueID(value: result); }); |
3662 | } |
3663 | } |
3664 | |
3665 | void OperationPrinter::printCustomOrGenericOp(Operation *op) { |
3666 | // If requested, always print the generic form. |
3667 | if (!printerFlags.shouldPrintGenericOpForm()) { |
3668 | // Check to see if this is a known operation. If so, use the registered |
3669 | // custom printer hook. |
3670 | if (auto opInfo = op->getRegisteredInfo()) { |
3671 | opInfo->printAssembly(op, p&: *this, defaultDialect: defaultDialectStack.back()); |
3672 | return; |
3673 | } |
3674 | // Otherwise try to dispatch to the dialect, if available. |
3675 | if (Dialect *dialect = op->getDialect()) { |
3676 | if (auto opPrinter = dialect->getOperationPrinter(op)) { |
3677 | // Print the op name first. |
3678 | StringRef name = op->getName().getStringRef(); |
3679 | // Only drop the default dialect prefix when it cannot lead to |
3680 | // ambiguities. |
3681 | if (name.count(C: '.') == 1) |
3682 | name.consume_front(Prefix: (defaultDialectStack.back() + "." ).str()); |
3683 | os << name; |
3684 | |
3685 | // Print the rest of the op now. |
3686 | opPrinter(op, *this); |
3687 | return; |
3688 | } |
3689 | } |
3690 | } |
3691 | |
3692 | // Otherwise print with the generic assembly form. |
3693 | printGenericOp(op, /*printOpName=*/true); |
3694 | } |
3695 | |
3696 | void OperationPrinter::printGenericOp(Operation *op, bool printOpName) { |
3697 | if (printOpName) |
3698 | printEscapedString(str: op->getName().getStringRef()); |
3699 | os << '('; |
3700 | interleaveComma(c: op->getOperands(), eachFn: [&](Value value) { printValueID(value); }); |
3701 | os << ')'; |
3702 | |
3703 | // For terminators, print the list of successors and their operands. |
3704 | if (op->getNumSuccessors() != 0) { |
3705 | os << '['; |
3706 | interleaveComma(c: op->getSuccessors(), |
3707 | eachFn: [&](Block *successor) { printBlockName(block: successor); }); |
3708 | os << ']'; |
3709 | } |
3710 | |
3711 | // Print the properties. |
3712 | if (Attribute prop = op->getPropertiesAsAttribute()) { |
3713 | os << " <" ; |
3714 | Impl::printAttribute(attr: prop); |
3715 | os << '>'; |
3716 | } |
3717 | |
3718 | // Print regions. |
3719 | if (op->getNumRegions() != 0) { |
3720 | os << " (" ; |
3721 | interleaveComma(c: op->getRegions(), eachFn: [&](Region ®ion) { |
3722 | printRegion(region, /*printEntryBlockArgs=*/true, |
3723 | /*printBlockTerminators=*/true, /*printEmptyBlock=*/true); |
3724 | }); |
3725 | os << ')'; |
3726 | } |
3727 | |
3728 | printOptionalAttrDict(attrs: op->getPropertiesStorage() |
3729 | ? llvm::to_vector(op->getDiscardableAttrs()) |
3730 | : op->getAttrs()); |
3731 | |
3732 | // Print the type signature of the operation. |
3733 | os << " : " ; |
3734 | printFunctionalType(op); |
3735 | } |
3736 | |
3737 | void OperationPrinter::printBlockName(Block *block) { |
3738 | os << state.getSSANameState().getBlockInfo(block).name; |
3739 | } |
3740 | |
3741 | void OperationPrinter::print(Block *block, bool printBlockArgs, |
3742 | bool printBlockTerminator) { |
3743 | // Print the block label and argument list if requested. |
3744 | if (printBlockArgs) { |
3745 | os.indent(NumSpaces: currentIndent); |
3746 | printBlockName(block); |
3747 | |
3748 | // Print the argument list if non-empty. |
3749 | if (!block->args_empty()) { |
3750 | os << '('; |
3751 | interleaveComma(c: block->getArguments(), eachFn: [&](BlockArgument arg) { |
3752 | printValueID(value: arg); |
3753 | os << ": " ; |
3754 | printType(type: arg.getType()); |
3755 | // TODO: We should allow location aliases on block arguments. |
3756 | printTrailingLocation(loc: arg.getLoc(), /*allowAlias*/ false); |
3757 | }); |
3758 | os << ')'; |
3759 | } |
3760 | os << ':'; |
3761 | |
3762 | // Print out some context information about the predecessors of this block. |
3763 | if (!block->getParent()) { |
3764 | os << " // block is not in a region!" ; |
3765 | } else if (block->hasNoPredecessors()) { |
3766 | if (!block->isEntryBlock()) |
3767 | os << " // no predecessors" ; |
3768 | } else if (auto *pred = block->getSinglePredecessor()) { |
3769 | os << " // pred: " ; |
3770 | printBlockName(block: pred); |
3771 | } else { |
3772 | // We want to print the predecessors in a stable order, not in |
3773 | // whatever order the use-list is in, so gather and sort them. |
3774 | SmallVector<BlockInfo, 4> predIDs; |
3775 | for (auto *pred : block->getPredecessors()) |
3776 | predIDs.push_back(Elt: state.getSSANameState().getBlockInfo(block: pred)); |
3777 | llvm::sort(C&: predIDs, Comp: [](BlockInfo lhs, BlockInfo rhs) { |
3778 | return lhs.ordering < rhs.ordering; |
3779 | }); |
3780 | |
3781 | os << " // " << predIDs.size() << " preds: " ; |
3782 | |
3783 | interleaveComma(c: predIDs, eachFn: [&](BlockInfo pred) { os << pred.name; }); |
3784 | } |
3785 | os << newLine; |
3786 | } |
3787 | |
3788 | currentIndent += indentWidth; |
3789 | |
3790 | if (printerFlags.shouldPrintValueUsers()) { |
3791 | for (BlockArgument arg : block->getArguments()) { |
3792 | os.indent(NumSpaces: currentIndent); |
3793 | printUsersComment(arg); |
3794 | } |
3795 | } |
3796 | |
3797 | bool hasTerminator = |
3798 | !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>(); |
3799 | auto range = llvm::make_range( |
3800 | x: block->begin(), |
3801 | y: std::prev(x: block->end(), |
3802 | n: (!hasTerminator || printBlockTerminator) ? 0 : 1)); |
3803 | for (auto &op : range) { |
3804 | printFullOpWithIndentAndLoc(op: &op); |
3805 | os << newLine; |
3806 | } |
3807 | currentIndent -= indentWidth; |
3808 | } |
3809 | |
3810 | void OperationPrinter::printValueID(Value value, bool printResultNo, |
3811 | raw_ostream *streamOverride) const { |
3812 | state.getSSANameState().printValueID(value, printResultNo, |
3813 | stream&: streamOverride ? *streamOverride : os); |
3814 | } |
3815 | |
3816 | void OperationPrinter::printOperationID(Operation *op, |
3817 | raw_ostream *streamOverride) const { |
3818 | state.getSSANameState().printOperationID(op, stream&: streamOverride ? *streamOverride |
3819 | : os); |
3820 | } |
3821 | |
3822 | void OperationPrinter::printSuccessor(Block *successor) { |
3823 | printBlockName(block: successor); |
3824 | } |
3825 | |
3826 | void OperationPrinter::printSuccessorAndUseList(Block *successor, |
3827 | ValueRange succOperands) { |
3828 | printBlockName(block: successor); |
3829 | if (succOperands.empty()) |
3830 | return; |
3831 | |
3832 | os << '('; |
3833 | interleaveComma(c: succOperands, |
3834 | eachFn: [this](Value operand) { printValueID(value: operand); }); |
3835 | os << " : " ; |
3836 | interleaveComma(c: succOperands, |
3837 | eachFn: [this](Value operand) { printType(type: operand.getType()); }); |
3838 | os << ')'; |
3839 | } |
3840 | |
3841 | void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs, |
3842 | bool printBlockTerminators, |
3843 | bool printEmptyBlock) { |
3844 | if (printerFlags.shouldSkipRegions()) { |
3845 | os << "{...}" ; |
3846 | return; |
3847 | } |
3848 | os << "{" << newLine; |
3849 | if (!region.empty()) { |
3850 | auto restoreDefaultDialect = |
3851 | llvm::make_scope_exit(F: [&]() { defaultDialectStack.pop_back(); }); |
3852 | if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp())) |
3853 | defaultDialectStack.push_back(Elt: iface.getDefaultDialect()); |
3854 | else |
3855 | defaultDialectStack.push_back(Elt: "" ); |
3856 | |
3857 | auto *entryBlock = ®ion.front(); |
3858 | // Force printing the block header if printEmptyBlock is set and the block |
3859 | // is empty or if printEntryBlockArgs is set and there are arguments to |
3860 | // print. |
3861 | bool = |
3862 | (printEmptyBlock && entryBlock->empty()) || |
3863 | (printEntryBlockArgs && entryBlock->getNumArguments() != 0); |
3864 | print(block: entryBlock, printBlockArgs: shouldAlwaysPrintBlockHeader, printBlockTerminator: printBlockTerminators); |
3865 | for (auto &b : llvm::drop_begin(RangeOrContainer&: region.getBlocks(), N: 1)) |
3866 | print(block: &b); |
3867 | } |
3868 | os.indent(NumSpaces: currentIndent) << "}" ; |
3869 | } |
3870 | |
3871 | void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr, |
3872 | ValueRange operands) { |
3873 | if (!mapAttr) { |
3874 | os << "<<NULL AFFINE MAP>>" ; |
3875 | return; |
3876 | } |
3877 | AffineMap map = mapAttr.getValue(); |
3878 | unsigned numDims = map.getNumDims(); |
3879 | auto printValueName = [&](unsigned pos, bool isSymbol) { |
3880 | unsigned index = isSymbol ? numDims + pos : pos; |
3881 | assert(index < operands.size()); |
3882 | if (isSymbol) |
3883 | os << "symbol(" ; |
3884 | printValueID(value: operands[index]); |
3885 | if (isSymbol) |
3886 | os << ')'; |
3887 | }; |
3888 | |
3889 | interleaveComma(c: map.getResults(), eachFn: [&](AffineExpr expr) { |
3890 | printAffineExpr(expr, printValueName); |
3891 | }); |
3892 | } |
3893 | |
3894 | void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr, |
3895 | ValueRange dimOperands, |
3896 | ValueRange symOperands) { |
3897 | auto printValueName = [&](unsigned pos, bool isSymbol) { |
3898 | if (!isSymbol) |
3899 | return printValueID(value: dimOperands[pos]); |
3900 | os << "symbol(" ; |
3901 | printValueID(value: symOperands[pos]); |
3902 | os << ')'; |
3903 | }; |
3904 | printAffineExpr(expr, printValueName); |
3905 | } |
3906 | |
3907 | //===----------------------------------------------------------------------===// |
3908 | // print and dump methods |
3909 | //===----------------------------------------------------------------------===// |
3910 | |
3911 | void Attribute::print(raw_ostream &os, bool elideType) const { |
3912 | if (!*this) { |
3913 | os << "<<NULL ATTRIBUTE>>" ; |
3914 | return; |
3915 | } |
3916 | |
3917 | AsmState state(getContext()); |
3918 | print(os, state, elideType); |
3919 | } |
3920 | void Attribute::print(raw_ostream &os, AsmState &state, bool elideType) const { |
3921 | using AttrTypeElision = AsmPrinter::Impl::AttrTypeElision; |
3922 | AsmPrinter::Impl(os, state.getImpl()) |
3923 | .printAttribute(attr: *this, typeElision: elideType ? AttrTypeElision::Must |
3924 | : AttrTypeElision::Never); |
3925 | } |
3926 | |
3927 | void Attribute::dump() const { |
3928 | print(os&: llvm::errs()); |
3929 | llvm::errs() << "\n" ; |
3930 | } |
3931 | |
3932 | void Attribute::printStripped(raw_ostream &os, AsmState &state) const { |
3933 | if (!*this) { |
3934 | os << "<<NULL ATTRIBUTE>>" ; |
3935 | return; |
3936 | } |
3937 | |
3938 | AsmPrinter::Impl subPrinter(os, state.getImpl()); |
3939 | if (succeeded(Result: subPrinter.printAlias(attr: *this))) |
3940 | return; |
3941 | |
3942 | auto &dialect = this->getDialect(); |
3943 | uint64_t posPrior = os.tell(); |
3944 | DialectAsmPrinter printer(subPrinter); |
3945 | dialect.printAttribute(*this, printer); |
3946 | if (posPrior != os.tell()) |
3947 | return; |
3948 | |
3949 | // Fallback to printing with prefix if the above failed to write anything |
3950 | // to the output stream. |
3951 | print(os, state); |
3952 | } |
3953 | void Attribute::printStripped(raw_ostream &os) const { |
3954 | if (!*this) { |
3955 | os << "<<NULL ATTRIBUTE>>" ; |
3956 | return; |
3957 | } |
3958 | |
3959 | AsmState state(getContext()); |
3960 | printStripped(os, state); |
3961 | } |
3962 | |
3963 | void Type::print(raw_ostream &os) const { |
3964 | if (!*this) { |
3965 | os << "<<NULL TYPE>>" ; |
3966 | return; |
3967 | } |
3968 | |
3969 | AsmState state(getContext()); |
3970 | print(os, state); |
3971 | } |
3972 | void Type::print(raw_ostream &os, AsmState &state) const { |
3973 | AsmPrinter::Impl(os, state.getImpl()).printType(type: *this); |
3974 | } |
3975 | |
3976 | void Type::dump() const { |
3977 | print(os&: llvm::errs()); |
3978 | llvm::errs() << "\n" ; |
3979 | } |
3980 | |
3981 | void AffineMap::dump() const { |
3982 | print(os&: llvm::errs()); |
3983 | llvm::errs() << "\n" ; |
3984 | } |
3985 | |
3986 | void IntegerSet::dump() const { |
3987 | print(os&: llvm::errs()); |
3988 | llvm::errs() << "\n" ; |
3989 | } |
3990 | |
3991 | void AffineExpr::print(raw_ostream &os) const { |
3992 | if (!expr) { |
3993 | os << "<<NULL AFFINE EXPR>>" ; |
3994 | return; |
3995 | } |
3996 | AsmState state(getContext()); |
3997 | AsmPrinter::Impl(os, state.getImpl()).printAffineExpr(expr: *this); |
3998 | } |
3999 | |
4000 | void AffineExpr::dump() const { |
4001 | print(os&: llvm::errs()); |
4002 | llvm::errs() << "\n" ; |
4003 | } |
4004 | |
4005 | void AffineMap::print(raw_ostream &os) const { |
4006 | if (!map) { |
4007 | os << "<<NULL AFFINE MAP>>" ; |
4008 | return; |
4009 | } |
4010 | AsmState state(getContext()); |
4011 | AsmPrinter::Impl(os, state.getImpl()).printAffineMap(map: *this); |
4012 | } |
4013 | |
4014 | void IntegerSet::print(raw_ostream &os) const { |
4015 | AsmState state(getContext()); |
4016 | AsmPrinter::Impl(os, state.getImpl()).printIntegerSet(set: *this); |
4017 | } |
4018 | |
4019 | void Value::print(raw_ostream &os) const { print(os, flags: OpPrintingFlags()); } |
4020 | void Value::print(raw_ostream &os, const OpPrintingFlags &flags) const { |
4021 | if (!impl) { |
4022 | os << "<<NULL VALUE>>" ; |
4023 | return; |
4024 | } |
4025 | |
4026 | if (auto *op = getDefiningOp()) |
4027 | return op->print(os, flags); |
4028 | // TODO: Improve BlockArgument print'ing. |
4029 | BlockArgument arg = llvm::cast<BlockArgument>(Val: *this); |
4030 | os << "<block argument> of type '" << arg.getType() |
4031 | << "' at index: " << arg.getArgNumber(); |
4032 | } |
4033 | void Value::print(raw_ostream &os, AsmState &state) const { |
4034 | if (!impl) { |
4035 | os << "<<NULL VALUE>>" ; |
4036 | return; |
4037 | } |
4038 | |
4039 | if (auto *op = getDefiningOp()) |
4040 | return op->print(os, state); |
4041 | |
4042 | // TODO: Improve BlockArgument print'ing. |
4043 | BlockArgument arg = llvm::cast<BlockArgument>(Val: *this); |
4044 | os << "<block argument> of type '" << arg.getType() |
4045 | << "' at index: " << arg.getArgNumber(); |
4046 | } |
4047 | |
4048 | void Value::dump() const { |
4049 | print(os&: llvm::errs()); |
4050 | llvm::errs() << "\n" ; |
4051 | } |
4052 | |
4053 | void Value::printAsOperand(raw_ostream &os, AsmState &state) const { |
4054 | // TODO: This doesn't necessarily capture all potential cases. |
4055 | // Currently, region arguments can be shadowed when printing the main |
4056 | // operation. If the IR hasn't been printed, this will produce the old SSA |
4057 | // name and not the shadowed name. |
4058 | state.getImpl().getSSANameState().printValueID(value: *this, /*printResultNo=*/true, |
4059 | stream&: os); |
4060 | } |
4061 | |
4062 | static Operation *findParent(Operation *op, bool shouldUseLocalScope) { |
4063 | do { |
4064 | // If we are printing local scope, stop at the first operation that is |
4065 | // isolated from above. |
4066 | if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>()) |
4067 | break; |
4068 | |
4069 | // Otherwise, traverse up to the next parent. |
4070 | Operation *parentOp = op->getParentOp(); |
4071 | if (!parentOp) |
4072 | break; |
4073 | op = parentOp; |
4074 | } while (true); |
4075 | return op; |
4076 | } |
4077 | |
4078 | void Value::printAsOperand(raw_ostream &os, |
4079 | const OpPrintingFlags &flags) const { |
4080 | Operation *op; |
4081 | if (auto result = llvm::dyn_cast<OpResult>(Val: *this)) { |
4082 | op = result.getOwner(); |
4083 | } else { |
4084 | op = llvm::cast<BlockArgument>(Val: *this).getOwner()->getParentOp(); |
4085 | if (!op) { |
4086 | os << "<<UNKNOWN SSA VALUE>>" ; |
4087 | return; |
4088 | } |
4089 | } |
4090 | op = findParent(op, shouldUseLocalScope: flags.shouldUseLocalScope()); |
4091 | AsmState state(op, flags); |
4092 | printAsOperand(os, state); |
4093 | } |
4094 | |
4095 | void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) { |
4096 | // Find the operation to number from based upon the provided flags. |
4097 | Operation *op = findParent(op: this, shouldUseLocalScope: printerFlags.shouldUseLocalScope()); |
4098 | AsmState state(op, printerFlags); |
4099 | print(os, state); |
4100 | } |
4101 | void Operation::print(raw_ostream &os, AsmState &state) { |
4102 | OperationPrinter printer(os, state.getImpl()); |
4103 | if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) { |
4104 | state.getImpl().initializeAliases(op: this); |
4105 | printer.printTopLevelOperation(op: this); |
4106 | } else { |
4107 | printer.printFullOpWithIndentAndLoc(op: this); |
4108 | } |
4109 | } |
4110 | |
4111 | void Operation::dump() { |
4112 | print(os&: llvm::errs(), printerFlags: OpPrintingFlags().useLocalScope()); |
4113 | llvm::errs() << "\n" ; |
4114 | } |
4115 | |
4116 | void Operation::dumpPretty() { |
4117 | print(os&: llvm::errs(), printerFlags: OpPrintingFlags().useLocalScope().assumeVerified()); |
4118 | llvm::errs() << "\n" ; |
4119 | } |
4120 | |
4121 | void Block::print(raw_ostream &os) { |
4122 | Operation *parentOp = getParentOp(); |
4123 | if (!parentOp) { |
4124 | os << "<<UNLINKED BLOCK>>\n" ; |
4125 | return; |
4126 | } |
4127 | // Get the top-level op. |
4128 | while (auto *nextOp = parentOp->getParentOp()) |
4129 | parentOp = nextOp; |
4130 | |
4131 | AsmState state(parentOp); |
4132 | print(os, state); |
4133 | } |
4134 | void Block::print(raw_ostream &os, AsmState &state) { |
4135 | OperationPrinter(os, state.getImpl()).print(block: this); |
4136 | } |
4137 | |
4138 | void Block::dump() { print(os&: llvm::errs()); } |
4139 | |
4140 | /// Print out the name of the block without printing its body. |
4141 | void Block::printAsOperand(raw_ostream &os, bool printType) { |
4142 | Operation *parentOp = getParentOp(); |
4143 | if (!parentOp) { |
4144 | os << "<<UNLINKED BLOCK>>\n" ; |
4145 | return; |
4146 | } |
4147 | AsmState state(parentOp); |
4148 | printAsOperand(os, state); |
4149 | } |
4150 | void Block::printAsOperand(raw_ostream &os, AsmState &state) { |
4151 | OperationPrinter printer(os, state.getImpl()); |
4152 | printer.printBlockName(block: this); |
4153 | } |
4154 | |
4155 | raw_ostream &mlir::operator<<(raw_ostream &os, Block &block) { |
4156 | block.print(os); |
4157 | return os; |
4158 | } |
4159 | |
4160 | //===--------------------------------------------------------------------===// |
4161 | // Custom printers |
4162 | //===--------------------------------------------------------------------===// |
4163 | namespace mlir { |
4164 | |
4165 | void printDimensionList(OpAsmPrinter &printer, Operation *op, |
4166 | ArrayRef<int64_t> dimensions) { |
4167 | if (dimensions.empty()) |
4168 | printer << "[" ; |
4169 | printer.printDimensionList(shape: dimensions); |
4170 | if (dimensions.empty()) |
4171 | printer << "]" ; |
4172 | } |
4173 | |
4174 | ParseResult parseDimensionList(OpAsmParser &parser, |
4175 | DenseI64ArrayAttr &dimensions) { |
4176 | // Empty list case denoted by "[]". |
4177 | if (succeeded(Result: parser.parseOptionalLSquare())) { |
4178 | if (failed(Result: parser.parseRSquare())) { |
4179 | return parser.emitError(loc: parser.getCurrentLocation()) |
4180 | << "Failed parsing dimension list." ; |
4181 | } |
4182 | dimensions = |
4183 | DenseI64ArrayAttr::get(parser.getContext(), ArrayRef<int64_t>()); |
4184 | return success(); |
4185 | } |
4186 | |
4187 | // Non-empty list case. |
4188 | SmallVector<int64_t> shapeArr; |
4189 | if (failed(Result: parser.parseDimensionList(dimensions&: shapeArr, allowDynamic: true, withTrailingX: false))) { |
4190 | return parser.emitError(loc: parser.getCurrentLocation()) |
4191 | << "Failed parsing dimension list." ; |
4192 | } |
4193 | if (shapeArr.empty()) { |
4194 | return parser.emitError(loc: parser.getCurrentLocation()) |
4195 | << "Failed parsing dimension list. Did you mean an empty list? It " |
4196 | "must be denoted by \"[]\"." ; |
4197 | } |
4198 | dimensions = DenseI64ArrayAttr::get(parser.getContext(), shapeArr); |
4199 | return success(); |
4200 | } |
4201 | |
4202 | } // namespace mlir |
4203 | |