1 | //===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This classes used by the implementation details of Op types. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_IR_OPIMPLEMENTATION_H |
14 | #define MLIR_IR_OPIMPLEMENTATION_H |
15 | |
16 | #include "mlir/IR/BuiltinTypes.h" |
17 | #include "mlir/IR/DialectInterface.h" |
18 | #include "mlir/IR/OpDefinition.h" |
19 | #include "llvm/ADT/Twine.h" |
20 | #include "llvm/Support/SMLoc.h" |
21 | #include <optional> |
22 | |
23 | namespace mlir { |
24 | class AsmParsedResourceEntry; |
25 | class AsmResourceBuilder; |
26 | class Builder; |
27 | |
28 | //===----------------------------------------------------------------------===// |
29 | // AsmDialectResourceHandle |
30 | //===----------------------------------------------------------------------===// |
31 | |
32 | /// This class represents an opaque handle to a dialect resource entry. |
33 | class AsmDialectResourceHandle { |
34 | public: |
35 | AsmDialectResourceHandle() = default; |
36 | AsmDialectResourceHandle(void *resource, TypeID resourceID, Dialect *dialect) |
37 | : resource(resource), opaqueID(resourceID), dialect(dialect) {} |
38 | bool operator==(const AsmDialectResourceHandle &other) const { |
39 | return resource == other.resource; |
40 | } |
41 | |
42 | /// Return an opaque pointer to the referenced resource. |
43 | void *getResource() const { return resource; } |
44 | |
45 | /// Return the type ID of the resource. |
46 | TypeID getTypeID() const { return opaqueID; } |
47 | |
48 | /// Return the dialect that owns the resource. |
49 | Dialect *getDialect() const { return dialect; } |
50 | |
51 | private: |
52 | /// The opaque handle to the dialect resource. |
53 | void *resource = nullptr; |
54 | /// The type of the resource referenced. |
55 | TypeID opaqueID; |
56 | /// The dialect owning the given resource. |
57 | Dialect *dialect; |
58 | }; |
59 | |
60 | /// This class represents a CRTP base class for dialect resource handles. It |
61 | /// abstracts away various utilities necessary for defined derived resource |
62 | /// handles. |
63 | template <typename DerivedT, typename ResourceT, typename DialectT> |
64 | class AsmDialectResourceHandleBase : public AsmDialectResourceHandle { |
65 | public: |
66 | using Dialect = DialectT; |
67 | |
68 | /// Construct a handle from a pointer to the resource. The given pointer |
69 | /// should be guaranteed to live beyond the life of this handle. |
70 | AsmDialectResourceHandleBase(ResourceT *resource, DialectT *dialect) |
71 | : AsmDialectResourceHandle(resource, TypeID::get<DerivedT>(), dialect) {} |
72 | AsmDialectResourceHandleBase(AsmDialectResourceHandle handle) |
73 | : AsmDialectResourceHandle(handle) { |
74 | assert(handle.getTypeID() == TypeID::get<DerivedT>()); |
75 | } |
76 | |
77 | /// Return the resource referenced by this handle. |
78 | ResourceT *getResource() { |
79 | return static_cast<ResourceT *>(AsmDialectResourceHandle::getResource()); |
80 | } |
81 | const ResourceT *getResource() const { |
82 | return const_cast<AsmDialectResourceHandleBase *>(this)->getResource(); |
83 | } |
84 | |
85 | /// Return the dialect that owns the resource. |
86 | DialectT *getDialect() const { |
87 | return static_cast<DialectT *>(AsmDialectResourceHandle::getDialect()); |
88 | } |
89 | |
90 | /// Support llvm style casting. |
91 | static bool classof(const AsmDialectResourceHandle *handle) { |
92 | return handle->getTypeID() == TypeID::get<DerivedT>(); |
93 | } |
94 | }; |
95 | |
96 | inline llvm::hash_code hash_value(const AsmDialectResourceHandle ¶m) { |
97 | return llvm::hash_value(ptr: param.getResource()); |
98 | } |
99 | |
100 | //===----------------------------------------------------------------------===// |
101 | // AsmPrinter |
102 | //===----------------------------------------------------------------------===// |
103 | |
104 | /// This base class exposes generic asm printer hooks, usable across the various |
105 | /// derived printers. |
106 | class AsmPrinter { |
107 | public: |
108 | /// This class contains the internal default implementation of the base |
109 | /// printer methods. |
110 | class Impl; |
111 | |
112 | /// Initialize the printer with the given internal implementation. |
113 | AsmPrinter(Impl &impl) : impl(&impl) {} |
114 | virtual ~AsmPrinter(); |
115 | |
116 | /// Return the raw output stream used by this printer. |
117 | virtual raw_ostream &getStream() const; |
118 | |
119 | /// Print the given floating point value in a stabilized form that can be |
120 | /// roundtripped through the IR. This is the companion to the 'parseFloat' |
121 | /// hook on the AsmParser. |
122 | virtual void printFloat(const APFloat &value); |
123 | |
124 | virtual void printType(Type type); |
125 | virtual void printAttribute(Attribute attr); |
126 | |
127 | /// Trait to check if `AttrType` provides a `print` method. |
128 | template <typename AttrOrType> |
129 | using has_print_method = |
130 | decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>())); |
131 | template <typename AttrOrType> |
132 | using detect_has_print_method = |
133 | llvm::is_detected<has_print_method, AttrOrType>; |
134 | |
135 | /// Print the provided attribute in the context of an operation custom |
136 | /// printer/parser: this will invoke directly the print method on the |
137 | /// attribute class and skip the `#dialect.mnemonic` prefix in most cases. |
138 | template <typename AttrOrType, |
139 | std::enable_if_t<detect_has_print_method<AttrOrType>::value> |
140 | *sfinae = nullptr> |
141 | void printStrippedAttrOrType(AttrOrType attrOrType) { |
142 | if (succeeded(printAlias(attrOrType))) |
143 | return; |
144 | |
145 | raw_ostream &os = getStream(); |
146 | uint64_t posPrior = os.tell(); |
147 | attrOrType.print(*this); |
148 | if (posPrior != os.tell()) |
149 | return; |
150 | |
151 | // Fallback to printing with prefix if the above failed to write anything |
152 | // to the output stream. |
153 | *this << attrOrType; |
154 | } |
155 | |
156 | /// Print the provided array of attributes or types in the context of an |
157 | /// operation custom printer/parser: this will invoke directly the print |
158 | /// method on the attribute class and skip the `#dialect.mnemonic` prefix in |
159 | /// most cases. |
160 | template <typename AttrOrType, |
161 | std::enable_if_t<detect_has_print_method<AttrOrType>::value> |
162 | *sfinae = nullptr> |
163 | void printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes) { |
164 | llvm::interleaveComma( |
165 | attrOrTypes, getStream(), |
166 | [this](AttrOrType attrOrType) { printStrippedAttrOrType(attrOrType); }); |
167 | } |
168 | |
169 | /// SFINAE for printing the provided attribute in the context of an operation |
170 | /// custom printer in the case where the attribute does not define a print |
171 | /// method. |
172 | template <typename AttrOrType, |
173 | std::enable_if_t<!detect_has_print_method<AttrOrType>::value> |
174 | *sfinae = nullptr> |
175 | void printStrippedAttrOrType(AttrOrType attrOrType) { |
176 | *this << attrOrType; |
177 | } |
178 | |
179 | /// Print the given attribute without its type. The corresponding parser must |
180 | /// provide a valid type for the attribute. |
181 | virtual void printAttributeWithoutType(Attribute attr); |
182 | |
183 | /// Print the alias for the given attribute, return failure if no alias could |
184 | /// be printed. |
185 | virtual LogicalResult printAlias(Attribute attr); |
186 | |
187 | /// Print the alias for the given type, return failure if no alias could |
188 | /// be printed. |
189 | virtual LogicalResult printAlias(Type type); |
190 | |
191 | /// Print the given string as a keyword, or a quoted and escaped string if it |
192 | /// has any special or non-printable characters in it. |
193 | virtual void printKeywordOrString(StringRef keyword); |
194 | |
195 | /// Print the given string as a quoted string, escaping any special or |
196 | /// non-printable characters in it. |
197 | virtual void printString(StringRef string); |
198 | |
199 | /// Print the given string as a symbol reference, i.e. a form representable by |
200 | /// a SymbolRefAttr. A symbol reference is represented as a string prefixed |
201 | /// with '@'. The reference is surrounded with ""'s and escaped if it has any |
202 | /// special or non-printable characters in it. |
203 | virtual void printSymbolName(StringRef symbolRef); |
204 | |
205 | /// Print a handle to the given dialect resource. |
206 | virtual void printResourceHandle(const AsmDialectResourceHandle &resource); |
207 | |
208 | /// Print an optional arrow followed by a type list. |
209 | template <typename TypeRange> |
210 | void printOptionalArrowTypeList(TypeRange &&types) { |
211 | if (types.begin() != types.end()) |
212 | printArrowTypeList(types); |
213 | } |
214 | template <typename TypeRange> |
215 | void printArrowTypeList(TypeRange &&types) { |
216 | auto &os = getStream() << " -> " ; |
217 | |
218 | bool wrapped = !llvm::hasSingleElement(types) || |
219 | llvm::isa<FunctionType>((*types.begin())); |
220 | if (wrapped) |
221 | os << '('; |
222 | llvm::interleaveComma(types, *this); |
223 | if (wrapped) |
224 | os << ')'; |
225 | } |
226 | |
227 | /// Print the two given type ranges in a functional form. |
228 | template <typename InputRangeT, typename ResultRangeT> |
229 | void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) { |
230 | auto &os = getStream(); |
231 | os << '('; |
232 | llvm::interleaveComma(inputs, *this); |
233 | os << ')'; |
234 | printArrowTypeList(results); |
235 | } |
236 | |
237 | void printDimensionList(ArrayRef<int64_t> shape); |
238 | |
239 | /// Class used to automatically end a cyclic region on destruction. |
240 | class CyclicPrintReset { |
241 | public: |
242 | explicit CyclicPrintReset(AsmPrinter *printer) : printer(printer) {} |
243 | |
244 | ~CyclicPrintReset() { |
245 | if (printer) |
246 | printer->popCyclicPrinting(); |
247 | } |
248 | |
249 | CyclicPrintReset(const CyclicPrintReset &) = delete; |
250 | |
251 | CyclicPrintReset &operator=(const CyclicPrintReset &) = delete; |
252 | |
253 | CyclicPrintReset(CyclicPrintReset &&rhs) |
254 | : printer(std::exchange(obj&: rhs.printer, new_val: nullptr)) {} |
255 | |
256 | CyclicPrintReset &operator=(CyclicPrintReset &&rhs) { |
257 | printer = std::exchange(obj&: rhs.printer, new_val: nullptr); |
258 | return *this; |
259 | } |
260 | |
261 | private: |
262 | AsmPrinter *printer; |
263 | }; |
264 | |
265 | /// Attempts to start a cyclic printing region for `attrOrType`. |
266 | /// A cyclic printing region starts with this call and ends with the |
267 | /// destruction of the returned `CyclicPrintReset`. During this time, |
268 | /// calling `tryStartCyclicPrint` with the same attribute in any printer |
269 | /// will lead to returning failure. |
270 | /// |
271 | /// This makes it possible to break infinite recursions when trying to print |
272 | /// cyclic attributes or types by printing only immutable parameters if nested |
273 | /// within itself. |
274 | template <class AttrOrTypeT> |
275 | FailureOr<CyclicPrintReset> tryStartCyclicPrint(AttrOrTypeT attrOrType) { |
276 | static_assert( |
277 | std::is_base_of_v<AttributeTrait::IsMutable<AttrOrTypeT>, |
278 | AttrOrTypeT> || |
279 | std::is_base_of_v<TypeTrait::IsMutable<AttrOrTypeT>, AttrOrTypeT>, |
280 | "Only mutable attributes or types can be cyclic" ); |
281 | if (failed(pushCyclicPrinting(opaquePointer: attrOrType.getAsOpaquePointer()))) |
282 | return failure(); |
283 | return CyclicPrintReset(this); |
284 | } |
285 | |
286 | protected: |
287 | /// Initialize the printer with no internal implementation. In this case, all |
288 | /// virtual methods of this class must be overriden. |
289 | AsmPrinter() = default; |
290 | |
291 | /// Pushes a new attribute or type in the form of a type erased pointer |
292 | /// into an internal set. |
293 | /// Returns success if the type or attribute was inserted in the set or |
294 | /// failure if it was already contained. |
295 | virtual LogicalResult pushCyclicPrinting(const void *opaquePointer); |
296 | |
297 | /// Removes the element that was last inserted with a successful call to |
298 | /// `pushCyclicPrinting`. There must be exactly one `popCyclicPrinting` call |
299 | /// in reverse order of all successful `pushCyclicPrinting`. |
300 | virtual void popCyclicPrinting(); |
301 | |
302 | private: |
303 | AsmPrinter(const AsmPrinter &) = delete; |
304 | void operator=(const AsmPrinter &) = delete; |
305 | |
306 | /// The internal implementation of the printer. |
307 | Impl *impl{nullptr}; |
308 | }; |
309 | |
310 | template <typename AsmPrinterT> |
311 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
312 | AsmPrinterT &> |
313 | operator<<(AsmPrinterT &p, Type type) { |
314 | p.printType(type); |
315 | return p; |
316 | } |
317 | |
318 | template <typename AsmPrinterT> |
319 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
320 | AsmPrinterT &> |
321 | operator<<(AsmPrinterT &p, Attribute attr) { |
322 | p.printAttribute(attr); |
323 | return p; |
324 | } |
325 | |
326 | template <typename AsmPrinterT> |
327 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
328 | AsmPrinterT &> |
329 | operator<<(AsmPrinterT &p, const APFloat &value) { |
330 | p.printFloat(value); |
331 | return p; |
332 | } |
333 | template <typename AsmPrinterT> |
334 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
335 | AsmPrinterT &> |
336 | operator<<(AsmPrinterT &p, float value) { |
337 | return p << APFloat(value); |
338 | } |
339 | template <typename AsmPrinterT> |
340 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
341 | AsmPrinterT &> |
342 | operator<<(AsmPrinterT &p, double value) { |
343 | return p << APFloat(value); |
344 | } |
345 | |
346 | // Support printing anything that isn't convertible to one of the other |
347 | // streamable types, even if it isn't exactly one of them. For example, we want |
348 | // to print FunctionType with the Type version above, not have it match this. |
349 | template <typename AsmPrinterT, typename T, |
350 | std::enable_if_t<!std::is_convertible<T &, Value &>::value && |
351 | !std::is_convertible<T &, Type &>::value && |
352 | !std::is_convertible<T &, Attribute &>::value && |
353 | !std::is_convertible<T &, ValueRange>::value && |
354 | !std::is_convertible<T &, APFloat &>::value && |
355 | !llvm::is_one_of<T, bool, float, double>::value, |
356 | T> * = nullptr> |
357 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
358 | AsmPrinterT &> |
359 | operator<<(AsmPrinterT &p, const T &other) { |
360 | p.getStream() << other; |
361 | return p; |
362 | } |
363 | |
364 | template <typename AsmPrinterT> |
365 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
366 | AsmPrinterT &> |
367 | operator<<(AsmPrinterT &p, bool value) { |
368 | return p << (value ? StringRef("true" ) : "false" ); |
369 | } |
370 | |
371 | template <typename AsmPrinterT, typename ValueRangeT> |
372 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
373 | AsmPrinterT &> |
374 | operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) { |
375 | llvm::interleaveComma(types, p); |
376 | return p; |
377 | } |
378 | |
379 | template <typename AsmPrinterT> |
380 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
381 | AsmPrinterT &> |
382 | operator<<(AsmPrinterT &p, const TypeRange &types) { |
383 | llvm::interleaveComma(types, p); |
384 | return p; |
385 | } |
386 | |
387 | // Prevent matching the TypeRange version above for ValueRange |
388 | // printing through base AsmPrinter. This is needed so that the |
389 | // ValueRange printing behaviour does not change from printing |
390 | // the SSA values to printing the types for the operands when |
391 | // using AsmPrinter instead of OpAsmPrinter. |
392 | template <typename AsmPrinterT, typename T> |
393 | inline std::enable_if_t<std::is_same<AsmPrinter, AsmPrinterT>::value && |
394 | std::is_convertible<T &, ValueRange>::value, |
395 | AsmPrinterT &> |
396 | operator<<(AsmPrinterT &p, const T &other) = delete; |
397 | |
398 | template <typename AsmPrinterT, typename ElementT> |
399 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
400 | AsmPrinterT &> |
401 | operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) { |
402 | llvm::interleaveComma(types, p); |
403 | return p; |
404 | } |
405 | |
406 | //===----------------------------------------------------------------------===// |
407 | // OpAsmPrinter |
408 | //===----------------------------------------------------------------------===// |
409 | |
410 | /// This is a pure-virtual base class that exposes the asmprinter hooks |
411 | /// necessary to implement a custom print() method. |
412 | class OpAsmPrinter : public AsmPrinter { |
413 | public: |
414 | using AsmPrinter::AsmPrinter; |
415 | ~OpAsmPrinter() override; |
416 | |
417 | /// Print a loc(...) specifier if printing debug info is enabled. |
418 | virtual void printOptionalLocationSpecifier(Location loc) = 0; |
419 | |
420 | /// Print a newline and indent the printer to the start of the current |
421 | /// operation. |
422 | virtual void printNewline() = 0; |
423 | |
424 | /// Increase indentation. |
425 | virtual void increaseIndent() = 0; |
426 | |
427 | /// Decrease indentation. |
428 | virtual void decreaseIndent() = 0; |
429 | |
430 | /// Print a block argument in the usual format of: |
431 | /// %ssaName : type {attr1=42} loc("here") |
432 | /// where location printing is controlled by the standard internal option. |
433 | /// You may pass omitType=true to not print a type, and pass an empty |
434 | /// attribute list if you don't care for attributes. |
435 | virtual void printRegionArgument(BlockArgument arg, |
436 | ArrayRef<NamedAttribute> argAttrs = {}, |
437 | bool omitType = false) = 0; |
438 | |
439 | /// Print implementations for various things an operation contains. |
440 | virtual void printOperand(Value value) = 0; |
441 | virtual void printOperand(Value value, raw_ostream &os) = 0; |
442 | |
443 | /// Print a comma separated list of operands. |
444 | template <typename ContainerType> |
445 | void printOperands(const ContainerType &container) { |
446 | printOperands(container.begin(), container.end()); |
447 | } |
448 | |
449 | /// Print a comma separated list of operands. |
450 | template <typename IteratorType> |
451 | void printOperands(IteratorType it, IteratorType end) { |
452 | llvm::interleaveComma(llvm::make_range(it, end), getStream(), |
453 | [this](Value value) { printOperand(value); }); |
454 | } |
455 | |
456 | /// Print the given successor. |
457 | virtual void printSuccessor(Block *successor) = 0; |
458 | |
459 | /// Print the successor and its operands. |
460 | virtual void printSuccessorAndUseList(Block *successor, |
461 | ValueRange succOperands) = 0; |
462 | |
463 | /// If the specified operation has attributes, print out an attribute |
464 | /// dictionary with their values. elidedAttrs allows the client to ignore |
465 | /// specific well known attributes, commonly used if the attribute value is |
466 | /// printed some other way (like as a fixed operand). |
467 | virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
468 | ArrayRef<StringRef> elidedAttrs = {}) = 0; |
469 | |
470 | /// If the specified operation has attributes, print out an attribute |
471 | /// dictionary prefixed with 'attributes'. |
472 | virtual void |
473 | printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs, |
474 | ArrayRef<StringRef> elidedAttrs = {}) = 0; |
475 | |
476 | /// Prints the entire operation with the custom assembly form, if available, |
477 | /// or the generic assembly form, otherwise. |
478 | virtual void printCustomOrGenericOp(Operation *op) = 0; |
479 | |
480 | /// Print the entire operation with the default generic assembly form. |
481 | /// If `printOpName` is true, then the operation name is printed (the default) |
482 | /// otherwise it is omitted and the print will start with the operand list. |
483 | virtual void printGenericOp(Operation *op, bool printOpName = true) = 0; |
484 | |
485 | /// Prints a region. |
486 | /// If 'printEntryBlockArgs' is false, the arguments of the |
487 | /// block are not printed. If 'printBlockTerminator' is false, the terminator |
488 | /// operation of the block is not printed. If printEmptyBlock is true, then |
489 | /// the block header is printed even if the block is empty. |
490 | virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true, |
491 | bool printBlockTerminators = true, |
492 | bool printEmptyBlock = false) = 0; |
493 | |
494 | /// Renumber the arguments for the specified region to the same names as the |
495 | /// SSA values in namesToUse. This may only be used for IsolatedFromAbove |
496 | /// operations. If any entry in namesToUse is null, the corresponding |
497 | /// argument name is left alone. |
498 | virtual void shadowRegionArgs(Region ®ion, ValueRange namesToUse) = 0; |
499 | |
500 | /// Prints an affine map of SSA ids, where SSA id names are used in place |
501 | /// of dims/symbols. |
502 | /// Operand values must come from single-result sources, and be valid |
503 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
504 | virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, |
505 | ValueRange operands) = 0; |
506 | |
507 | /// Prints an affine expression of SSA ids with SSA id names used instead of |
508 | /// dims and symbols. |
509 | /// Operand values must come from single-result sources, and be valid |
510 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
511 | virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, |
512 | ValueRange symOperands) = 0; |
513 | |
514 | /// Print the complete type of an operation in functional form. |
515 | void printFunctionalType(Operation *op); |
516 | using AsmPrinter::printFunctionalType; |
517 | }; |
518 | |
519 | // Make the implementations convenient to use. |
520 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) { |
521 | p.printOperand(value); |
522 | return p; |
523 | } |
524 | |
525 | template <typename T, |
526 | std::enable_if_t<std::is_convertible<T &, ValueRange>::value && |
527 | !std::is_convertible<T &, Value &>::value, |
528 | T> * = nullptr> |
529 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) { |
530 | p.printOperands(values); |
531 | return p; |
532 | } |
533 | |
534 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) { |
535 | p.printSuccessor(successor: value); |
536 | return p; |
537 | } |
538 | |
539 | //===----------------------------------------------------------------------===// |
540 | // AsmParser |
541 | //===----------------------------------------------------------------------===// |
542 | |
543 | /// This base class exposes generic asm parser hooks, usable across the various |
544 | /// derived parsers. |
545 | class AsmParser { |
546 | public: |
547 | AsmParser() = default; |
548 | virtual ~AsmParser(); |
549 | |
550 | MLIRContext *getContext() const; |
551 | |
552 | /// Return the location of the original name token. |
553 | virtual SMLoc getNameLoc() const = 0; |
554 | |
555 | //===--------------------------------------------------------------------===// |
556 | // Utilities |
557 | //===--------------------------------------------------------------------===// |
558 | |
559 | /// Emit a diagnostic at the specified location and return failure. |
560 | virtual InFlightDiagnostic emitError(SMLoc loc, |
561 | const Twine &message = {}) = 0; |
562 | |
563 | /// Return a builder which provides useful access to MLIRContext, global |
564 | /// objects like types and attributes. |
565 | virtual Builder &getBuilder() const = 0; |
566 | |
567 | /// Get the location of the next token and store it into the argument. This |
568 | /// always succeeds. |
569 | virtual SMLoc getCurrentLocation() = 0; |
570 | ParseResult getCurrentLocation(SMLoc *loc) { |
571 | *loc = getCurrentLocation(); |
572 | return success(); |
573 | } |
574 | |
575 | /// Re-encode the given source location as an MLIR location and return it. |
576 | /// Note: This method should only be used when a `Location` is necessary, as |
577 | /// the encoding process is not efficient. |
578 | virtual Location getEncodedSourceLoc(SMLoc loc) = 0; |
579 | |
580 | //===--------------------------------------------------------------------===// |
581 | // Token Parsing |
582 | //===--------------------------------------------------------------------===// |
583 | |
584 | /// Parse a '->' token. |
585 | virtual ParseResult parseArrow() = 0; |
586 | |
587 | /// Parse a '->' token if present |
588 | virtual ParseResult parseOptionalArrow() = 0; |
589 | |
590 | /// Parse a `{` token. |
591 | virtual ParseResult parseLBrace() = 0; |
592 | |
593 | /// Parse a `{` token if present. |
594 | virtual ParseResult parseOptionalLBrace() = 0; |
595 | |
596 | /// Parse a `}` token. |
597 | virtual ParseResult parseRBrace() = 0; |
598 | |
599 | /// Parse a `}` token if present. |
600 | virtual ParseResult parseOptionalRBrace() = 0; |
601 | |
602 | /// Parse a `:` token. |
603 | virtual ParseResult parseColon() = 0; |
604 | |
605 | /// Parse a `:` token if present. |
606 | virtual ParseResult parseOptionalColon() = 0; |
607 | |
608 | /// Parse a `,` token. |
609 | virtual ParseResult parseComma() = 0; |
610 | |
611 | /// Parse a `,` token if present. |
612 | virtual ParseResult parseOptionalComma() = 0; |
613 | |
614 | /// Parse a `=` token. |
615 | virtual ParseResult parseEqual() = 0; |
616 | |
617 | /// Parse a `=` token if present. |
618 | virtual ParseResult parseOptionalEqual() = 0; |
619 | |
620 | /// Parse a '<' token. |
621 | virtual ParseResult parseLess() = 0; |
622 | |
623 | /// Parse a '<' token if present. |
624 | virtual ParseResult parseOptionalLess() = 0; |
625 | |
626 | /// Parse a '>' token. |
627 | virtual ParseResult parseGreater() = 0; |
628 | |
629 | /// Parse a '>' token if present. |
630 | virtual ParseResult parseOptionalGreater() = 0; |
631 | |
632 | /// Parse a '?' token. |
633 | virtual ParseResult parseQuestion() = 0; |
634 | |
635 | /// Parse a '?' token if present. |
636 | virtual ParseResult parseOptionalQuestion() = 0; |
637 | |
638 | /// Parse a '+' token. |
639 | virtual ParseResult parsePlus() = 0; |
640 | |
641 | /// Parse a '+' token if present. |
642 | virtual ParseResult parseOptionalPlus() = 0; |
643 | |
644 | /// Parse a '*' token. |
645 | virtual ParseResult parseStar() = 0; |
646 | |
647 | /// Parse a '*' token if present. |
648 | virtual ParseResult parseOptionalStar() = 0; |
649 | |
650 | /// Parse a '|' token. |
651 | virtual ParseResult parseVerticalBar() = 0; |
652 | |
653 | /// Parse a '|' token if present. |
654 | virtual ParseResult parseOptionalVerticalBar() = 0; |
655 | |
656 | /// Parse a quoted string token. |
657 | ParseResult parseString(std::string *string) { |
658 | auto loc = getCurrentLocation(); |
659 | if (parseOptionalString(string)) |
660 | return emitError(loc, message: "expected string" ); |
661 | return success(); |
662 | } |
663 | |
664 | /// Parse a quoted string token if present. |
665 | virtual ParseResult parseOptionalString(std::string *string) = 0; |
666 | |
667 | /// Parses a Base64 encoded string of bytes. |
668 | virtual ParseResult parseBase64Bytes(std::vector<char> *bytes) = 0; |
669 | |
670 | /// Parse a `(` token. |
671 | virtual ParseResult parseLParen() = 0; |
672 | |
673 | /// Parse a `(` token if present. |
674 | virtual ParseResult parseOptionalLParen() = 0; |
675 | |
676 | /// Parse a `)` token. |
677 | virtual ParseResult parseRParen() = 0; |
678 | |
679 | /// Parse a `)` token if present. |
680 | virtual ParseResult parseOptionalRParen() = 0; |
681 | |
682 | /// Parse a `[` token. |
683 | virtual ParseResult parseLSquare() = 0; |
684 | |
685 | /// Parse a `[` token if present. |
686 | virtual ParseResult parseOptionalLSquare() = 0; |
687 | |
688 | /// Parse a `]` token. |
689 | virtual ParseResult parseRSquare() = 0; |
690 | |
691 | /// Parse a `]` token if present. |
692 | virtual ParseResult parseOptionalRSquare() = 0; |
693 | |
694 | /// Parse a `...` token. |
695 | virtual ParseResult parseEllipsis() = 0; |
696 | |
697 | /// Parse a `...` token if present; |
698 | virtual ParseResult parseOptionalEllipsis() = 0; |
699 | |
700 | /// Parse a floating point value from the stream. |
701 | virtual ParseResult parseFloat(double &result) = 0; |
702 | |
703 | /// Parse an integer value from the stream. |
704 | template <typename IntT> |
705 | ParseResult parseInteger(IntT &result) { |
706 | auto loc = getCurrentLocation(); |
707 | OptionalParseResult parseResult = parseOptionalInteger(result); |
708 | if (!parseResult.has_value()) |
709 | return emitError(loc, message: "expected integer value" ); |
710 | return *parseResult; |
711 | } |
712 | |
713 | /// Parse an optional integer value from the stream. |
714 | virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0; |
715 | |
716 | template <typename IntT> |
717 | OptionalParseResult parseOptionalInteger(IntT &result) { |
718 | auto loc = getCurrentLocation(); |
719 | |
720 | // Parse the unsigned variant. |
721 | APInt uintResult; |
722 | OptionalParseResult parseResult = parseOptionalInteger(result&: uintResult); |
723 | if (!parseResult.has_value() || failed(result: *parseResult)) |
724 | return parseResult; |
725 | |
726 | // Try to convert to the provided integer type. sextOrTrunc is correct even |
727 | // for unsigned types because parseOptionalInteger ensures the sign bit is |
728 | // zero for non-negated integers. |
729 | result = |
730 | (IntT)uintResult.sextOrTrunc(width: sizeof(IntT) * CHAR_BIT).getLimitedValue(); |
731 | if (APInt(uintResult.getBitWidth(), result) != uintResult) |
732 | return emitError(loc, message: "integer value too large" ); |
733 | return success(); |
734 | } |
735 | |
736 | /// These are the supported delimiters around operand lists and region |
737 | /// argument lists, used by parseOperandList. |
738 | enum class Delimiter { |
739 | /// Zero or more operands with no delimiters. |
740 | None, |
741 | /// Parens surrounding zero or more operands. |
742 | Paren, |
743 | /// Square brackets surrounding zero or more operands. |
744 | Square, |
745 | /// <> brackets surrounding zero or more operands. |
746 | LessGreater, |
747 | /// {} brackets surrounding zero or more operands. |
748 | Braces, |
749 | /// Parens supporting zero or more operands, or nothing. |
750 | OptionalParen, |
751 | /// Square brackets supporting zero or more ops, or nothing. |
752 | OptionalSquare, |
753 | /// <> brackets supporting zero or more ops, or nothing. |
754 | OptionalLessGreater, |
755 | /// {} brackets surrounding zero or more operands, or nothing. |
756 | OptionalBraces, |
757 | }; |
758 | |
759 | /// Parse a list of comma-separated items with an optional delimiter. If a |
760 | /// delimiter is provided, then an empty list is allowed. If not, then at |
761 | /// least one element will be parsed. |
762 | /// |
763 | /// contextMessage is an optional message appended to "expected '('" sorts of |
764 | /// diagnostics when parsing the delimeters. |
765 | virtual ParseResult |
766 | parseCommaSeparatedList(Delimiter delimiter, |
767 | function_ref<ParseResult()> parseElementFn, |
768 | StringRef contextMessage = StringRef()) = 0; |
769 | |
770 | /// Parse a comma separated list of elements that must have at least one entry |
771 | /// in it. |
772 | ParseResult |
773 | parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) { |
774 | return parseCommaSeparatedList(delimiter: Delimiter::None, parseElementFn); |
775 | } |
776 | |
777 | //===--------------------------------------------------------------------===// |
778 | // Keyword Parsing |
779 | //===--------------------------------------------------------------------===// |
780 | |
781 | /// This class represents a StringSwitch like class that is useful for parsing |
782 | /// expected keywords. On construction, unless a non-empty keyword is |
783 | /// provided, it invokes `parseKeyword` and processes each of the provided |
784 | /// cases statements until a match is hit. The provided `ResultT` must be |
785 | /// assignable from `failure()`. |
786 | template <typename ResultT = ParseResult> |
787 | class KeywordSwitch { |
788 | public: |
789 | KeywordSwitch(AsmParser &parser, StringRef *keyword = nullptr) |
790 | : parser(parser), loc(parser.getCurrentLocation()) { |
791 | if (keyword && !keyword->empty()) |
792 | this->keyword = *keyword; |
793 | else if (failed(parser.parseKeywordOrCompletion(keyword: &this->keyword))) |
794 | result = failure(); |
795 | } |
796 | /// Case that uses the provided value when true. |
797 | KeywordSwitch &Case(StringLiteral str, ResultT value) { |
798 | return Case(str, [&](StringRef, SMLoc) { return std::move(value); }); |
799 | } |
800 | KeywordSwitch &Default(ResultT value) { |
801 | return Default([&](StringRef, SMLoc) { return std::move(value); }); |
802 | } |
803 | /// Case that invokes the provided functor when true. The parameters passed |
804 | /// to the functor are the keyword, and the location of the keyword (in case |
805 | /// any errors need to be emitted). |
806 | template <typename FnT> |
807 | std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &> |
808 | Case(StringLiteral str, FnT &&fn) { |
809 | if (result) |
810 | return *this; |
811 | |
812 | // If the word was empty, record this as a completion. |
813 | if (keyword.empty()) |
814 | parser.codeCompleteExpectedTokens(tokens: str); |
815 | else if (keyword == str) |
816 | result.emplace(std::move(fn(keyword, loc))); |
817 | return *this; |
818 | } |
819 | template <typename FnT> |
820 | std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &> |
821 | Default(FnT &&fn) { |
822 | if (!result) |
823 | result.emplace(fn(keyword, loc)); |
824 | return *this; |
825 | } |
826 | |
827 | /// Returns true if this switch has a value yet. |
828 | bool hasValue() const { return result.has_value(); } |
829 | |
830 | /// Return the result of the switch. |
831 | [[nodiscard]] operator ResultT() { |
832 | if (!result) |
833 | return parser.emitError(loc, message: "unexpected keyword: " ) << keyword; |
834 | return std::move(*result); |
835 | } |
836 | |
837 | private: |
838 | /// The parser used to construct this switch. |
839 | AsmParser &parser; |
840 | |
841 | /// The location of the keyword, used to emit errors as necessary. |
842 | SMLoc loc; |
843 | |
844 | /// The parsed keyword itself. |
845 | StringRef keyword; |
846 | |
847 | /// The result of the switch statement or std::nullopt if currently unknown. |
848 | std::optional<ResultT> result; |
849 | }; |
850 | |
851 | /// Parse a given keyword. |
852 | ParseResult parseKeyword(StringRef keyword) { |
853 | return parseKeyword(keyword, msg: "" ); |
854 | } |
855 | virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0; |
856 | |
857 | /// Parse a keyword into 'keyword'. |
858 | ParseResult parseKeyword(StringRef *keyword) { |
859 | auto loc = getCurrentLocation(); |
860 | if (parseOptionalKeyword(keyword)) |
861 | return emitError(loc, message: "expected valid keyword" ); |
862 | return success(); |
863 | } |
864 | |
865 | /// Parse the given keyword if present. |
866 | virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0; |
867 | |
868 | /// Parse a keyword, if present, into 'keyword'. |
869 | virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0; |
870 | |
871 | /// Parse a keyword, if present, and if one of the 'allowedValues', |
872 | /// into 'keyword' |
873 | virtual ParseResult |
874 | parseOptionalKeyword(StringRef *keyword, |
875 | ArrayRef<StringRef> allowedValues) = 0; |
876 | |
877 | /// Parse a keyword or a quoted string. |
878 | ParseResult parseKeywordOrString(std::string *result) { |
879 | if (failed(result: parseOptionalKeywordOrString(result))) |
880 | return emitError(loc: getCurrentLocation()) |
881 | << "expected valid keyword or string" ; |
882 | return success(); |
883 | } |
884 | |
885 | /// Parse an optional keyword or string. |
886 | virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0; |
887 | |
888 | //===--------------------------------------------------------------------===// |
889 | // Attribute/Type Parsing |
890 | //===--------------------------------------------------------------------===// |
891 | |
892 | /// Invoke the `getChecked` method of the given Attribute or Type class, using |
893 | /// the provided location to emit errors in the case of failure. Note that |
894 | /// unlike `OpBuilder::getType`, this method does not implicitly insert a |
895 | /// context parameter. |
896 | template <typename T, typename... ParamsT> |
897 | auto getChecked(SMLoc loc, ParamsT &&...params) { |
898 | return T::getChecked([&] { return emitError(loc); }, |
899 | std::forward<ParamsT>(params)...); |
900 | } |
901 | /// A variant of `getChecked` that uses the result of `getNameLoc` to emit |
902 | /// errors. |
903 | template <typename T, typename... ParamsT> |
904 | auto getChecked(ParamsT &&...params) { |
905 | return T::getChecked([&] { return emitError(loc: getNameLoc()); }, |
906 | std::forward<ParamsT>(params)...); |
907 | } |
908 | |
909 | //===--------------------------------------------------------------------===// |
910 | // Attribute Parsing |
911 | //===--------------------------------------------------------------------===// |
912 | |
913 | /// Parse an arbitrary attribute of a given type and return it in result. |
914 | virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0; |
915 | |
916 | /// Parse a custom attribute with the provided callback, unless the next |
917 | /// token is `#`, in which case the generic parser is invoked. |
918 | virtual ParseResult parseCustomAttributeWithFallback( |
919 | Attribute &result, Type type, |
920 | function_ref<ParseResult(Attribute &result, Type type)> |
921 | parseAttribute) = 0; |
922 | |
923 | /// Parse an attribute of a specific kind and type. |
924 | template <typename AttrType> |
925 | ParseResult parseAttribute(AttrType &result, Type type = {}) { |
926 | SMLoc loc = getCurrentLocation(); |
927 | |
928 | // Parse any kind of attribute. |
929 | Attribute attr; |
930 | if (parseAttribute(result&: attr, type)) |
931 | return failure(); |
932 | |
933 | // Check for the right kind of attribute. |
934 | if (!(result = llvm::dyn_cast<AttrType>(attr))) |
935 | return emitError(loc, message: "invalid kind of attribute specified" ); |
936 | |
937 | return success(); |
938 | } |
939 | |
940 | /// Parse an arbitrary attribute and return it in result. This also adds the |
941 | /// attribute to the specified attribute list with the specified name. |
942 | ParseResult parseAttribute(Attribute &result, StringRef attrName, |
943 | NamedAttrList &attrs) { |
944 | return parseAttribute(result, type: Type(), attrName, attrs); |
945 | } |
946 | |
947 | /// Parse an attribute of a specific kind and type. |
948 | template <typename AttrType> |
949 | ParseResult parseAttribute(AttrType &result, StringRef attrName, |
950 | NamedAttrList &attrs) { |
951 | return parseAttribute(result, Type(), attrName, attrs); |
952 | } |
953 | |
954 | /// Parse an arbitrary attribute of a given type and populate it in `result`. |
955 | /// This also adds the attribute to the specified attribute list with the |
956 | /// specified name. |
957 | template <typename AttrType> |
958 | ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName, |
959 | NamedAttrList &attrs) { |
960 | SMLoc loc = getCurrentLocation(); |
961 | |
962 | // Parse any kind of attribute. |
963 | Attribute attr; |
964 | if (parseAttribute(result&: attr, type)) |
965 | return failure(); |
966 | |
967 | // Check for the right kind of attribute. |
968 | result = llvm::dyn_cast<AttrType>(attr); |
969 | if (!result) |
970 | return emitError(loc, message: "invalid kind of attribute specified" ); |
971 | |
972 | attrs.append(attrName, result); |
973 | return success(); |
974 | } |
975 | |
976 | /// Trait to check if `AttrType` provides a `parse` method. |
977 | template <typename AttrType> |
978 | using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(), |
979 | std::declval<Type>())); |
980 | template <typename AttrType> |
981 | using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>; |
982 | |
983 | /// Parse a custom attribute of a given type unless the next token is `#`, in |
984 | /// which case the generic parser is invoked. The parsed attribute is |
985 | /// populated in `result` and also added to the specified attribute list with |
986 | /// the specified name. |
987 | template <typename AttrType> |
988 | std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult> |
989 | parseCustomAttributeWithFallback(AttrType &result, Type type, |
990 | StringRef attrName, NamedAttrList &attrs) { |
991 | SMLoc loc = getCurrentLocation(); |
992 | |
993 | // Parse any kind of attribute. |
994 | Attribute attr; |
995 | if (parseCustomAttributeWithFallback( |
996 | attr, type, [&](Attribute &result, Type type) -> ParseResult { |
997 | result = AttrType::parse(*this, type); |
998 | if (!result) |
999 | return failure(); |
1000 | return success(); |
1001 | })) |
1002 | return failure(); |
1003 | |
1004 | // Check for the right kind of attribute. |
1005 | result = llvm::dyn_cast<AttrType>(attr); |
1006 | if (!result) |
1007 | return emitError(loc, message: "invalid kind of attribute specified" ); |
1008 | |
1009 | attrs.append(attrName, result); |
1010 | return success(); |
1011 | } |
1012 | |
1013 | /// SFINAE parsing method for Attribute that don't implement a parse method. |
1014 | template <typename AttrType> |
1015 | std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult> |
1016 | parseCustomAttributeWithFallback(AttrType &result, Type type, |
1017 | StringRef attrName, NamedAttrList &attrs) { |
1018 | return parseAttribute(result, type, attrName, attrs); |
1019 | } |
1020 | |
1021 | /// Parse a custom attribute of a given type unless the next token is `#`, in |
1022 | /// which case the generic parser is invoked. The parsed attribute is |
1023 | /// populated in `result`. |
1024 | template <typename AttrType> |
1025 | std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult> |
1026 | parseCustomAttributeWithFallback(AttrType &result, Type type = {}) { |
1027 | SMLoc loc = getCurrentLocation(); |
1028 | |
1029 | // Parse any kind of attribute. |
1030 | Attribute attr; |
1031 | if (parseCustomAttributeWithFallback( |
1032 | attr, type, [&](Attribute &result, Type type) -> ParseResult { |
1033 | result = AttrType::parse(*this, type); |
1034 | return success(isSuccess: !!result); |
1035 | })) |
1036 | return failure(); |
1037 | |
1038 | // Check for the right kind of attribute. |
1039 | result = llvm::dyn_cast<AttrType>(attr); |
1040 | if (!result) |
1041 | return emitError(loc, message: "invalid kind of attribute specified" ); |
1042 | return success(); |
1043 | } |
1044 | |
1045 | /// SFINAE parsing method for Attribute that don't implement a parse method. |
1046 | template <typename AttrType> |
1047 | std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult> |
1048 | parseCustomAttributeWithFallback(AttrType &result, Type type = {}) { |
1049 | return parseAttribute(result, type); |
1050 | } |
1051 | |
1052 | /// Parse an arbitrary optional attribute of a given type and return it in |
1053 | /// result. |
1054 | virtual OptionalParseResult parseOptionalAttribute(Attribute &result, |
1055 | Type type = {}) = 0; |
1056 | |
1057 | /// Parse an optional array attribute and return it in result. |
1058 | virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result, |
1059 | Type type = {}) = 0; |
1060 | |
1061 | /// Parse an optional string attribute and return it in result. |
1062 | virtual OptionalParseResult parseOptionalAttribute(StringAttr &result, |
1063 | Type type = {}) = 0; |
1064 | |
1065 | /// Parse an optional symbol ref attribute and return it in result. |
1066 | virtual OptionalParseResult parseOptionalAttribute(SymbolRefAttr &result, |
1067 | Type type = {}) = 0; |
1068 | |
1069 | /// Parse an optional attribute of a specific type and add it to the list with |
1070 | /// the specified name. |
1071 | template <typename AttrType> |
1072 | OptionalParseResult parseOptionalAttribute(AttrType &result, |
1073 | StringRef attrName, |
1074 | NamedAttrList &attrs) { |
1075 | return parseOptionalAttribute(result, Type(), attrName, attrs); |
1076 | } |
1077 | |
1078 | /// Parse an optional attribute of a specific type and add it to the list with |
1079 | /// the specified name. |
1080 | template <typename AttrType> |
1081 | OptionalParseResult parseOptionalAttribute(AttrType &result, Type type, |
1082 | StringRef attrName, |
1083 | NamedAttrList &attrs) { |
1084 | OptionalParseResult parseResult = parseOptionalAttribute(result, type); |
1085 | if (parseResult.has_value() && succeeded(result: *parseResult)) |
1086 | attrs.append(attrName, result); |
1087 | return parseResult; |
1088 | } |
1089 | |
1090 | /// Parse a named dictionary into 'result' if it is present. |
1091 | virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0; |
1092 | |
1093 | /// Parse a named dictionary into 'result' if the `attributes` keyword is |
1094 | /// present. |
1095 | virtual ParseResult |
1096 | parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0; |
1097 | |
1098 | /// Parse an affine map instance into 'map'. |
1099 | virtual ParseResult parseAffineMap(AffineMap &map) = 0; |
1100 | |
1101 | /// Parse an affine expr instance into 'expr' using the already computed |
1102 | /// mapping from symbols to affine expressions in 'symbolSet'. |
1103 | virtual ParseResult |
1104 | parseAffineExpr(ArrayRef<std::pair<StringRef, AffineExpr>> symbolSet, |
1105 | AffineExpr &expr) = 0; |
1106 | |
1107 | /// Parse an integer set instance into 'set'. |
1108 | virtual ParseResult parseIntegerSet(IntegerSet &set) = 0; |
1109 | |
1110 | //===--------------------------------------------------------------------===// |
1111 | // Identifier Parsing |
1112 | //===--------------------------------------------------------------------===// |
1113 | |
1114 | /// Parse an @-identifier and store it (without the '@' symbol) in a string |
1115 | /// attribute. |
1116 | ParseResult parseSymbolName(StringAttr &result) { |
1117 | if (failed(result: parseOptionalSymbolName(result))) |
1118 | return emitError(loc: getCurrentLocation()) |
1119 | << "expected valid '@'-identifier for symbol name" ; |
1120 | return success(); |
1121 | } |
1122 | |
1123 | /// Parse an @-identifier and store it (without the '@' symbol) in a string |
1124 | /// attribute named 'attrName'. |
1125 | ParseResult parseSymbolName(StringAttr &result, StringRef attrName, |
1126 | NamedAttrList &attrs) { |
1127 | if (parseSymbolName(result)) |
1128 | return failure(); |
1129 | attrs.append(attrName, result); |
1130 | return success(); |
1131 | } |
1132 | |
1133 | /// Parse an optional @-identifier and store it (without the '@' symbol) in a |
1134 | /// string attribute. |
1135 | virtual ParseResult parseOptionalSymbolName(StringAttr &result) = 0; |
1136 | |
1137 | /// Parse an optional @-identifier and store it (without the '@' symbol) in a |
1138 | /// string attribute named 'attrName'. |
1139 | ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName, |
1140 | NamedAttrList &attrs) { |
1141 | if (succeeded(result: parseOptionalSymbolName(result))) { |
1142 | attrs.append(attrName, result); |
1143 | return success(); |
1144 | } |
1145 | return failure(); |
1146 | } |
1147 | |
1148 | //===--------------------------------------------------------------------===// |
1149 | // Resource Parsing |
1150 | //===--------------------------------------------------------------------===// |
1151 | |
1152 | /// Parse a handle to a resource within the assembly format. |
1153 | template <typename ResourceT> |
1154 | FailureOr<ResourceT> parseResourceHandle() { |
1155 | SMLoc handleLoc = getCurrentLocation(); |
1156 | |
1157 | // Try to load the dialect that owns the handle. |
1158 | auto *dialect = |
1159 | getContext()->getOrLoadDialect<typename ResourceT::Dialect>(); |
1160 | if (!dialect) { |
1161 | return emitError(loc: handleLoc) |
1162 | << "dialect '" << ResourceT::Dialect::getDialectNamespace() |
1163 | << "' is unknown" ; |
1164 | } |
1165 | |
1166 | FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(dialect); |
1167 | if (failed(result: handle)) |
1168 | return failure(); |
1169 | if (auto *result = dyn_cast<ResourceT>(&*handle)) |
1170 | return std::move(*result); |
1171 | return emitError(loc: handleLoc) << "provided resource handle differs from the " |
1172 | "expected resource type" ; |
1173 | } |
1174 | |
1175 | //===--------------------------------------------------------------------===// |
1176 | // Type Parsing |
1177 | //===--------------------------------------------------------------------===// |
1178 | |
1179 | /// Parse a type. |
1180 | virtual ParseResult parseType(Type &result) = 0; |
1181 | |
1182 | /// Parse a custom type with the provided callback, unless the next |
1183 | /// token is `#`, in which case the generic parser is invoked. |
1184 | virtual ParseResult parseCustomTypeWithFallback( |
1185 | Type &result, function_ref<ParseResult(Type &result)> parseType) = 0; |
1186 | |
1187 | /// Parse an optional type. |
1188 | virtual OptionalParseResult parseOptionalType(Type &result) = 0; |
1189 | |
1190 | /// Parse a type of a specific type. |
1191 | template <typename TypeT> |
1192 | ParseResult parseType(TypeT &result) { |
1193 | SMLoc loc = getCurrentLocation(); |
1194 | |
1195 | // Parse any kind of type. |
1196 | Type type; |
1197 | if (parseType(result&: type)) |
1198 | return failure(); |
1199 | |
1200 | // Check for the right kind of type. |
1201 | result = llvm::dyn_cast<TypeT>(type); |
1202 | if (!result) |
1203 | return emitError(loc, message: "invalid kind of type specified" ); |
1204 | |
1205 | return success(); |
1206 | } |
1207 | |
1208 | /// Trait to check if `TypeT` provides a `parse` method. |
1209 | template <typename TypeT> |
1210 | using type_has_parse_method = |
1211 | decltype(TypeT::parse(std::declval<AsmParser &>())); |
1212 | template <typename TypeT> |
1213 | using detect_type_has_parse_method = |
1214 | llvm::is_detected<type_has_parse_method, TypeT>; |
1215 | |
1216 | /// Parse a custom Type of a given type unless the next token is `#`, in |
1217 | /// which case the generic parser is invoked. The parsed Type is |
1218 | /// populated in `result`. |
1219 | template <typename TypeT> |
1220 | std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult> |
1221 | parseCustomTypeWithFallback(TypeT &result) { |
1222 | SMLoc loc = getCurrentLocation(); |
1223 | |
1224 | // Parse any kind of Type. |
1225 | Type type; |
1226 | if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult { |
1227 | result = TypeT::parse(*this); |
1228 | return success(isSuccess: !!result); |
1229 | })) |
1230 | return failure(); |
1231 | |
1232 | // Check for the right kind of Type. |
1233 | result = llvm::dyn_cast<TypeT>(type); |
1234 | if (!result) |
1235 | return emitError(loc, message: "invalid kind of Type specified" ); |
1236 | return success(); |
1237 | } |
1238 | |
1239 | /// SFINAE parsing method for Type that don't implement a parse method. |
1240 | template <typename TypeT> |
1241 | std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult> |
1242 | parseCustomTypeWithFallback(TypeT &result) { |
1243 | return parseType(result); |
1244 | } |
1245 | |
1246 | /// Parse a type list. |
1247 | ParseResult parseTypeList(SmallVectorImpl<Type> &result); |
1248 | |
1249 | /// Parse an arrow followed by a type list. |
1250 | virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0; |
1251 | |
1252 | /// Parse an optional arrow followed by a type list. |
1253 | virtual ParseResult |
1254 | parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0; |
1255 | |
1256 | /// Parse a colon followed by a type. |
1257 | virtual ParseResult parseColonType(Type &result) = 0; |
1258 | |
1259 | /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType. |
1260 | template <typename TypeType> |
1261 | ParseResult parseColonType(TypeType &result) { |
1262 | SMLoc loc = getCurrentLocation(); |
1263 | |
1264 | // Parse any kind of type. |
1265 | Type type; |
1266 | if (parseColonType(result&: type)) |
1267 | return failure(); |
1268 | |
1269 | // Check for the right kind of type. |
1270 | result = llvm::dyn_cast<TypeType>(type); |
1271 | if (!result) |
1272 | return emitError(loc, message: "invalid kind of type specified" ); |
1273 | |
1274 | return success(); |
1275 | } |
1276 | |
1277 | /// Parse a colon followed by a type list, which must have at least one type. |
1278 | virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0; |
1279 | |
1280 | /// Parse an optional colon followed by a type list, which if present must |
1281 | /// have at least one type. |
1282 | virtual ParseResult |
1283 | parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0; |
1284 | |
1285 | /// Parse a keyword followed by a type. |
1286 | ParseResult parseKeywordType(const char *keyword, Type &result) { |
1287 | return failure(isFailure: parseKeyword(keyword) || parseType(result)); |
1288 | } |
1289 | |
1290 | /// Add the specified type to the end of the specified type list and return |
1291 | /// success. This is a helper designed to allow parse methods to be simple |
1292 | /// and chain through || operators. |
1293 | ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) { |
1294 | result.push_back(Elt: type); |
1295 | return success(); |
1296 | } |
1297 | |
1298 | /// Add the specified types to the end of the specified type list and return |
1299 | /// success. This is a helper designed to allow parse methods to be simple |
1300 | /// and chain through || operators. |
1301 | ParseResult addTypesToList(ArrayRef<Type> types, |
1302 | SmallVectorImpl<Type> &result) { |
1303 | result.append(in_start: types.begin(), in_end: types.end()); |
1304 | return success(); |
1305 | } |
1306 | |
1307 | /// Parse a dimension list of a tensor or memref type. This populates the |
1308 | /// dimension list, using ShapedType::kDynamic for the `?` dimensions if |
1309 | /// `allowDynamic` is set and errors out on `?` otherwise. Parsing the |
1310 | /// trailing `x` is configurable. |
1311 | /// |
1312 | /// dimension-list ::= eps | dimension (`x` dimension)* |
1313 | /// dimension-list-with-trailing-x ::= (dimension `x`)* |
1314 | /// dimension ::= `?` | decimal-literal |
1315 | /// |
1316 | /// When `allowDynamic` is not set, this is used to parse: |
1317 | /// |
1318 | /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)* |
1319 | /// static-dimension-list-with-trailing-x ::= (dimension `x`)* |
1320 | virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, |
1321 | bool allowDynamic = true, |
1322 | bool withTrailingX = true) = 0; |
1323 | |
1324 | /// Parse an 'x' token in a dimension list, handling the case where the x is |
1325 | /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the |
1326 | /// next token. |
1327 | virtual ParseResult parseXInDimensionList() = 0; |
1328 | |
1329 | /// Class used to automatically end a cyclic region on destruction. |
1330 | class CyclicParseReset { |
1331 | public: |
1332 | explicit CyclicParseReset(AsmParser *parser) : parser(parser) {} |
1333 | |
1334 | ~CyclicParseReset() { |
1335 | if (parser) |
1336 | parser->popCyclicParsing(); |
1337 | } |
1338 | |
1339 | CyclicParseReset(const CyclicParseReset &) = delete; |
1340 | CyclicParseReset &operator=(const CyclicParseReset &) = delete; |
1341 | CyclicParseReset(CyclicParseReset &&rhs) |
1342 | : parser(std::exchange(obj&: rhs.parser, new_val: nullptr)) {} |
1343 | CyclicParseReset &operator=(CyclicParseReset &&rhs) { |
1344 | parser = std::exchange(obj&: rhs.parser, new_val: nullptr); |
1345 | return *this; |
1346 | } |
1347 | |
1348 | private: |
1349 | AsmParser *parser; |
1350 | }; |
1351 | |
1352 | /// Attempts to start a cyclic parsing region for `attrOrType`. |
1353 | /// A cyclic parsing region starts with this call and ends with the |
1354 | /// destruction of the returned `CyclicParseReset`. During this time, |
1355 | /// calling `tryStartCyclicParse` with the same attribute in any parser |
1356 | /// will lead to returning failure. |
1357 | /// |
1358 | /// This makes it possible to parse cyclic attributes or types by parsing a |
1359 | /// short from if nested within itself. |
1360 | template <class AttrOrTypeT> |
1361 | FailureOr<CyclicParseReset> tryStartCyclicParse(AttrOrTypeT attrOrType) { |
1362 | static_assert( |
1363 | std::is_base_of_v<AttributeTrait::IsMutable<AttrOrTypeT>, |
1364 | AttrOrTypeT> || |
1365 | std::is_base_of_v<TypeTrait::IsMutable<AttrOrTypeT>, AttrOrTypeT>, |
1366 | "Only mutable attributes or types can be cyclic" ); |
1367 | if (failed(pushCyclicParsing(opaquePointer: attrOrType.getAsOpaquePointer()))) |
1368 | return failure(); |
1369 | |
1370 | return CyclicParseReset(this); |
1371 | } |
1372 | |
1373 | protected: |
1374 | /// Parse a handle to a resource within the assembly format for the given |
1375 | /// dialect. |
1376 | virtual FailureOr<AsmDialectResourceHandle> |
1377 | parseResourceHandle(Dialect *dialect) = 0; |
1378 | |
1379 | /// Pushes a new attribute or type in the form of a type erased pointer |
1380 | /// into an internal set. |
1381 | /// Returns success if the type or attribute was inserted in the set or |
1382 | /// failure if it was already contained. |
1383 | virtual LogicalResult pushCyclicParsing(const void *opaquePointer) = 0; |
1384 | |
1385 | /// Removes the element that was last inserted with a successful call to |
1386 | /// `pushCyclicParsing`. There must be exactly one `popCyclicParsing` call |
1387 | /// in reverse order of all successful `pushCyclicParsing`. |
1388 | virtual void popCyclicParsing() = 0; |
1389 | |
1390 | //===--------------------------------------------------------------------===// |
1391 | // Code Completion |
1392 | //===--------------------------------------------------------------------===// |
1393 | |
1394 | /// Parse a keyword, or an empty string if the current location signals a code |
1395 | /// completion. |
1396 | virtual ParseResult parseKeywordOrCompletion(StringRef *keyword) = 0; |
1397 | |
1398 | /// Signal the code completion of a set of expected tokens. |
1399 | virtual void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) = 0; |
1400 | |
1401 | private: |
1402 | AsmParser(const AsmParser &) = delete; |
1403 | void operator=(const AsmParser &) = delete; |
1404 | }; |
1405 | |
1406 | //===----------------------------------------------------------------------===// |
1407 | // OpAsmParser |
1408 | //===----------------------------------------------------------------------===// |
1409 | |
1410 | /// The OpAsmParser has methods for interacting with the asm parser: parsing |
1411 | /// things from it, emitting errors etc. It has an intentionally high-level API |
1412 | /// that is designed to reduce/constrain syntax innovation in individual |
1413 | /// operations. |
1414 | /// |
1415 | /// For example, consider an op like this: |
1416 | /// |
1417 | /// %x = load %p[%1, %2] : memref<...> |
1418 | /// |
1419 | /// The "%x = load" tokens are already parsed and therefore invisible to the |
1420 | /// custom op parser. This can be supported by calling `parseOperandList` to |
1421 | /// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to |
1422 | /// parse the indices, then calling `parseColonTypeList` to parse the result |
1423 | /// type. |
1424 | /// |
1425 | class OpAsmParser : public AsmParser { |
1426 | public: |
1427 | using AsmParser::AsmParser; |
1428 | ~OpAsmParser() override; |
1429 | |
1430 | /// Parse a loc(...) specifier if present, filling in result if so. |
1431 | /// Location for BlockArgument and Operation may be deferred with an alias, in |
1432 | /// which case an OpaqueLoc is set and will be resolved when parsing |
1433 | /// completes. |
1434 | virtual ParseResult |
1435 | parseOptionalLocationSpecifier(std::optional<Location> &result) = 0; |
1436 | |
1437 | /// Return the name of the specified result in the specified syntax, as well |
1438 | /// as the sub-element in the name. It returns an empty string and ~0U for |
1439 | /// invalid result numbers. For example, in this operation: |
1440 | /// |
1441 | /// %x, %y:2, %z = foo.op |
1442 | /// |
1443 | /// getResultName(0) == {"x", 0 } |
1444 | /// getResultName(1) == {"y", 0 } |
1445 | /// getResultName(2) == {"y", 1 } |
1446 | /// getResultName(3) == {"z", 0 } |
1447 | /// getResultName(4) == {"", ~0U } |
1448 | virtual std::pair<StringRef, unsigned> |
1449 | getResultName(unsigned resultNo) const = 0; |
1450 | |
1451 | /// Return the number of declared SSA results. This returns 4 for the foo.op |
1452 | /// example in the comment for `getResultName`. |
1453 | virtual size_t getNumResults() const = 0; |
1454 | |
1455 | // These methods emit an error and return failure or success. This allows |
1456 | // these to be chained together into a linear sequence of || expressions in |
1457 | // many cases. |
1458 | |
1459 | /// Parse an operation in its generic form. |
1460 | /// The parsed operation is parsed in the current context and inserted in the |
1461 | /// provided block and insertion point. The results produced by this operation |
1462 | /// aren't mapped to any named value in the parser. Returns nullptr on |
1463 | /// failure. |
1464 | virtual Operation *parseGenericOperation(Block *insertBlock, |
1465 | Block::iterator insertPt) = 0; |
1466 | |
1467 | /// Parse the name of an operation, in the custom form. On success, return a |
1468 | /// an object of type 'OperationName'. Otherwise, failure is returned. |
1469 | virtual FailureOr<OperationName> parseCustomOperationName() = 0; |
1470 | |
1471 | //===--------------------------------------------------------------------===// |
1472 | // Operand Parsing |
1473 | //===--------------------------------------------------------------------===// |
1474 | |
1475 | /// This is the representation of an operand reference. |
1476 | struct UnresolvedOperand { |
1477 | SMLoc location; // Location of the token. |
1478 | StringRef name; // Value name, e.g. %42 or %abc |
1479 | unsigned number; // Number, e.g. 12 for an operand like %xyz#12 |
1480 | }; |
1481 | |
1482 | /// Parse different components, viz., use-info of operand(s), successor(s), |
1483 | /// region(s), attribute(s) and function-type, of the generic form of an |
1484 | /// operation instance and populate the input operation-state 'result' with |
1485 | /// those components. If any of the components is explicitly provided, then |
1486 | /// skip parsing that component. |
1487 | virtual ParseResult parseGenericOperationAfterOpName( |
1488 | OperationState &result, |
1489 | std::optional<ArrayRef<UnresolvedOperand>> parsedOperandType = |
1490 | std::nullopt, |
1491 | std::optional<ArrayRef<Block *>> parsedSuccessors = std::nullopt, |
1492 | std::optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions = |
1493 | std::nullopt, |
1494 | std::optional<ArrayRef<NamedAttribute>> parsedAttributes = std::nullopt, |
1495 | std::optional<Attribute> parsedPropertiesAttribute = std::nullopt, |
1496 | std::optional<FunctionType> parsedFnType = std::nullopt) = 0; |
1497 | |
1498 | /// Parse a single SSA value operand name along with a result number if |
1499 | /// `allowResultNumber` is true. |
1500 | virtual ParseResult parseOperand(UnresolvedOperand &result, |
1501 | bool allowResultNumber = true) = 0; |
1502 | |
1503 | /// Parse a single operand if present. |
1504 | virtual OptionalParseResult |
1505 | parseOptionalOperand(UnresolvedOperand &result, |
1506 | bool allowResultNumber = true) = 0; |
1507 | |
1508 | /// Parse zero or more SSA comma-separated operand references with a specified |
1509 | /// surrounding delimiter, and an optional required operand count. |
1510 | virtual ParseResult |
1511 | parseOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1512 | Delimiter delimiter = Delimiter::None, |
1513 | bool allowResultNumber = true, |
1514 | int requiredOperandCount = -1) = 0; |
1515 | |
1516 | /// Parse a specified number of comma separated operands. |
1517 | ParseResult parseOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1518 | int requiredOperandCount, |
1519 | Delimiter delimiter = Delimiter::None) { |
1520 | return parseOperandList(result, delimiter, |
1521 | /*allowResultNumber=*/allowResultNumber: true, requiredOperandCount); |
1522 | } |
1523 | |
1524 | /// Parse zero or more trailing SSA comma-separated trailing operand |
1525 | /// references with a specified surrounding delimiter, and an optional |
1526 | /// required operand count. A leading comma is expected before the |
1527 | /// operands. |
1528 | ParseResult |
1529 | parseTrailingOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1530 | Delimiter delimiter = Delimiter::None) { |
1531 | if (failed(result: parseOptionalComma())) |
1532 | return success(); // The comma is optional. |
1533 | return parseOperandList(result, delimiter); |
1534 | } |
1535 | |
1536 | /// Resolve an operand to an SSA value, emitting an error on failure. |
1537 | virtual ParseResult resolveOperand(const UnresolvedOperand &operand, |
1538 | Type type, |
1539 | SmallVectorImpl<Value> &result) = 0; |
1540 | |
1541 | /// Resolve a list of operands to SSA values, emitting an error on failure, or |
1542 | /// appending the results to the list on success. This method should be used |
1543 | /// when all operands have the same type. |
1544 | template <typename Operands = ArrayRef<UnresolvedOperand>> |
1545 | ParseResult resolveOperands(Operands &&operands, Type type, |
1546 | SmallVectorImpl<Value> &result) { |
1547 | for (const UnresolvedOperand &operand : operands) |
1548 | if (resolveOperand(operand, type, result)) |
1549 | return failure(); |
1550 | return success(); |
1551 | } |
1552 | template <typename Operands = ArrayRef<UnresolvedOperand>> |
1553 | ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc, |
1554 | SmallVectorImpl<Value> &result) { |
1555 | return resolveOperands(std::forward<Operands>(operands), type, result); |
1556 | } |
1557 | |
1558 | /// Resolve a list of operands and a list of operand types to SSA values, |
1559 | /// emitting an error and returning failure, or appending the results |
1560 | /// to the list on success. |
1561 | template <typename Operands = ArrayRef<UnresolvedOperand>, |
1562 | typename Types = ArrayRef<Type>> |
1563 | std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult> |
1564 | resolveOperands(Operands &&operands, Types &&types, SMLoc loc, |
1565 | SmallVectorImpl<Value> &result) { |
1566 | size_t operandSize = llvm::range_size(operands); |
1567 | size_t typeSize = llvm::range_size(types); |
1568 | if (operandSize != typeSize) |
1569 | return emitError(loc) |
1570 | << operandSize << " operands present, but expected " << typeSize; |
1571 | |
1572 | for (auto [operand, type] : llvm::zip_equal(operands, types)) |
1573 | if (resolveOperand(operand, type, result)) |
1574 | return failure(); |
1575 | return success(); |
1576 | } |
1577 | |
1578 | /// Parses an affine map attribute where dims and symbols are SSA operands. |
1579 | /// Operand values must come from single-result sources, and be valid |
1580 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
1581 | virtual ParseResult |
1582 | parseAffineMapOfSSAIds(SmallVectorImpl<UnresolvedOperand> &operands, |
1583 | Attribute &map, StringRef attrName, |
1584 | NamedAttrList &attrs, |
1585 | Delimiter delimiter = Delimiter::Square) = 0; |
1586 | |
1587 | /// Parses an affine expression where dims and symbols are SSA operands. |
1588 | /// Operand values must come from single-result sources, and be valid |
1589 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
1590 | virtual ParseResult |
1591 | parseAffineExprOfSSAIds(SmallVectorImpl<UnresolvedOperand> &dimOperands, |
1592 | SmallVectorImpl<UnresolvedOperand> &symbOperands, |
1593 | AffineExpr &expr) = 0; |
1594 | |
1595 | //===--------------------------------------------------------------------===// |
1596 | // Argument Parsing |
1597 | //===--------------------------------------------------------------------===// |
1598 | |
1599 | struct Argument { |
1600 | UnresolvedOperand ssaName; // SourceLoc, SSA name, result #. |
1601 | Type type; // Type. |
1602 | DictionaryAttr attrs; // Attributes if present. |
1603 | std::optional<Location> sourceLoc; // Source location specifier if present. |
1604 | }; |
1605 | |
1606 | /// Parse a single argument with the following syntax: |
1607 | /// |
1608 | /// `%ssaName : !type { optionalAttrDict} loc(optionalSourceLoc)` |
1609 | /// |
1610 | /// If `allowType` is false or `allowAttrs` are false then the respective |
1611 | /// parts of the grammar are not parsed. |
1612 | virtual ParseResult parseArgument(Argument &result, bool allowType = false, |
1613 | bool allowAttrs = false) = 0; |
1614 | |
1615 | /// Parse a single argument if present. |
1616 | virtual OptionalParseResult |
1617 | parseOptionalArgument(Argument &result, bool allowType = false, |
1618 | bool allowAttrs = false) = 0; |
1619 | |
1620 | /// Parse zero or more arguments with a specified surrounding delimiter. |
1621 | virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result, |
1622 | Delimiter delimiter = Delimiter::None, |
1623 | bool allowType = false, |
1624 | bool allowAttrs = false) = 0; |
1625 | |
1626 | //===--------------------------------------------------------------------===// |
1627 | // Region Parsing |
1628 | //===--------------------------------------------------------------------===// |
1629 | |
1630 | /// Parses a region. Any parsed blocks are appended to 'region' and must be |
1631 | /// moved to the op regions after the op is created. The first block of the |
1632 | /// region takes 'arguments'. |
1633 | /// |
1634 | /// If 'enableNameShadowing' is set to true, the argument names are allowed to |
1635 | /// shadow the names of other existing SSA values defined above the region |
1636 | /// scope. 'enableNameShadowing' can only be set to true for regions attached |
1637 | /// to operations that are 'IsolatedFromAbove'. |
1638 | virtual ParseResult parseRegion(Region ®ion, |
1639 | ArrayRef<Argument> arguments = {}, |
1640 | bool enableNameShadowing = false) = 0; |
1641 | |
1642 | /// Parses a region if present. |
1643 | virtual OptionalParseResult |
1644 | parseOptionalRegion(Region ®ion, ArrayRef<Argument> arguments = {}, |
1645 | bool enableNameShadowing = false) = 0; |
1646 | |
1647 | /// Parses a region if present. If the region is present, a new region is |
1648 | /// allocated and placed in `region`. If no region is present or on failure, |
1649 | /// `region` remains untouched. |
1650 | virtual OptionalParseResult |
1651 | parseOptionalRegion(std::unique_ptr<Region> ®ion, |
1652 | ArrayRef<Argument> arguments = {}, |
1653 | bool enableNameShadowing = false) = 0; |
1654 | |
1655 | //===--------------------------------------------------------------------===// |
1656 | // Successor Parsing |
1657 | //===--------------------------------------------------------------------===// |
1658 | |
1659 | /// Parse a single operation successor. |
1660 | virtual ParseResult parseSuccessor(Block *&dest) = 0; |
1661 | |
1662 | /// Parse an optional operation successor. |
1663 | virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0; |
1664 | |
1665 | /// Parse a single operation successor and its operand list. |
1666 | virtual ParseResult |
1667 | parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0; |
1668 | |
1669 | //===--------------------------------------------------------------------===// |
1670 | // Type Parsing |
1671 | //===--------------------------------------------------------------------===// |
1672 | |
1673 | /// Parse a list of assignments of the form |
1674 | /// (%x1 = %y1, %x2 = %y2, ...) |
1675 | ParseResult parseAssignmentList(SmallVectorImpl<Argument> &lhs, |
1676 | SmallVectorImpl<UnresolvedOperand> &rhs) { |
1677 | OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs); |
1678 | if (!result.has_value()) |
1679 | return emitError(loc: getCurrentLocation(), message: "expected '('" ); |
1680 | return result.value(); |
1681 | } |
1682 | |
1683 | virtual OptionalParseResult |
1684 | parseOptionalAssignmentList(SmallVectorImpl<Argument> &lhs, |
1685 | SmallVectorImpl<UnresolvedOperand> &rhs) = 0; |
1686 | }; |
1687 | |
1688 | //===--------------------------------------------------------------------===// |
1689 | // Dialect OpAsm interface. |
1690 | //===--------------------------------------------------------------------===// |
1691 | |
1692 | /// A functor used to set the name of the start of a result group of an |
1693 | /// operation. See 'getAsmResultNames' below for more details. |
1694 | using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>; |
1695 | |
1696 | /// A functor used to set the name of blocks in regions directly nested under |
1697 | /// an operation. |
1698 | using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>; |
1699 | |
1700 | class OpAsmDialectInterface |
1701 | : public DialectInterface::Base<OpAsmDialectInterface> { |
1702 | public: |
1703 | OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {} |
1704 | |
1705 | //===------------------------------------------------------------------===// |
1706 | // Aliases |
1707 | //===------------------------------------------------------------------===// |
1708 | |
1709 | /// Holds the result of `getAlias` hook call. |
1710 | enum class AliasResult { |
1711 | /// The object (type or attribute) is not supported by the hook |
1712 | /// and an alias was not provided. |
1713 | NoAlias, |
1714 | /// An alias was provided, but it might be overriden by other hook. |
1715 | OverridableAlias, |
1716 | /// An alias was provided and it should be used |
1717 | /// (no other hooks will be checked). |
1718 | FinalAlias |
1719 | }; |
1720 | |
1721 | /// Hooks for getting an alias identifier alias for a given symbol, that is |
1722 | /// not necessarily a part of this dialect. The identifier is used in place of |
1723 | /// the symbol when printing textual IR. These aliases must not contain `.` or |
1724 | /// end with a numeric digit([0-9]+). |
1725 | virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const { |
1726 | return AliasResult::NoAlias; |
1727 | } |
1728 | virtual AliasResult getAlias(Type type, raw_ostream &os) const { |
1729 | return AliasResult::NoAlias; |
1730 | } |
1731 | |
1732 | //===--------------------------------------------------------------------===// |
1733 | // Resources |
1734 | //===--------------------------------------------------------------------===// |
1735 | |
1736 | /// Declare a resource with the given key, returning a handle to use for any |
1737 | /// references of this resource key within the IR during parsing. The result |
1738 | /// of `getResourceKey` on the returned handle is permitted to be different |
1739 | /// than `key`. |
1740 | virtual FailureOr<AsmDialectResourceHandle> |
1741 | declareResource(StringRef key) const { |
1742 | return failure(); |
1743 | } |
1744 | |
1745 | /// Return a key to use for the given resource. This key should uniquely |
1746 | /// identify this resource within the dialect. |
1747 | virtual std::string |
1748 | getResourceKey(const AsmDialectResourceHandle &handle) const { |
1749 | llvm_unreachable( |
1750 | "Dialect must implement `getResourceKey` when defining resources" ); |
1751 | } |
1752 | |
1753 | /// Hook for parsing resource entries. Returns failure if the entry was not |
1754 | /// valid, or could otherwise not be processed correctly. Any necessary errors |
1755 | /// can be emitted via the provided entry. |
1756 | virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const; |
1757 | |
1758 | /// Hook for building resources to use during printing. The given `op` may be |
1759 | /// inspected to help determine what information to include. |
1760 | /// `referencedResources` contains all of the resources detected when printing |
1761 | /// 'op'. |
1762 | virtual void |
1763 | buildResources(Operation *op, |
1764 | const SetVector<AsmDialectResourceHandle> &referencedResources, |
1765 | AsmResourceBuilder &builder) const {} |
1766 | }; |
1767 | |
1768 | //===--------------------------------------------------------------------===// |
1769 | // Custom printers and parsers. |
1770 | //===--------------------------------------------------------------------===// |
1771 | |
1772 | // Handles custom<DimensionList>(...) in TableGen. |
1773 | void printDimensionList(OpAsmPrinter &printer, Operation *op, |
1774 | ArrayRef<int64_t> dimensions); |
1775 | ParseResult parseDimensionList(OpAsmParser &parser, |
1776 | DenseI64ArrayAttr &dimensions); |
1777 | |
1778 | } // namespace mlir |
1779 | |
1780 | //===--------------------------------------------------------------------===// |
1781 | // Operation OpAsm interface. |
1782 | //===--------------------------------------------------------------------===// |
1783 | |
1784 | /// The OpAsmOpInterface, see OpAsmInterface.td for more details. |
1785 | #include "mlir/IR/OpAsmInterface.h.inc" |
1786 | |
1787 | namespace llvm { |
1788 | template <> |
1789 | struct DenseMapInfo<mlir::AsmDialectResourceHandle> { |
1790 | static inline mlir::AsmDialectResourceHandle getEmptyKey() { |
1791 | return {DenseMapInfo<void *>::getEmptyKey(), |
1792 | DenseMapInfo<mlir::TypeID>::getEmptyKey(), nullptr}; |
1793 | } |
1794 | static inline mlir::AsmDialectResourceHandle getTombstoneKey() { |
1795 | return {DenseMapInfo<void *>::getTombstoneKey(), |
1796 | DenseMapInfo<mlir::TypeID>::getTombstoneKey(), nullptr}; |
1797 | } |
1798 | static unsigned getHashValue(const mlir::AsmDialectResourceHandle &handle) { |
1799 | return DenseMapInfo<void *>::getHashValue(handle.getResource()); |
1800 | } |
1801 | static bool isEqual(const mlir::AsmDialectResourceHandle &lhs, |
1802 | const mlir::AsmDialectResourceHandle &rhs) { |
1803 | return lhs.getResource() == rhs.getResource(); |
1804 | } |
1805 | }; |
1806 | } // namespace llvm |
1807 | |
1808 | #endif |
1809 | |