1 | //===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the 'dialect' abstraction. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_IR_DIALECT_H |
14 | #define MLIR_IR_DIALECT_H |
15 | |
16 | #include "mlir/IR/DialectRegistry.h" |
17 | #include "mlir/IR/OperationSupport.h" |
18 | #include "mlir/Support/TypeID.h" |
19 | |
20 | #include <map> |
21 | #include <tuple> |
22 | |
23 | namespace mlir { |
24 | class DialectAsmParser; |
25 | class DialectAsmPrinter; |
26 | class DialectInterface; |
27 | class OpBuilder; |
28 | class Type; |
29 | |
30 | //===----------------------------------------------------------------------===// |
31 | // Dialect |
32 | //===----------------------------------------------------------------------===// |
33 | |
34 | /// Dialects are groups of MLIR operations, types and attributes, as well as |
35 | /// behavior associated with the entire group. For example, hooks into other |
36 | /// systems for constant folding, interfaces, default named types for asm |
37 | /// printing, etc. |
38 | /// |
39 | /// Instances of the dialect object are loaded in a specific MLIRContext. |
40 | /// |
41 | class Dialect { |
42 | public: |
43 | /// Type for a callback provided by the dialect to parse a custom operation. |
44 | /// This is used for the dialect to provide an alternative way to parse custom |
45 | /// operations, including unregistered ones. |
46 | using ParseOpHook = |
47 | function_ref<ParseResult(OpAsmParser &parser, OperationState &result)>; |
48 | |
49 | virtual ~Dialect(); |
50 | |
51 | /// Utility function that returns if the given string is a valid dialect |
52 | /// namespace |
53 | static bool isValidNamespace(StringRef str); |
54 | |
55 | MLIRContext *getContext() const { return context; } |
56 | |
57 | StringRef getNamespace() const { return name; } |
58 | |
59 | /// Returns the unique identifier that corresponds to this dialect. |
60 | TypeID getTypeID() const { return dialectID; } |
61 | |
62 | /// Returns true if this dialect allows for unregistered operations, i.e. |
63 | /// operations prefixed with the dialect namespace but not registered with |
64 | /// addOperation. |
65 | bool allowsUnknownOperations() const { return unknownOpsAllowed; } |
66 | |
67 | /// Return true if this dialect allows for unregistered types, i.e., types |
68 | /// prefixed with the dialect namespace but not registered with addType. |
69 | /// These are represented with OpaqueType. |
70 | bool allowsUnknownTypes() const { return unknownTypesAllowed; } |
71 | |
72 | /// Register dialect-wide canonicalization patterns. This method should only |
73 | /// be used to register canonicalization patterns that do not conceptually |
74 | /// belong to any single operation in the dialect. (In that case, use the op's |
75 | /// canonicalizer.) E.g., canonicalization patterns for op interfaces should |
76 | /// be registered here. |
77 | virtual void getCanonicalizationPatterns(RewritePatternSet &results) const {} |
78 | |
79 | /// Registered hook to materialize a single constant operation from a given |
80 | /// attribute value with the desired resultant type. This method should use |
81 | /// the provided builder to create the operation without changing the |
82 | /// insertion position. The generated operation is expected to be constant |
83 | /// like, i.e. single result, zero operands, non side-effecting, etc. On |
84 | /// success, this hook should return the value generated to represent the |
85 | /// constant value. Otherwise, it should return null on failure. |
86 | virtual Operation *materializeConstant(OpBuilder &builder, Attribute value, |
87 | Type type, Location loc) { |
88 | return nullptr; |
89 | } |
90 | |
91 | //===--------------------------------------------------------------------===// |
92 | // Parsing Hooks |
93 | //===--------------------------------------------------------------------===// |
94 | |
95 | /// Parse an attribute registered to this dialect. If 'type' is nonnull, it |
96 | /// refers to the expected type of the attribute. |
97 | virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const; |
98 | |
99 | /// Print an attribute registered to this dialect. Note: The type of the |
100 | /// attribute need not be printed by this method as it is always printed by |
101 | /// the caller. |
102 | virtual void printAttribute(Attribute, DialectAsmPrinter &) const { |
103 | llvm_unreachable("dialect has no registered attribute printing hook" ); |
104 | } |
105 | |
106 | /// Parse a type registered to this dialect. |
107 | virtual Type parseType(DialectAsmParser &parser) const; |
108 | |
109 | /// Print a type registered to this dialect. |
110 | virtual void printType(Type, DialectAsmPrinter &) const { |
111 | llvm_unreachable("dialect has no registered type printing hook" ); |
112 | } |
113 | |
114 | /// Return the hook to parse an operation registered to this dialect, if any. |
115 | /// By default this will lookup for registered operations and return the |
116 | /// `parse()` method registered on the RegisteredOperationName. Dialects can |
117 | /// override this behavior and handle unregistered operations as well. |
118 | virtual std::optional<ParseOpHook> |
119 | getParseOperationHook(StringRef opName) const; |
120 | |
121 | /// Print an operation registered to this dialect. |
122 | /// This hook is invoked for registered operation which don't override the |
123 | /// `print()` method to define their own custom assembly. |
124 | virtual llvm::unique_function<void(Operation *, OpAsmPrinter &printer)> |
125 | getOperationPrinter(Operation *op) const; |
126 | |
127 | //===--------------------------------------------------------------------===// |
128 | // Verification Hooks |
129 | //===--------------------------------------------------------------------===// |
130 | |
131 | /// Verify an attribute from this dialect on the argument at 'argIndex' for |
132 | /// the region at 'regionIndex' on the given operation. Returns failure if |
133 | /// the verification failed, success otherwise. This hook may optionally be |
134 | /// invoked from any operation containing a region. |
135 | virtual LogicalResult verifyRegionArgAttribute(Operation *, |
136 | unsigned regionIndex, |
137 | unsigned argIndex, |
138 | NamedAttribute); |
139 | |
140 | /// Verify an attribute from this dialect on the result at 'resultIndex' for |
141 | /// the region at 'regionIndex' on the given operation. Returns failure if |
142 | /// the verification failed, success otherwise. This hook may optionally be |
143 | /// invoked from any operation containing a region. |
144 | virtual LogicalResult verifyRegionResultAttribute(Operation *, |
145 | unsigned regionIndex, |
146 | unsigned resultIndex, |
147 | NamedAttribute); |
148 | |
149 | /// Verify an attribute from this dialect on the given operation. Returns |
150 | /// failure if the verification failed, success otherwise. |
151 | virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute) { |
152 | return success(); |
153 | } |
154 | |
155 | //===--------------------------------------------------------------------===// |
156 | // Interfaces |
157 | //===--------------------------------------------------------------------===// |
158 | |
159 | /// Lookup an interface for the given ID if one is registered, otherwise |
160 | /// nullptr. |
161 | DialectInterface *getRegisteredInterface(TypeID interfaceID) { |
162 | #ifndef NDEBUG |
163 | handleUseOfUndefinedPromisedInterface(interfaceRequestorID: getTypeID(), interfaceID); |
164 | #endif |
165 | |
166 | auto it = registeredInterfaces.find(Val: interfaceID); |
167 | return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr; |
168 | } |
169 | template <typename InterfaceT> |
170 | InterfaceT *getRegisteredInterface() { |
171 | #ifndef NDEBUG |
172 | handleUseOfUndefinedPromisedInterface(interfaceRequestorID: getTypeID(), |
173 | interfaceID: InterfaceT::getInterfaceID(), |
174 | interfaceName: llvm::getTypeName<InterfaceT>()); |
175 | #endif |
176 | |
177 | return static_cast<InterfaceT *>( |
178 | getRegisteredInterface(InterfaceT::getInterfaceID())); |
179 | } |
180 | |
181 | /// Lookup an op interface for the given ID if one is registered, otherwise |
182 | /// nullptr. |
183 | virtual void *getRegisteredInterfaceForOp(TypeID interfaceID, |
184 | OperationName opName) { |
185 | return nullptr; |
186 | } |
187 | template <typename InterfaceT> |
188 | typename InterfaceT::Concept * |
189 | getRegisteredInterfaceForOp(OperationName opName) { |
190 | return static_cast<typename InterfaceT::Concept *>( |
191 | getRegisteredInterfaceForOp(InterfaceT::getInterfaceID(), opName)); |
192 | } |
193 | |
194 | /// Register a dialect interface with this dialect instance. |
195 | void addInterface(std::unique_ptr<DialectInterface> interface); |
196 | |
197 | /// Register a set of dialect interfaces with this dialect instance. |
198 | template <typename... Args> |
199 | void addInterfaces() { |
200 | (addInterface(std::make_unique<Args>(this)), ...); |
201 | } |
202 | template <typename InterfaceT, typename... Args> |
203 | InterfaceT &addInterface(Args &&...args) { |
204 | InterfaceT *interface = new InterfaceT(this, std::forward<Args>(args)...); |
205 | addInterface(interface: std::unique_ptr<DialectInterface>(interface)); |
206 | return *interface; |
207 | } |
208 | |
209 | /// Declare that the given interface will be implemented, but has a delayed |
210 | /// registration. The promised interface type can be an interface of any type |
211 | /// not just a dialect interface, i.e. it may also be an |
212 | /// AttributeInterface/OpInterface/TypeInterface/etc. |
213 | template <typename InterfaceT, typename ConcreteT> |
214 | void declarePromisedInterface() { |
215 | unresolvedPromisedInterfaces.insert( |
216 | {TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()}); |
217 | } |
218 | |
219 | // Declare the same interface for multiple types. |
220 | // Example: |
221 | // declarePromisedInterfaces<FunctionOpInterface, MyFuncType1, MyFuncType2>() |
222 | template <typename InterfaceT, typename... ConcreteT> |
223 | void declarePromisedInterfaces() { |
224 | (declarePromisedInterface<InterfaceT, ConcreteT>(), ...); |
225 | } |
226 | |
227 | /// Checks if the given interface, which is attempting to be used, is a |
228 | /// promised interface of this dialect that has yet to be implemented. If so, |
229 | /// emits a fatal error. `interfaceName` is an optional string that contains a |
230 | /// more user readable name for the interface (such as the class name). |
231 | void handleUseOfUndefinedPromisedInterface(TypeID interfaceRequestorID, |
232 | TypeID interfaceID, |
233 | StringRef interfaceName = "" ) { |
234 | if (unresolvedPromisedInterfaces.count( |
235 | V: {interfaceRequestorID, interfaceID})) { |
236 | llvm::report_fatal_error( |
237 | reason: "checking for an interface (`" + interfaceName + |
238 | "`) that was promised by dialect '" + getNamespace() + |
239 | "' but never implemented. This is generally an indication " |
240 | "that the dialect extension implementing the interface was never " |
241 | "registered." ); |
242 | } |
243 | } |
244 | |
245 | /// Checks if the given interface, which is attempting to be attached to a |
246 | /// construct owned by this dialect, is a promised interface of this dialect |
247 | /// that has yet to be implemented. If so, it resolves the interface promise. |
248 | void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceRequestorID, |
249 | TypeID interfaceID) { |
250 | unresolvedPromisedInterfaces.erase(V: {interfaceRequestorID, interfaceID}); |
251 | } |
252 | |
253 | /// Checks if a promise has been made for the interface/requestor pair. |
254 | bool hasPromisedInterface(TypeID interfaceRequestorID, |
255 | TypeID interfaceID) const { |
256 | return unresolvedPromisedInterfaces.count( |
257 | V: {interfaceRequestorID, interfaceID}); |
258 | } |
259 | |
260 | /// Checks if a promise has been made for the interface/requestor pair. |
261 | template <typename ConcreteT, typename InterfaceT> |
262 | bool hasPromisedInterface() const { |
263 | return hasPromisedInterface(TypeID::get<ConcreteT>(), |
264 | InterfaceT::getInterfaceID()); |
265 | } |
266 | |
267 | protected: |
268 | /// The constructor takes a unique namespace for this dialect as well as the |
269 | /// context to bind to. |
270 | /// Note: The namespace must not contain '.' characters. |
271 | /// Note: All operations belonging to this dialect must have names starting |
272 | /// with the namespace followed by '.'. |
273 | /// Example: |
274 | /// - "tf" for the TensorFlow ops like "tf.add". |
275 | Dialect(StringRef name, MLIRContext *context, TypeID id); |
276 | |
277 | /// This method is used by derived classes to add their operations to the set. |
278 | /// |
279 | template <typename... Args> |
280 | void addOperations() { |
281 | // This initializer_list argument pack expansion is essentially equal to |
282 | // using a fold expression with a comma operator. Clang however, refuses |
283 | // to compile a fold expression with a depth of more than 256 by default. |
284 | // There seem to be no such limitations for initializer_list. |
285 | (void)std::initializer_list<int>{ |
286 | 0, (RegisteredOperationName::insert<Args>(*this), 0)...}; |
287 | } |
288 | |
289 | /// Register a set of type classes with this dialect. |
290 | template <typename... Args> |
291 | void addTypes() { |
292 | // This initializer_list argument pack expansion is essentially equal to |
293 | // using a fold expression with a comma operator. Clang however, refuses |
294 | // to compile a fold expression with a depth of more than 256 by default. |
295 | // There seem to be no such limitations for initializer_list. |
296 | (void)std::initializer_list<int>{0, (addType<Args>(), 0)...}; |
297 | } |
298 | |
299 | /// Register a type instance with this dialect. |
300 | /// The use of this method is in general discouraged in favor of |
301 | /// 'addTypes<CustomType>()'. |
302 | void addType(TypeID typeID, AbstractType &&typeInfo); |
303 | |
304 | /// Register a set of attribute classes with this dialect. |
305 | template <typename... Args> |
306 | void addAttributes() { |
307 | // This initializer_list argument pack expansion is essentially equal to |
308 | // using a fold expression with a comma operator. Clang however, refuses |
309 | // to compile a fold expression with a depth of more than 256 by default. |
310 | // There seem to be no such limitations for initializer_list. |
311 | (void)std::initializer_list<int>{0, (addAttribute<Args>(), 0)...}; |
312 | } |
313 | |
314 | /// Register an attribute instance with this dialect. |
315 | /// The use of this method is in general discouraged in favor of |
316 | /// 'addAttributes<CustomAttr>()'. |
317 | void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo); |
318 | |
319 | /// Enable support for unregistered operations. |
320 | void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; } |
321 | |
322 | /// Enable support for unregistered types. |
323 | void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; } |
324 | |
325 | private: |
326 | Dialect(const Dialect &) = delete; |
327 | void operator=(Dialect &) = delete; |
328 | |
329 | /// Register an attribute instance with this dialect. |
330 | template <typename T> |
331 | void addAttribute() { |
332 | // Add this attribute to the dialect and register it with the uniquer. |
333 | addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this)); |
334 | detail::AttributeUniquer::registerAttribute<T>(context); |
335 | } |
336 | |
337 | /// Register a type instance with this dialect. |
338 | template <typename T> |
339 | void addType() { |
340 | // Add this type to the dialect and register it with the uniquer. |
341 | addType(T::getTypeID(), AbstractType::get<T>(*this)); |
342 | detail::TypeUniquer::registerType<T>(context); |
343 | } |
344 | |
345 | /// The namespace of this dialect. |
346 | StringRef name; |
347 | |
348 | /// The unique identifier of the derived Op class, this is used in the context |
349 | /// to allow registering multiple times the same dialect. |
350 | TypeID dialectID; |
351 | |
352 | /// This is the context that owns this Dialect object. |
353 | MLIRContext *context; |
354 | |
355 | /// Flag that specifies whether this dialect supports unregistered operations, |
356 | /// i.e. operations prefixed with the dialect namespace but not registered |
357 | /// with addOperation. |
358 | bool unknownOpsAllowed = false; |
359 | |
360 | /// Flag that specifies whether this dialect allows unregistered types, i.e. |
361 | /// types prefixed with the dialect namespace but not registered with addType. |
362 | /// These types are represented with OpaqueType. |
363 | bool unknownTypesAllowed = false; |
364 | |
365 | /// A collection of registered dialect interfaces. |
366 | DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces; |
367 | |
368 | /// A set of interfaces that the dialect (or its constructs, i.e. |
369 | /// Attributes/Operations/Types/etc.) has promised to implement, but has yet |
370 | /// to provide an implementation for. |
371 | DenseSet<std::pair<TypeID, TypeID>> unresolvedPromisedInterfaces; |
372 | |
373 | friend class DialectRegistry; |
374 | friend void registerDialect(); |
375 | friend class MLIRContext; |
376 | }; |
377 | |
378 | } // namespace mlir |
379 | |
380 | namespace llvm { |
381 | /// Provide isa functionality for Dialects. |
382 | template <typename T> |
383 | struct isa_impl<T, ::mlir::Dialect, |
384 | std::enable_if_t<std::is_base_of<::mlir::Dialect, T>::value>> { |
385 | static inline bool doit(const ::mlir::Dialect &dialect) { |
386 | return mlir::TypeID::get<T>() == dialect.getTypeID(); |
387 | } |
388 | }; |
389 | template <typename T> |
390 | struct isa_impl< |
391 | T, ::mlir::Dialect, |
392 | std::enable_if_t<std::is_base_of<::mlir::DialectInterface, T>::value>> { |
393 | static inline bool doit(const ::mlir::Dialect &dialect) { |
394 | return const_cast<::mlir::Dialect &>(dialect).getRegisteredInterface<T>(); |
395 | } |
396 | }; |
397 | template <typename T> |
398 | struct cast_retty_impl<T, ::mlir::Dialect *> { |
399 | using ret_type = T *; |
400 | }; |
401 | template <typename T> |
402 | struct cast_retty_impl<T, ::mlir::Dialect> { |
403 | using ret_type = T &; |
404 | }; |
405 | |
406 | template <typename T> |
407 | struct cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect> { |
408 | template <typename To> |
409 | static std::enable_if_t<std::is_base_of<::mlir::Dialect, To>::value, To &> |
410 | doitImpl(::mlir::Dialect &dialect) { |
411 | return static_cast<To &>(dialect); |
412 | } |
413 | template <typename To> |
414 | static std::enable_if_t<std::is_base_of<::mlir::DialectInterface, To>::value, |
415 | To &> |
416 | doitImpl(::mlir::Dialect &dialect) { |
417 | return *dialect.getRegisteredInterface<To>(); |
418 | } |
419 | |
420 | static auto &doit(::mlir::Dialect &dialect) { return doitImpl<T>(dialect); } |
421 | }; |
422 | template <class T> |
423 | struct cast_convert_val<T, ::mlir::Dialect *, ::mlir::Dialect *> { |
424 | static auto doit(::mlir::Dialect *dialect) { |
425 | return &cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect>::doit( |
426 | *dialect); |
427 | } |
428 | }; |
429 | |
430 | } // namespace llvm |
431 | |
432 | #endif |
433 | |