1 | //===- OperationSupport.h ---------------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines a number of support types that Operation and related |
10 | // classes build on top of. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef MLIR_IR_OPERATIONSUPPORT_H |
15 | #define MLIR_IR_OPERATIONSUPPORT_H |
16 | |
17 | #include "mlir/IR/Attributes.h" |
18 | #include "mlir/IR/BlockSupport.h" |
19 | #include "mlir/IR/BuiltinAttributes.h" |
20 | #include "mlir/IR/Diagnostics.h" |
21 | #include "mlir/IR/DialectRegistry.h" |
22 | #include "mlir/IR/Location.h" |
23 | #include "mlir/IR/TypeRange.h" |
24 | #include "mlir/IR/Types.h" |
25 | #include "mlir/IR/Value.h" |
26 | #include "mlir/Support/InterfaceSupport.h" |
27 | #include "llvm/ADT/BitmaskEnum.h" |
28 | #include "llvm/ADT/PointerUnion.h" |
29 | #include "llvm/ADT/STLFunctionalExtras.h" |
30 | #include "llvm/Support/ErrorHandling.h" |
31 | #include "llvm/Support/PointerLikeTypeTraits.h" |
32 | #include "llvm/Support/TrailingObjects.h" |
33 | #include <memory> |
34 | #include <optional> |
35 | |
36 | namespace llvm { |
37 | class BitVector; |
38 | } // namespace llvm |
39 | |
40 | namespace mlir { |
41 | class Dialect; |
42 | class DictionaryAttr; |
43 | class ElementsAttr; |
44 | struct EmptyProperties; |
45 | class MutableOperandRangeRange; |
46 | class NamedAttrList; |
47 | class Operation; |
48 | struct OperationState; |
49 | class OpAsmParser; |
50 | class OpAsmPrinter; |
51 | class OperandRange; |
52 | class OperandRangeRange; |
53 | class OpFoldResult; |
54 | class ParseResult; |
55 | class Pattern; |
56 | class Region; |
57 | class ResultRange; |
58 | class RewritePattern; |
59 | class RewritePatternSet; |
60 | class Type; |
61 | class Value; |
62 | class ValueRange; |
63 | template <typename ValueRangeT> |
64 | class ValueTypeRange; |
65 | |
66 | //===----------------------------------------------------------------------===// |
67 | // OpaqueProperties |
68 | //===----------------------------------------------------------------------===// |
69 | |
70 | /// Simple wrapper around a void* in order to express generically how to pass |
71 | /// in op properties through APIs. |
72 | class OpaqueProperties { |
73 | public: |
74 | OpaqueProperties(void *prop) : properties(prop) {} |
75 | operator bool() const { return properties != nullptr; } |
76 | template <typename Dest> |
77 | Dest as() const { |
78 | return static_cast<Dest>(const_cast<void *>(properties)); |
79 | } |
80 | |
81 | private: |
82 | void *properties; |
83 | }; |
84 | |
85 | //===----------------------------------------------------------------------===// |
86 | // OperationName |
87 | //===----------------------------------------------------------------------===// |
88 | |
89 | class OperationName { |
90 | public: |
91 | using FoldHookFn = llvm::unique_function<LogicalResult( |
92 | Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) const>; |
93 | using HasTraitFn = llvm::unique_function<bool(TypeID) const>; |
94 | using ParseAssemblyFn = |
95 | llvm::unique_function<ParseResult(OpAsmParser &, OperationState &)>; |
96 | // Note: RegisteredOperationName is passed as reference here as the derived |
97 | // class is defined below. |
98 | using PopulateDefaultAttrsFn = |
99 | llvm::unique_function<void(const OperationName &, NamedAttrList &) const>; |
100 | using PrintAssemblyFn = |
101 | llvm::unique_function<void(Operation *, OpAsmPrinter &, StringRef) const>; |
102 | using VerifyInvariantsFn = |
103 | llvm::unique_function<LogicalResult(Operation *) const>; |
104 | using VerifyRegionInvariantsFn = |
105 | llvm::unique_function<LogicalResult(Operation *) const>; |
106 | |
107 | /// This class represents a type erased version of an operation. It contains |
108 | /// all of the components necessary for opaquely interacting with an |
109 | /// operation. If the operation is not registered, some of these components |
110 | /// may not be populated. |
111 | struct InterfaceConcept { |
112 | virtual ~InterfaceConcept() = default; |
113 | virtual LogicalResult foldHook(Operation *, ArrayRef<Attribute>, |
114 | SmallVectorImpl<OpFoldResult> &) = 0; |
115 | virtual void getCanonicalizationPatterns(RewritePatternSet &, |
116 | MLIRContext *) = 0; |
117 | virtual bool hasTrait(TypeID) = 0; |
118 | virtual OperationName::ParseAssemblyFn getParseAssemblyFn() = 0; |
119 | virtual void populateDefaultAttrs(const OperationName &, |
120 | NamedAttrList &) = 0; |
121 | virtual void printAssembly(Operation *, OpAsmPrinter &, StringRef) = 0; |
122 | virtual LogicalResult verifyInvariants(Operation *) = 0; |
123 | virtual LogicalResult verifyRegionInvariants(Operation *) = 0; |
124 | /// Implementation for properties |
125 | virtual std::optional<Attribute> getInherentAttr(Operation *, |
126 | StringRef name) = 0; |
127 | virtual void setInherentAttr(Operation *op, StringAttr name, |
128 | Attribute value) = 0; |
129 | virtual void populateInherentAttrs(Operation *op, NamedAttrList &attrs) = 0; |
130 | virtual LogicalResult |
131 | verifyInherentAttrs(OperationName opName, NamedAttrList &attributes, |
132 | function_ref<InFlightDiagnostic()> emitError) = 0; |
133 | virtual int getOpPropertyByteSize() = 0; |
134 | virtual void initProperties(OperationName opName, OpaqueProperties storage, |
135 | OpaqueProperties init) = 0; |
136 | virtual void deleteProperties(OpaqueProperties) = 0; |
137 | virtual void populateDefaultProperties(OperationName opName, |
138 | OpaqueProperties properties) = 0; |
139 | virtual LogicalResult |
140 | setPropertiesFromAttr(OperationName, OpaqueProperties, Attribute, |
141 | function_ref<InFlightDiagnostic()> emitError) = 0; |
142 | virtual Attribute getPropertiesAsAttr(Operation *) = 0; |
143 | virtual void copyProperties(OpaqueProperties, OpaqueProperties) = 0; |
144 | virtual bool compareProperties(OpaqueProperties, OpaqueProperties) = 0; |
145 | virtual llvm::hash_code hashProperties(OpaqueProperties) = 0; |
146 | }; |
147 | |
148 | public: |
149 | class Impl : public InterfaceConcept { |
150 | public: |
151 | Impl(StringRef, Dialect *dialect, TypeID typeID, |
152 | detail::InterfaceMap interfaceMap); |
153 | Impl(StringAttr name, Dialect *dialect, TypeID typeID, |
154 | detail::InterfaceMap interfaceMap) |
155 | : name(name), typeID(typeID), dialect(dialect), |
156 | interfaceMap(std::move(interfaceMap)) {} |
157 | |
158 | /// Returns true if this is a registered operation. |
159 | bool isRegistered() const { return typeID != TypeID::get<void>(); } |
160 | detail::InterfaceMap &getInterfaceMap() { return interfaceMap; } |
161 | Dialect *getDialect() const { return dialect; } |
162 | StringAttr getName() const { return name; } |
163 | TypeID getTypeID() const { return typeID; } |
164 | ArrayRef<StringAttr> getAttributeNames() const { return attributeNames; } |
165 | |
166 | protected: |
167 | //===------------------------------------------------------------------===// |
168 | // Registered Operation Info |
169 | |
170 | /// The name of the operation. |
171 | StringAttr name; |
172 | |
173 | /// The unique identifier of the derived Op class. |
174 | TypeID typeID; |
175 | |
176 | /// The following fields are only populated when the operation is |
177 | /// registered. |
178 | |
179 | /// This is the dialect that this operation belongs to. |
180 | Dialect *dialect; |
181 | |
182 | /// A map of interfaces that were registered to this operation. |
183 | detail::InterfaceMap interfaceMap; |
184 | |
185 | /// A list of attribute names registered to this operation in StringAttr |
186 | /// form. This allows for operation classes to use StringAttr for attribute |
187 | /// lookup/creation/etc., as opposed to raw strings. |
188 | ArrayRef<StringAttr> attributeNames; |
189 | |
190 | friend class RegisteredOperationName; |
191 | }; |
192 | |
193 | protected: |
194 | /// Default implementation for unregistered operations. |
195 | struct UnregisteredOpModel : public Impl { |
196 | using Impl::Impl; |
197 | LogicalResult foldHook(Operation *, ArrayRef<Attribute>, |
198 | SmallVectorImpl<OpFoldResult> &) final; |
199 | void getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) final; |
200 | bool hasTrait(TypeID) final; |
201 | OperationName::ParseAssemblyFn getParseAssemblyFn() final; |
202 | void populateDefaultAttrs(const OperationName &, NamedAttrList &) final; |
203 | void printAssembly(Operation *, OpAsmPrinter &, StringRef) final; |
204 | LogicalResult verifyInvariants(Operation *) final; |
205 | LogicalResult verifyRegionInvariants(Operation *) final; |
206 | /// Implementation for properties |
207 | std::optional<Attribute> getInherentAttr(Operation *op, |
208 | StringRef name) final; |
209 | void setInherentAttr(Operation *op, StringAttr name, Attribute value) final; |
210 | void populateInherentAttrs(Operation *op, NamedAttrList &attrs) final; |
211 | LogicalResult |
212 | verifyInherentAttrs(OperationName opName, NamedAttrList &attributes, |
213 | function_ref<InFlightDiagnostic()> emitError) final; |
214 | int getOpPropertyByteSize() final; |
215 | void initProperties(OperationName opName, OpaqueProperties storage, |
216 | OpaqueProperties init) final; |
217 | void deleteProperties(OpaqueProperties) final; |
218 | void populateDefaultProperties(OperationName opName, |
219 | OpaqueProperties properties) final; |
220 | LogicalResult |
221 | setPropertiesFromAttr(OperationName, OpaqueProperties, Attribute, |
222 | function_ref<InFlightDiagnostic()> emitError) final; |
223 | Attribute getPropertiesAsAttr(Operation *) final; |
224 | void copyProperties(OpaqueProperties, OpaqueProperties) final; |
225 | bool compareProperties(OpaqueProperties, OpaqueProperties) final; |
226 | llvm::hash_code hashProperties(OpaqueProperties) final; |
227 | }; |
228 | |
229 | public: |
230 | OperationName(StringRef name, MLIRContext *context); |
231 | |
232 | /// Return if this operation is registered. |
233 | bool isRegistered() const { return getImpl()->isRegistered(); } |
234 | |
235 | /// Return the unique identifier of the derived Op class, or null if not |
236 | /// registered. |
237 | TypeID getTypeID() const { return getImpl()->getTypeID(); } |
238 | |
239 | /// If this operation is registered, returns the registered information, |
240 | /// std::nullopt otherwise. |
241 | std::optional<RegisteredOperationName> getRegisteredInfo() const; |
242 | |
243 | /// This hook implements a generalized folder for this operation. Operations |
244 | /// can implement this to provide simplifications rules that are applied by |
245 | /// the Builder::createOrFold API and the canonicalization pass. |
246 | /// |
247 | /// This is an intentionally limited interface - implementations of this |
248 | /// hook can only perform the following changes to the operation: |
249 | /// |
250 | /// 1. They can leave the operation alone and without changing the IR, and |
251 | /// return failure. |
252 | /// 2. They can mutate the operation in place, without changing anything |
253 | /// else |
254 | /// in the IR. In this case, return success. |
255 | /// 3. They can return a list of existing values that can be used instead |
256 | /// of |
257 | /// the operation. In this case, fill in the results list and return |
258 | /// success. The caller will remove the operation and use those results |
259 | /// instead. |
260 | /// |
261 | /// This allows expression of some simple in-place canonicalizations (e.g. |
262 | /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as |
263 | /// generalized constant folding. |
264 | LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands, |
265 | SmallVectorImpl<OpFoldResult> &results) const { |
266 | return getImpl()->foldHook(op, operands, results); |
267 | } |
268 | |
269 | /// This hook returns any canonicalization pattern rewrites that the |
270 | /// operation supports, for use by the canonicalization pass. |
271 | void getCanonicalizationPatterns(RewritePatternSet &results, |
272 | MLIRContext *context) const { |
273 | return getImpl()->getCanonicalizationPatterns(results, context); |
274 | } |
275 | |
276 | /// Returns true if the operation was registered with a particular trait, e.g. |
277 | /// hasTrait<OperandsAreSignlessIntegerLike>(). Returns false if the operation |
278 | /// is unregistered. |
279 | template <template <typename T> class Trait> |
280 | bool hasTrait() const { |
281 | return hasTrait(TypeID::get<Trait>()); |
282 | } |
283 | bool hasTrait(TypeID traitID) const { return getImpl()->hasTrait(traitID); } |
284 | |
285 | /// Returns true if the operation *might* have the provided trait. This |
286 | /// means that either the operation is unregistered, or it was registered with |
287 | /// the provide trait. |
288 | template <template <typename T> class Trait> |
289 | bool mightHaveTrait() const { |
290 | return mightHaveTrait(TypeID::get<Trait>()); |
291 | } |
292 | bool mightHaveTrait(TypeID traitID) const { |
293 | return !isRegistered() || getImpl()->hasTrait(traitID); |
294 | } |
295 | |
296 | /// Return the static hook for parsing this operation assembly. |
297 | ParseAssemblyFn getParseAssemblyFn() const { |
298 | return getImpl()->getParseAssemblyFn(); |
299 | } |
300 | |
301 | /// This hook implements the method to populate defaults attributes that are |
302 | /// unset. |
303 | void populateDefaultAttrs(NamedAttrList &attrs) const { |
304 | getImpl()->populateDefaultAttrs(*this, attrs); |
305 | } |
306 | |
307 | /// This hook implements the AsmPrinter for this operation. |
308 | void printAssembly(Operation *op, OpAsmPrinter &p, |
309 | StringRef defaultDialect) const { |
310 | return getImpl()->printAssembly(op, p, defaultDialect); |
311 | } |
312 | |
313 | /// These hooks implement the verifiers for this operation. It should emits |
314 | /// an error message and returns failure if a problem is detected, or |
315 | /// returns success if everything is ok. |
316 | LogicalResult verifyInvariants(Operation *op) const { |
317 | return getImpl()->verifyInvariants(op); |
318 | } |
319 | LogicalResult verifyRegionInvariants(Operation *op) const { |
320 | return getImpl()->verifyRegionInvariants(op); |
321 | } |
322 | |
323 | /// Return the list of cached attribute names registered to this operation. |
324 | /// The order of attributes cached here is unique to each type of operation, |
325 | /// and the interpretation of this attribute list should generally be driven |
326 | /// by the respective operation. In many cases, this caching removes the |
327 | /// need to use the raw string name of a known attribute. |
328 | /// |
329 | /// For example the ODS generator, with an op defining the following |
330 | /// attributes: |
331 | /// |
332 | /// let arguments = (ins I32Attr:$attr1, I32Attr:$attr2); |
333 | /// |
334 | /// ... may produce an order here of ["attr1", "attr2"]. This allows for the |
335 | /// ODS generator to directly access the cached name for a known attribute, |
336 | /// greatly simplifying the cost and complexity of attribute usage produced |
337 | /// by the generator. |
338 | /// |
339 | ArrayRef<StringAttr> getAttributeNames() const { |
340 | return getImpl()->getAttributeNames(); |
341 | } |
342 | |
343 | /// Returns an instance of the concept object for the given interface if it |
344 | /// was registered to this operation, null otherwise. This should not be used |
345 | /// directly. |
346 | template <typename T> |
347 | typename T::Concept *getInterface() const { |
348 | return getImpl()->getInterfaceMap().lookup<T>(); |
349 | } |
350 | |
351 | /// Attach the given models as implementations of the corresponding |
352 | /// interfaces for the concrete operation. |
353 | template <typename... Models> |
354 | void attachInterface() { |
355 | // Handle the case where the models resolve a promised interface. |
356 | (dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface( |
357 | dialect&: *getDialect(), interfaceRequestorID: getTypeID(), interfaceID: Models::Interface::getInterfaceID()), |
358 | ...); |
359 | |
360 | getImpl()->getInterfaceMap().insertModels<Models...>(); |
361 | } |
362 | |
363 | /// Returns true if `InterfaceT` has been promised by the dialect or |
364 | /// implemented. |
365 | template <typename InterfaceT> |
366 | bool hasPromiseOrImplementsInterface() const { |
367 | return dialect_extension_detail::hasPromisedInterface( |
368 | getDialect(), getTypeID(), InterfaceT::getInterfaceID()) || |
369 | hasInterface<InterfaceT>(); |
370 | } |
371 | |
372 | /// Returns true if this operation has the given interface registered to it. |
373 | template <typename T> |
374 | bool hasInterface() const { |
375 | return hasInterface(TypeID::get<T>()); |
376 | } |
377 | bool hasInterface(TypeID interfaceID) const { |
378 | return getImpl()->getInterfaceMap().contains(interfaceID); |
379 | } |
380 | |
381 | /// Returns true if the operation *might* have the provided interface. This |
382 | /// means that either the operation is unregistered, or it was registered with |
383 | /// the provide interface. |
384 | template <typename T> |
385 | bool mightHaveInterface() const { |
386 | return mightHaveInterface(TypeID::get<T>()); |
387 | } |
388 | bool mightHaveInterface(TypeID interfaceID) const { |
389 | return !isRegistered() || hasInterface(interfaceID); |
390 | } |
391 | |
392 | /// Lookup an inherent attribute by name, this method isn't recommended |
393 | /// and may be removed in the future. |
394 | std::optional<Attribute> getInherentAttr(Operation *op, |
395 | StringRef name) const { |
396 | return getImpl()->getInherentAttr(op, name); |
397 | } |
398 | |
399 | void setInherentAttr(Operation *op, StringAttr name, Attribute value) const { |
400 | return getImpl()->setInherentAttr(op, name: name, value); |
401 | } |
402 | |
403 | void populateInherentAttrs(Operation *op, NamedAttrList &attrs) const { |
404 | return getImpl()->populateInherentAttrs(op, attrs); |
405 | } |
406 | /// This method exists for backward compatibility purpose when using |
407 | /// properties to store inherent attributes, it enables validating the |
408 | /// attributes when parsed from the older generic syntax pre-Properties. |
409 | LogicalResult |
410 | verifyInherentAttrs(NamedAttrList &attributes, |
411 | function_ref<InFlightDiagnostic()> emitError) const { |
412 | return getImpl()->verifyInherentAttrs(*this, attributes, emitError); |
413 | } |
414 | /// This hooks return the number of bytes to allocate for the op properties. |
415 | int getOpPropertyByteSize() const { |
416 | return getImpl()->getOpPropertyByteSize(); |
417 | } |
418 | |
419 | /// This hooks destroy the op properties. |
420 | void destroyOpProperties(OpaqueProperties properties) const { |
421 | getImpl()->deleteProperties(properties); |
422 | } |
423 | |
424 | /// Initialize the op properties. |
425 | void initOpProperties(OpaqueProperties storage, OpaqueProperties init) const { |
426 | getImpl()->initProperties(*this, storage, init); |
427 | } |
428 | |
429 | /// Set the default values on the ODS attribute in the properties. |
430 | void populateDefaultProperties(OpaqueProperties properties) const { |
431 | getImpl()->populateDefaultProperties(*this, properties); |
432 | } |
433 | |
434 | /// Return the op properties converted to an Attribute. |
435 | Attribute getOpPropertiesAsAttribute(Operation *op) const { |
436 | return getImpl()->getPropertiesAsAttr(op); |
437 | } |
438 | |
439 | /// Define the op properties from the provided Attribute. |
440 | LogicalResult setOpPropertiesFromAttribute( |
441 | OperationName opName, OpaqueProperties properties, Attribute attr, |
442 | function_ref<InFlightDiagnostic()> emitError) const { |
443 | return getImpl()->setPropertiesFromAttr(opName, properties, attr, |
444 | emitError); |
445 | } |
446 | |
447 | void copyOpProperties(OpaqueProperties lhs, OpaqueProperties rhs) const { |
448 | return getImpl()->copyProperties(lhs, rhs); |
449 | } |
450 | |
451 | bool compareOpProperties(OpaqueProperties lhs, OpaqueProperties rhs) const { |
452 | return getImpl()->compareProperties(lhs, rhs); |
453 | } |
454 | |
455 | llvm::hash_code hashOpProperties(OpaqueProperties properties) const { |
456 | return getImpl()->hashProperties(properties); |
457 | } |
458 | |
459 | /// Return the dialect this operation is registered to if the dialect is |
460 | /// loaded in the context, or nullptr if the dialect isn't loaded. |
461 | Dialect *getDialect() const { |
462 | return isRegistered() ? getImpl()->getDialect() |
463 | : getImpl()->getName().getReferencedDialect(); |
464 | } |
465 | |
466 | /// Return the name of the dialect this operation is registered to. |
467 | StringRef getDialectNamespace() const; |
468 | |
469 | /// Return the operation name with dialect name stripped, if it has one. |
470 | StringRef stripDialect() const { return getStringRef().split(Separator: '.').second; } |
471 | |
472 | /// Return the context this operation is associated with. |
473 | MLIRContext *getContext() { return getIdentifier().getContext(); } |
474 | |
475 | /// Return the name of this operation. This always succeeds. |
476 | StringRef getStringRef() const { return getIdentifier(); } |
477 | |
478 | /// Return the name of this operation as a StringAttr. |
479 | StringAttr getIdentifier() const { return getImpl()->getName(); } |
480 | |
481 | void print(raw_ostream &os) const; |
482 | void dump() const; |
483 | |
484 | /// Represent the operation name as an opaque pointer. (Used to support |
485 | /// PointerLikeTypeTraits). |
486 | void *getAsOpaquePointer() const { return const_cast<Impl *>(impl); } |
487 | static OperationName getFromOpaquePointer(const void *pointer) { |
488 | return OperationName( |
489 | const_cast<Impl *>(reinterpret_cast<const Impl *>(pointer))); |
490 | } |
491 | |
492 | bool operator==(const OperationName &rhs) const { return impl == rhs.impl; } |
493 | bool operator!=(const OperationName &rhs) const { return !(*this == rhs); } |
494 | |
495 | protected: |
496 | OperationName(Impl *impl) : impl(impl) {} |
497 | Impl *getImpl() const { return impl; } |
498 | void setImpl(Impl *rhs) { impl = rhs; } |
499 | |
500 | private: |
501 | /// The internal implementation of the operation name. |
502 | Impl *impl = nullptr; |
503 | |
504 | /// Allow access to the Impl struct. |
505 | friend MLIRContextImpl; |
506 | friend DenseMapInfo<mlir::OperationName>; |
507 | friend DenseMapInfo<mlir::RegisteredOperationName>; |
508 | }; |
509 | |
510 | inline raw_ostream &operator<<(raw_ostream &os, OperationName info) { |
511 | info.print(os); |
512 | return os; |
513 | } |
514 | |
515 | // Make operation names hashable. |
516 | inline llvm::hash_code hash_value(OperationName arg) { |
517 | return llvm::hash_value(ptr: arg.getAsOpaquePointer()); |
518 | } |
519 | |
520 | //===----------------------------------------------------------------------===// |
521 | // RegisteredOperationName |
522 | //===----------------------------------------------------------------------===// |
523 | |
524 | /// This is a "type erased" representation of a registered operation. This |
525 | /// should only be used by things like the AsmPrinter and other things that need |
526 | /// to be parameterized by generic operation hooks. Most user code should use |
527 | /// the concrete operation types. |
528 | class RegisteredOperationName : public OperationName { |
529 | public: |
530 | /// Implementation of the InterfaceConcept for operation APIs that forwarded |
531 | /// to a concrete op implementation. |
532 | template <typename ConcreteOp> |
533 | struct Model : public Impl { |
534 | Model(Dialect *dialect) |
535 | : Impl(ConcreteOp::getOperationName(), dialect, |
536 | TypeID::get<ConcreteOp>(), ConcreteOp::getInterfaceMap()) {} |
537 | LogicalResult foldHook(Operation *op, ArrayRef<Attribute> attrs, |
538 | SmallVectorImpl<OpFoldResult> &results) final { |
539 | return ConcreteOp::getFoldHookFn()(op, attrs, results); |
540 | } |
541 | void getCanonicalizationPatterns(RewritePatternSet &set, |
542 | MLIRContext *context) final { |
543 | ConcreteOp::getCanonicalizationPatterns(set, context); |
544 | } |
545 | bool hasTrait(TypeID id) final { return ConcreteOp::getHasTraitFn()(id); } |
546 | OperationName::ParseAssemblyFn getParseAssemblyFn() final { |
547 | return ConcreteOp::parse; |
548 | } |
549 | void populateDefaultAttrs(const OperationName &name, |
550 | NamedAttrList &attrs) final { |
551 | ConcreteOp::populateDefaultAttrs(name, attrs); |
552 | } |
553 | void printAssembly(Operation *op, OpAsmPrinter &printer, |
554 | StringRef name) final { |
555 | ConcreteOp::getPrintAssemblyFn()(op, printer, name); |
556 | } |
557 | LogicalResult verifyInvariants(Operation *op) final { |
558 | return ConcreteOp::getVerifyInvariantsFn()(op); |
559 | } |
560 | LogicalResult verifyRegionInvariants(Operation *op) final { |
561 | return ConcreteOp::getVerifyRegionInvariantsFn()(op); |
562 | } |
563 | |
564 | /// Implementation for "Properties" |
565 | |
566 | using Properties = std::remove_reference_t< |
567 | decltype(std::declval<ConcreteOp>().getProperties())>; |
568 | |
569 | std::optional<Attribute> getInherentAttr(Operation *op, |
570 | StringRef name) final { |
571 | if constexpr (hasProperties) { |
572 | auto concreteOp = cast<ConcreteOp>(op); |
573 | return ConcreteOp::getInherentAttr(concreteOp->getContext(), |
574 | concreteOp.getProperties(), name); |
575 | } |
576 | // If the op does not have support for properties, we dispatch back to the |
577 | // dictionnary of discardable attributes for now. |
578 | return cast<ConcreteOp>(op)->getDiscardableAttr(name); |
579 | } |
580 | void setInherentAttr(Operation *op, StringAttr name, |
581 | Attribute value) final { |
582 | if constexpr (hasProperties) { |
583 | auto concreteOp = cast<ConcreteOp>(op); |
584 | return ConcreteOp::setInherentAttr(concreteOp.getProperties(), name, |
585 | value); |
586 | } |
587 | // If the op does not have support for properties, we dispatch back to the |
588 | // dictionnary of discardable attributes for now. |
589 | return cast<ConcreteOp>(op)->setDiscardableAttr(name, value); |
590 | } |
591 | void populateInherentAttrs(Operation *op, NamedAttrList &attrs) final { |
592 | if constexpr (hasProperties) { |
593 | auto concreteOp = cast<ConcreteOp>(op); |
594 | ConcreteOp::populateInherentAttrs(concreteOp->getContext(), |
595 | concreteOp.getProperties(), attrs); |
596 | } |
597 | } |
598 | LogicalResult |
599 | verifyInherentAttrs(OperationName opName, NamedAttrList &attributes, |
600 | function_ref<InFlightDiagnostic()> emitError) final { |
601 | if constexpr (hasProperties) |
602 | return ConcreteOp::verifyInherentAttrs(opName, attributes, emitError); |
603 | return success(); |
604 | } |
605 | // Detect if the concrete operation defined properties. |
606 | static constexpr bool hasProperties = !std::is_same_v< |
607 | typename ConcreteOp::template InferredProperties<ConcreteOp>, |
608 | EmptyProperties>; |
609 | |
610 | int getOpPropertyByteSize() final { |
611 | if constexpr (hasProperties) |
612 | return sizeof(Properties); |
613 | return 0; |
614 | } |
615 | void initProperties(OperationName opName, OpaqueProperties storage, |
616 | OpaqueProperties init) final { |
617 | using Properties = |
618 | typename ConcreteOp::template InferredProperties<ConcreteOp>; |
619 | if (init) |
620 | new (storage.as<Properties *>()) Properties(*init.as<Properties *>()); |
621 | else |
622 | new (storage.as<Properties *>()) Properties(); |
623 | if constexpr (hasProperties) |
624 | ConcreteOp::populateDefaultProperties(opName, |
625 | *storage.as<Properties *>()); |
626 | } |
627 | void deleteProperties(OpaqueProperties prop) final { |
628 | prop.as<Properties *>()->~Properties(); |
629 | } |
630 | void populateDefaultProperties(OperationName opName, |
631 | OpaqueProperties properties) final { |
632 | if constexpr (hasProperties) |
633 | ConcreteOp::populateDefaultProperties(opName, |
634 | *properties.as<Properties *>()); |
635 | } |
636 | |
637 | LogicalResult |
638 | setPropertiesFromAttr(OperationName opName, OpaqueProperties properties, |
639 | Attribute attr, |
640 | function_ref<InFlightDiagnostic()> emitError) final { |
641 | if constexpr (hasProperties) { |
642 | auto p = properties.as<Properties *>(); |
643 | return ConcreteOp::setPropertiesFromAttr(*p, attr, emitError); |
644 | } |
645 | emitError() << "this operation does not support properties" ; |
646 | return failure(); |
647 | } |
648 | Attribute getPropertiesAsAttr(Operation *op) final { |
649 | if constexpr (hasProperties) { |
650 | auto concreteOp = cast<ConcreteOp>(op); |
651 | return ConcreteOp::getPropertiesAsAttr(concreteOp->getContext(), |
652 | concreteOp.getProperties()); |
653 | } |
654 | return {}; |
655 | } |
656 | bool compareProperties(OpaqueProperties lhs, OpaqueProperties rhs) final { |
657 | if constexpr (hasProperties) { |
658 | return *lhs.as<Properties *>() == *rhs.as<Properties *>(); |
659 | } else { |
660 | return true; |
661 | } |
662 | } |
663 | void copyProperties(OpaqueProperties lhs, OpaqueProperties rhs) final { |
664 | *lhs.as<Properties *>() = *rhs.as<Properties *>(); |
665 | } |
666 | llvm::hash_code hashProperties(OpaqueProperties prop) final { |
667 | if constexpr (hasProperties) |
668 | return ConcreteOp::computePropertiesHash(*prop.as<Properties *>()); |
669 | |
670 | return {}; |
671 | } |
672 | }; |
673 | |
674 | /// Lookup the registered operation information for the given operation. |
675 | /// Returns std::nullopt if the operation isn't registered. |
676 | static std::optional<RegisteredOperationName> lookup(StringRef name, |
677 | MLIRContext *ctx); |
678 | |
679 | /// Lookup the registered operation information for the given operation. |
680 | /// Returns std::nullopt if the operation isn't registered. |
681 | static std::optional<RegisteredOperationName> lookup(TypeID typeID, |
682 | MLIRContext *ctx); |
683 | |
684 | /// Register a new operation in a Dialect object. |
685 | /// This constructor is used by Dialect objects when they register the list |
686 | /// of operations they contain. |
687 | template <typename T> |
688 | static void insert(Dialect &dialect) { |
689 | insert(std::make_unique<Model<T>>(&dialect), T::getAttributeNames()); |
690 | } |
691 | /// The use of this method is in general discouraged in favor of |
692 | /// 'insert<CustomOp>(dialect)'. |
693 | static void insert(std::unique_ptr<OperationName::Impl> ownedImpl, |
694 | ArrayRef<StringRef> attrNames); |
695 | |
696 | /// Return the dialect this operation is registered to. |
697 | Dialect &getDialect() const { return *getImpl()->getDialect(); } |
698 | |
699 | /// Use the specified object to parse this ops custom assembly format. |
700 | ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const; |
701 | |
702 | /// Represent the operation name as an opaque pointer. (Used to support |
703 | /// PointerLikeTypeTraits). |
704 | static RegisteredOperationName getFromOpaquePointer(const void *pointer) { |
705 | return RegisteredOperationName( |
706 | const_cast<Impl *>(reinterpret_cast<const Impl *>(pointer))); |
707 | } |
708 | |
709 | private: |
710 | RegisteredOperationName(Impl *impl) : OperationName(impl) {} |
711 | |
712 | /// Allow access to the constructor. |
713 | friend OperationName; |
714 | }; |
715 | |
716 | inline std::optional<RegisteredOperationName> |
717 | OperationName::getRegisteredInfo() const { |
718 | return isRegistered() ? RegisteredOperationName(impl) |
719 | : std::optional<RegisteredOperationName>(); |
720 | } |
721 | |
722 | //===----------------------------------------------------------------------===// |
723 | // Attribute Dictionary-Like Interface |
724 | //===----------------------------------------------------------------------===// |
725 | |
726 | /// Attribute collections provide a dictionary-like interface. Define common |
727 | /// lookup functions. |
728 | namespace impl { |
729 | |
730 | /// Unsorted string search or identifier lookups are linear scans. |
731 | template <typename IteratorT, typename NameT> |
732 | std::pair<IteratorT, bool> findAttrUnsorted(IteratorT first, IteratorT last, |
733 | NameT name) { |
734 | for (auto it = first; it != last; ++it) |
735 | if (it->getName() == name) |
736 | return {it, true}; |
737 | return {last, false}; |
738 | } |
739 | |
740 | /// Using llvm::lower_bound requires an extra string comparison to check whether |
741 | /// the returned iterator points to the found element or whether it indicates |
742 | /// the lower bound. Skip this redundant comparison by checking if `compare == |
743 | /// 0` during the binary search. |
744 | template <typename IteratorT> |
745 | std::pair<IteratorT, bool> findAttrSorted(IteratorT first, IteratorT last, |
746 | StringRef name) { |
747 | ptrdiff_t length = std::distance(first, last); |
748 | |
749 | while (length > 0) { |
750 | ptrdiff_t half = length / 2; |
751 | IteratorT mid = first + half; |
752 | int compare = mid->getName().strref().compare(name); |
753 | if (compare < 0) { |
754 | first = mid + 1; |
755 | length = length - half - 1; |
756 | } else if (compare > 0) { |
757 | length = half; |
758 | } else { |
759 | return {mid, true}; |
760 | } |
761 | } |
762 | return {first, false}; |
763 | } |
764 | |
765 | /// StringAttr lookups on large attribute lists will switch to string binary |
766 | /// search. String binary searches become significantly faster than linear scans |
767 | /// with the identifier when the attribute list becomes very large. |
768 | template <typename IteratorT> |
769 | std::pair<IteratorT, bool> findAttrSorted(IteratorT first, IteratorT last, |
770 | StringAttr name) { |
771 | constexpr unsigned kSmallAttributeList = 16; |
772 | if (std::distance(first, last) > kSmallAttributeList) |
773 | return findAttrSorted(first, last, name.strref()); |
774 | return findAttrUnsorted(first, last, name); |
775 | } |
776 | |
777 | /// Get an attribute from a sorted range of named attributes. Returns null if |
778 | /// the attribute was not found. |
779 | template <typename IteratorT, typename NameT> |
780 | Attribute getAttrFromSortedRange(IteratorT first, IteratorT last, NameT name) { |
781 | std::pair<IteratorT, bool> result = findAttrSorted(first, last, name); |
782 | return result.second ? result.first->getValue() : Attribute(); |
783 | } |
784 | |
785 | /// Get an attribute from a sorted range of named attributes. Returns |
786 | /// std::nullopt if the attribute was not found. |
787 | template <typename IteratorT, typename NameT> |
788 | std::optional<NamedAttribute> |
789 | getNamedAttrFromSortedRange(IteratorT first, IteratorT last, NameT name) { |
790 | std::pair<IteratorT, bool> result = findAttrSorted(first, last, name); |
791 | return result.second ? *result.first : std::optional<NamedAttribute>(); |
792 | } |
793 | |
794 | } // namespace impl |
795 | |
796 | //===----------------------------------------------------------------------===// |
797 | // NamedAttrList |
798 | //===----------------------------------------------------------------------===// |
799 | |
800 | /// NamedAttrList is array of NamedAttributes that tracks whether it is sorted |
801 | /// and does some basic work to remain sorted. |
802 | class NamedAttrList { |
803 | public: |
804 | using iterator = SmallVectorImpl<NamedAttribute>::iterator; |
805 | using const_iterator = SmallVectorImpl<NamedAttribute>::const_iterator; |
806 | using reference = NamedAttribute &; |
807 | using const_reference = const NamedAttribute &; |
808 | using size_type = size_t; |
809 | |
810 | NamedAttrList() : dictionarySorted({}, true) {} |
811 | NamedAttrList(std::nullopt_t none) : NamedAttrList() {} |
812 | NamedAttrList(ArrayRef<NamedAttribute> attributes); |
813 | NamedAttrList(DictionaryAttr attributes); |
814 | NamedAttrList(const_iterator inStart, const_iterator inEnd); |
815 | |
816 | template <typename Container> |
817 | NamedAttrList(const Container &vec) |
818 | : NamedAttrList(ArrayRef<NamedAttribute>(vec)) {} |
819 | |
820 | bool operator!=(const NamedAttrList &other) const { |
821 | return !(*this == other); |
822 | } |
823 | bool operator==(const NamedAttrList &other) const { |
824 | return attrs == other.attrs; |
825 | } |
826 | |
827 | /// Add an attribute with the specified name. |
828 | void append(StringRef name, Attribute attr); |
829 | |
830 | /// Add an attribute with the specified name. |
831 | void append(StringAttr name, Attribute attr) { |
832 | append(attr: NamedAttribute(name, attr)); |
833 | } |
834 | |
835 | /// Append the given named attribute. |
836 | void append(NamedAttribute attr) { push_back(newAttribute: attr); } |
837 | |
838 | /// Add an array of named attributes. |
839 | template <typename RangeT> |
840 | void append(RangeT &&newAttributes) { |
841 | append(std::begin(newAttributes), std::end(newAttributes)); |
842 | } |
843 | |
844 | /// Add a range of named attributes. |
845 | template <typename IteratorT, |
846 | typename = std::enable_if_t<std::is_convertible< |
847 | typename std::iterator_traits<IteratorT>::iterator_category, |
848 | std::input_iterator_tag>::value>> |
849 | void append(IteratorT inStart, IteratorT inEnd) { |
850 | // TODO: expand to handle case where values appended are in order & after |
851 | // end of current list. |
852 | dictionarySorted.setPointerAndInt(PtrVal: nullptr, IntVal: false); |
853 | attrs.append(inStart, inEnd); |
854 | } |
855 | |
856 | /// Replaces the attributes with new list of attributes. |
857 | void assign(const_iterator inStart, const_iterator inEnd); |
858 | |
859 | /// Replaces the attributes with new list of attributes. |
860 | void assign(ArrayRef<NamedAttribute> range) { |
861 | assign(inStart: range.begin(), inEnd: range.end()); |
862 | } |
863 | |
864 | void clear() { |
865 | attrs.clear(); |
866 | dictionarySorted.setPointerAndInt(PtrVal: nullptr, IntVal: false); |
867 | } |
868 | |
869 | bool empty() const { return attrs.empty(); } |
870 | |
871 | void reserve(size_type N) { attrs.reserve(N); } |
872 | |
873 | /// Add an attribute with the specified name. |
874 | void push_back(NamedAttribute newAttribute); |
875 | |
876 | /// Pop last element from list. |
877 | void pop_back() { attrs.pop_back(); } |
878 | |
879 | /// Returns an entry with a duplicate name the list, if it exists, else |
880 | /// returns std::nullopt. |
881 | std::optional<NamedAttribute> findDuplicate() const; |
882 | |
883 | /// Return a dictionary attribute for the underlying dictionary. This will |
884 | /// return an empty dictionary attribute if empty rather than null. |
885 | DictionaryAttr getDictionary(MLIRContext *context) const; |
886 | |
887 | /// Return all of the attributes on this operation. |
888 | ArrayRef<NamedAttribute> getAttrs() const; |
889 | |
890 | /// Return the specified attribute if present, null otherwise. |
891 | Attribute get(StringAttr name) const; |
892 | Attribute get(StringRef name) const; |
893 | |
894 | /// Return the specified named attribute if present, std::nullopt otherwise. |
895 | std::optional<NamedAttribute> getNamed(StringRef name) const; |
896 | std::optional<NamedAttribute> getNamed(StringAttr name) const; |
897 | |
898 | /// If the an attribute exists with the specified name, change it to the new |
899 | /// value. Otherwise, add a new attribute with the specified name/value. |
900 | /// Returns the previous attribute value of `name`, or null if no |
901 | /// attribute previously existed with `name`. |
902 | Attribute set(StringAttr name, Attribute value); |
903 | Attribute set(StringRef name, Attribute value); |
904 | |
905 | /// Erase the attribute with the given name from the list. Return the |
906 | /// attribute that was erased, or nullptr if there was no attribute with such |
907 | /// name. |
908 | Attribute erase(StringAttr name); |
909 | Attribute erase(StringRef name); |
910 | |
911 | iterator begin() { return attrs.begin(); } |
912 | iterator end() { return attrs.end(); } |
913 | const_iterator begin() const { return attrs.begin(); } |
914 | const_iterator end() const { return attrs.end(); } |
915 | |
916 | NamedAttrList &operator=(const SmallVectorImpl<NamedAttribute> &rhs); |
917 | operator ArrayRef<NamedAttribute>() const; |
918 | |
919 | private: |
920 | /// Return whether the attributes are sorted. |
921 | bool isSorted() const { return dictionarySorted.getInt(); } |
922 | |
923 | /// Erase the attribute at the given iterator position. |
924 | Attribute eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it); |
925 | |
926 | /// Lookup an attribute in the list. |
927 | template <typename AttrListT, typename NameT> |
928 | static auto findAttr(AttrListT &attrs, NameT name) { |
929 | return attrs.isSorted() |
930 | ? impl::findAttrSorted(attrs.begin(), attrs.end(), name) |
931 | : impl::findAttrUnsorted(attrs.begin(), attrs.end(), name); |
932 | } |
933 | |
934 | // These are marked mutable as they may be modified (e.g., sorted) |
935 | mutable SmallVector<NamedAttribute, 4> attrs; |
936 | // Pair with cached DictionaryAttr and status of whether attrs is sorted. |
937 | // Note: just because sorted does not mean a DictionaryAttr has been created |
938 | // but the case where there is a DictionaryAttr but attrs isn't sorted should |
939 | // not occur. |
940 | mutable llvm::PointerIntPair<Attribute, 1, bool> dictionarySorted; |
941 | }; |
942 | |
943 | //===----------------------------------------------------------------------===// |
944 | // OperationState |
945 | //===----------------------------------------------------------------------===// |
946 | |
947 | /// This represents an operation in an abstracted form, suitable for use with |
948 | /// the builder APIs. This object is a large and heavy weight object meant to |
949 | /// be used as a temporary object on the stack. It is generally unwise to put |
950 | /// this in a collection. |
951 | struct OperationState { |
952 | Location location; |
953 | OperationName name; |
954 | SmallVector<Value, 4> operands; |
955 | /// Types of the results of this operation. |
956 | SmallVector<Type, 4> types; |
957 | NamedAttrList attributes; |
958 | /// Successors of this operation and their respective operands. |
959 | SmallVector<Block *, 1> successors; |
960 | /// Regions that the op will hold. |
961 | SmallVector<std::unique_ptr<Region>, 1> regions; |
962 | |
963 | /// This Attribute is used to opaquely construct the properties of the |
964 | /// operation. If we're creating an unregistered operation, the Attribute is |
965 | /// used as-is as the Properties storage of the operation. Otherwise, the |
966 | /// operation properties are constructed opaquely using its |
967 | /// `setPropertiesFromAttr` hook. Note that `getOrAddProperties` is the |
968 | /// preferred method to construct properties from C++. |
969 | Attribute propertiesAttr; |
970 | |
971 | private: |
972 | OpaqueProperties properties = nullptr; |
973 | TypeID propertiesId; |
974 | llvm::function_ref<void(OpaqueProperties)> propertiesDeleter; |
975 | llvm::function_ref<void(OpaqueProperties, const OpaqueProperties)> |
976 | propertiesSetter; |
977 | friend class Operation; |
978 | |
979 | public: |
980 | OperationState(Location location, StringRef name); |
981 | OperationState(Location location, OperationName name); |
982 | |
983 | OperationState(Location location, OperationName name, ValueRange operands, |
984 | TypeRange types, ArrayRef<NamedAttribute> attributes = {}, |
985 | BlockRange successors = {}, |
986 | MutableArrayRef<std::unique_ptr<Region>> regions = {}); |
987 | OperationState(Location location, StringRef name, ValueRange operands, |
988 | TypeRange types, ArrayRef<NamedAttribute> attributes = {}, |
989 | BlockRange successors = {}, |
990 | MutableArrayRef<std::unique_ptr<Region>> regions = {}); |
991 | OperationState(OperationState &&other) = default; |
992 | OperationState(const OperationState &other) = default; |
993 | OperationState &operator=(OperationState &&other) = default; |
994 | OperationState &operator=(const OperationState &other) = default; |
995 | ~OperationState(); |
996 | |
997 | /// Get (or create) a properties of the provided type to be set on the |
998 | /// operation on creation. |
999 | template <typename T> |
1000 | T &getOrAddProperties() { |
1001 | if (!properties) { |
1002 | T *p = new T{}; |
1003 | properties = p; |
1004 | propertiesDeleter = [](OpaqueProperties prop) { |
1005 | delete prop.as<const T *>(); |
1006 | }; |
1007 | propertiesSetter = [](OpaqueProperties new_prop, |
1008 | const OpaqueProperties prop) { |
1009 | *new_prop.as<T *>() = *prop.as<const T *>(); |
1010 | }; |
1011 | propertiesId = TypeID::get<T>(); |
1012 | } |
1013 | assert(propertiesId == TypeID::get<T>() && "Inconsistent properties" ); |
1014 | return *properties.as<T *>(); |
1015 | } |
1016 | OpaqueProperties getRawProperties() { return properties; } |
1017 | |
1018 | // Set the properties defined on this OpState on the given operation, |
1019 | // optionally emit diagnostics on error through the provided diagnostic. |
1020 | LogicalResult |
1021 | setProperties(Operation *op, |
1022 | function_ref<InFlightDiagnostic()> emitError) const; |
1023 | |
1024 | void addOperands(ValueRange newOperands); |
1025 | |
1026 | void addTypes(ArrayRef<Type> newTypes) { |
1027 | types.append(in_start: newTypes.begin(), in_end: newTypes.end()); |
1028 | } |
1029 | template <typename RangeT> |
1030 | std::enable_if_t<!std::is_convertible<RangeT, ArrayRef<Type>>::value> |
1031 | addTypes(RangeT &&newTypes) { |
1032 | types.append(newTypes.begin(), newTypes.end()); |
1033 | } |
1034 | |
1035 | /// Add an attribute with the specified name. |
1036 | void addAttribute(StringRef name, Attribute attr) { |
1037 | addAttribute(StringAttr::get(getContext(), name), attr); |
1038 | } |
1039 | |
1040 | /// Add an attribute with the specified name. |
1041 | void addAttribute(StringAttr name, Attribute attr) { |
1042 | attributes.append(name, attr); |
1043 | } |
1044 | |
1045 | /// Add an array of named attributes. |
1046 | void addAttributes(ArrayRef<NamedAttribute> newAttributes) { |
1047 | attributes.append(newAttributes); |
1048 | } |
1049 | |
1050 | void addSuccessors(Block *successor) { successors.push_back(Elt: successor); } |
1051 | void addSuccessors(BlockRange newSuccessors); |
1052 | |
1053 | /// Create a region that should be attached to the operation. These regions |
1054 | /// can be filled in immediately without waiting for Operation to be |
1055 | /// created. When it is, the region bodies will be transferred. |
1056 | Region *addRegion(); |
1057 | |
1058 | /// Take a region that should be attached to the Operation. The body of the |
1059 | /// region will be transferred when the Operation is constructed. If the |
1060 | /// region is null, a new empty region will be attached to the Operation. |
1061 | void addRegion(std::unique_ptr<Region> &®ion); |
1062 | |
1063 | /// Take ownership of a set of regions that should be attached to the |
1064 | /// Operation. |
1065 | void addRegions(MutableArrayRef<std::unique_ptr<Region>> regions); |
1066 | |
1067 | /// Get the context held by this operation state. |
1068 | MLIRContext *getContext() const { return location->getContext(); } |
1069 | }; |
1070 | |
1071 | //===----------------------------------------------------------------------===// |
1072 | // OperandStorage |
1073 | //===----------------------------------------------------------------------===// |
1074 | |
1075 | namespace detail { |
1076 | /// This class handles the management of operation operands. Operands are |
1077 | /// stored either in a trailing array, or a dynamically resizable vector. |
1078 | class alignas(8) OperandStorage { |
1079 | public: |
1080 | OperandStorage(Operation *owner, OpOperand *trailingOperands, |
1081 | ValueRange values); |
1082 | ~OperandStorage(); |
1083 | |
1084 | /// Replace the operands contained in the storage with the ones provided in |
1085 | /// 'values'. |
1086 | void setOperands(Operation *owner, ValueRange values); |
1087 | |
1088 | /// Replace the operands beginning at 'start' and ending at 'start' + 'length' |
1089 | /// with the ones provided in 'operands'. 'operands' may be smaller or larger |
1090 | /// than the range pointed to by 'start'+'length'. |
1091 | void setOperands(Operation *owner, unsigned start, unsigned length, |
1092 | ValueRange operands); |
1093 | |
1094 | /// Erase the operands held by the storage within the given range. |
1095 | void eraseOperands(unsigned start, unsigned length); |
1096 | |
1097 | /// Erase the operands held by the storage that have their corresponding bit |
1098 | /// set in `eraseIndices`. |
1099 | void eraseOperands(const BitVector &eraseIndices); |
1100 | |
1101 | /// Get the operation operands held by the storage. |
1102 | MutableArrayRef<OpOperand> getOperands() { return {operandStorage, size()}; } |
1103 | |
1104 | /// Return the number of operands held in the storage. |
1105 | unsigned size() { return numOperands; } |
1106 | |
1107 | private: |
1108 | /// Resize the storage to the given size. Returns the array containing the new |
1109 | /// operands. |
1110 | MutableArrayRef<OpOperand> resize(Operation *owner, unsigned newSize); |
1111 | |
1112 | /// The total capacity number of operands that the storage can hold. |
1113 | unsigned capacity : 31; |
1114 | /// A flag indicating if the operand storage was dynamically allocated, as |
1115 | /// opposed to inlined into the owning operation. |
1116 | unsigned isStorageDynamic : 1; |
1117 | /// The number of operands within the storage. |
1118 | unsigned numOperands; |
1119 | /// A pointer to the operand storage. |
1120 | OpOperand *operandStorage; |
1121 | }; |
1122 | } // namespace detail |
1123 | |
1124 | //===----------------------------------------------------------------------===// |
1125 | // OpPrintingFlags |
1126 | //===----------------------------------------------------------------------===// |
1127 | |
1128 | /// Set of flags used to control the behavior of the various IR print methods |
1129 | /// (e.g. Operation::Print). |
1130 | class OpPrintingFlags { |
1131 | public: |
1132 | OpPrintingFlags(); |
1133 | OpPrintingFlags(std::nullopt_t) : OpPrintingFlags() {} |
1134 | |
1135 | /// Enables the elision of large elements attributes by printing a lexically |
1136 | /// valid but otherwise meaningless form instead of the element data. The |
1137 | /// `largeElementLimit` is used to configure what is considered to be a |
1138 | /// "large" ElementsAttr by providing an upper limit to the number of |
1139 | /// elements. |
1140 | OpPrintingFlags &elideLargeElementsAttrs(int64_t largeElementLimit = 16); |
1141 | |
1142 | /// Enables the printing of large element attributes with a hex string. The |
1143 | /// `largeElementLimit` is used to configure what is considered to be a |
1144 | /// "large" ElementsAttr by providing an upper limit to the number of |
1145 | /// elements. Use -1 to disable the hex printing. |
1146 | OpPrintingFlags & |
1147 | printLargeElementsAttrWithHex(int64_t largeElementLimit = 100); |
1148 | |
1149 | /// Enables the elision of large resources strings by omitting them from the |
1150 | /// `dialect_resources` section. The `largeResourceLimit` is used to configure |
1151 | /// what is considered to be a "large" resource by providing an upper limit to |
1152 | /// the string size. |
1153 | OpPrintingFlags &elideLargeResourceString(int64_t largeResourceLimit = 64); |
1154 | |
1155 | /// Enable or disable printing of debug information (based on `enable`). If |
1156 | /// 'prettyForm' is set to true, debug information is printed in a more |
1157 | /// readable 'pretty' form. Note: The IR generated with 'prettyForm' is not |
1158 | /// parsable. |
1159 | OpPrintingFlags &enableDebugInfo(bool enable = true, bool prettyForm = false); |
1160 | |
1161 | /// Always print operations in the generic form. |
1162 | OpPrintingFlags &printGenericOpForm(bool enable = true); |
1163 | |
1164 | /// Skip printing regions. |
1165 | OpPrintingFlags &skipRegions(bool skip = true); |
1166 | |
1167 | /// Do not verify the operation when using custom operation printers. |
1168 | OpPrintingFlags &assumeVerified(); |
1169 | |
1170 | /// Use local scope when printing the operation. This allows for using the |
1171 | /// printer in a more localized and thread-safe setting, but may not |
1172 | /// necessarily be identical to what the IR will look like when dumping |
1173 | /// the full module. |
1174 | OpPrintingFlags &useLocalScope(); |
1175 | |
1176 | /// Print users of values as comments. |
1177 | OpPrintingFlags &printValueUsers(); |
1178 | |
1179 | /// Return if the given ElementsAttr should be elided. |
1180 | bool shouldElideElementsAttr(ElementsAttr attr) const; |
1181 | |
1182 | /// Return if the given ElementsAttr should be printed as hex string. |
1183 | bool shouldPrintElementsAttrWithHex(ElementsAttr attr) const; |
1184 | |
1185 | /// Return the size limit for printing large ElementsAttr. |
1186 | std::optional<int64_t> getLargeElementsAttrLimit() const; |
1187 | |
1188 | /// Return the size limit for printing large ElementsAttr as hex string. |
1189 | int64_t getLargeElementsAttrHexLimit() const; |
1190 | |
1191 | /// Return the size limit in chars for printing large resources. |
1192 | std::optional<uint64_t> getLargeResourceStringLimit() const; |
1193 | |
1194 | /// Return if debug information should be printed. |
1195 | bool shouldPrintDebugInfo() const; |
1196 | |
1197 | /// Return if debug information should be printed in the pretty form. |
1198 | bool shouldPrintDebugInfoPrettyForm() const; |
1199 | |
1200 | /// Return if operations should be printed in the generic form. |
1201 | bool shouldPrintGenericOpForm() const; |
1202 | |
1203 | /// Return if regions should be skipped. |
1204 | bool shouldSkipRegions() const; |
1205 | |
1206 | /// Return if operation verification should be skipped. |
1207 | bool shouldAssumeVerified() const; |
1208 | |
1209 | /// Return if the printer should use local scope when dumping the IR. |
1210 | bool shouldUseLocalScope() const; |
1211 | |
1212 | /// Return if the printer should print users of values. |
1213 | bool shouldPrintValueUsers() const; |
1214 | |
1215 | private: |
1216 | /// Elide large elements attributes if the number of elements is larger than |
1217 | /// the upper limit. |
1218 | std::optional<int64_t> elementsAttrElementLimit; |
1219 | |
1220 | /// Elide printing large resources based on size of string. |
1221 | std::optional<uint64_t> resourceStringCharLimit; |
1222 | |
1223 | /// Print large element attributes with hex strings if the number of elements |
1224 | /// is larger than the upper limit. |
1225 | int64_t elementsAttrHexElementLimit = 100; |
1226 | |
1227 | /// Print debug information. |
1228 | bool printDebugInfoFlag : 1; |
1229 | bool printDebugInfoPrettyFormFlag : 1; |
1230 | |
1231 | /// Print operations in the generic form. |
1232 | bool printGenericOpFormFlag : 1; |
1233 | |
1234 | /// Always skip Regions. |
1235 | bool skipRegionsFlag : 1; |
1236 | |
1237 | /// Skip operation verification. |
1238 | bool assumeVerifiedFlag : 1; |
1239 | |
1240 | /// Print operations with numberings local to the current operation. |
1241 | bool printLocalScope : 1; |
1242 | |
1243 | /// Print users of values. |
1244 | bool printValueUsersFlag : 1; |
1245 | }; |
1246 | |
1247 | //===----------------------------------------------------------------------===// |
1248 | // Operation Equivalency |
1249 | //===----------------------------------------------------------------------===// |
1250 | |
1251 | /// This class provides utilities for computing if two operations are |
1252 | /// equivalent. |
1253 | struct OperationEquivalence { |
1254 | enum Flags { |
1255 | None = 0, |
1256 | |
1257 | // When provided, the location attached to the operation are ignored. |
1258 | IgnoreLocations = 1, |
1259 | |
1260 | LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreLocations) |
1261 | }; |
1262 | |
1263 | /// Compute a hash for the given operation. |
1264 | /// The `hashOperands` and `hashResults` callbacks are expected to return a |
1265 | /// unique hash_code for a given Value. |
1266 | static llvm::hash_code computeHash( |
1267 | Operation *op, |
1268 | function_ref<llvm::hash_code(Value)> hashOperands = |
1269 | [](Value v) { return hash_value(arg: v); }, |
1270 | function_ref<llvm::hash_code(Value)> hashResults = |
1271 | [](Value v) { return hash_value(arg: v); }, |
1272 | Flags flags = Flags::None); |
1273 | |
1274 | /// Helper that can be used with `computeHash` above to ignore operation |
1275 | /// operands/result mapping. |
1276 | static llvm::hash_code ignoreHashValue(Value) { return llvm::hash_code{}; } |
1277 | /// Helper that can be used with `computeHash` above to ignore operation |
1278 | /// operands/result mapping. |
1279 | static llvm::hash_code directHashValue(Value v) { return hash_value(arg: v); } |
1280 | |
1281 | /// Compare two operations (including their regions) and return if they are |
1282 | /// equivalent. |
1283 | /// |
1284 | /// * `checkEquivalent` is a callback to check if two values are equivalent. |
1285 | /// For two operations to be equivalent, their operands must be the same SSA |
1286 | /// value or this callback must return `success`. |
1287 | /// * `markEquivalent` is a callback to inform the caller that the analysis |
1288 | /// determined that two values are equivalent. |
1289 | /// * `checkCommutativeEquivalent` is an optional callback to check for |
1290 | /// equivalence across two ranges for a commutative operation. If not passed |
1291 | /// in, then equivalence is checked pairwise. This callback is needed to be |
1292 | /// able to query the optional equivalence classes. |
1293 | /// |
1294 | /// Note: Additional information regarding value equivalence can be injected |
1295 | /// into the analysis via `checkEquivalent`. Typically, callers may want |
1296 | /// values that were determined to be equivalent as per `markEquivalent` to be |
1297 | /// reflected in `checkEquivalent`, unless `exactValueMatch` or a different |
1298 | /// equivalence relationship is desired. |
1299 | static bool |
1300 | isEquivalentTo(Operation *lhs, Operation *rhs, |
1301 | function_ref<LogicalResult(Value, Value)> checkEquivalent, |
1302 | function_ref<void(Value, Value)> markEquivalent = nullptr, |
1303 | Flags flags = Flags::None, |
1304 | function_ref<LogicalResult(ValueRange, ValueRange)> |
1305 | checkCommutativeEquivalent = nullptr); |
1306 | |
1307 | /// Compare two operations and return if they are equivalent. |
1308 | static bool isEquivalentTo(Operation *lhs, Operation *rhs, Flags flags); |
1309 | |
1310 | /// Compare two regions (including their subregions) and return if they are |
1311 | /// equivalent. See also `isEquivalentTo` for details. |
1312 | static bool isRegionEquivalentTo( |
1313 | Region *lhs, Region *rhs, |
1314 | function_ref<LogicalResult(Value, Value)> checkEquivalent, |
1315 | function_ref<void(Value, Value)> markEquivalent, |
1316 | OperationEquivalence::Flags flags, |
1317 | function_ref<LogicalResult(ValueRange, ValueRange)> |
1318 | checkCommutativeEquivalent = nullptr); |
1319 | |
1320 | /// Compare two regions and return if they are equivalent. |
1321 | static bool isRegionEquivalentTo(Region *lhs, Region *rhs, |
1322 | OperationEquivalence::Flags flags); |
1323 | |
1324 | /// Helper that can be used with `isEquivalentTo` above to consider ops |
1325 | /// equivalent even if their operands are not equivalent. |
1326 | static LogicalResult ignoreValueEquivalence(Value lhs, Value rhs) { |
1327 | return success(); |
1328 | } |
1329 | /// Helper that can be used with `isEquivalentTo` above to consider ops |
1330 | /// equivalent only if their operands are the exact same SSA values. |
1331 | static LogicalResult exactValueMatch(Value lhs, Value rhs) { |
1332 | return success(isSuccess: lhs == rhs); |
1333 | } |
1334 | }; |
1335 | |
1336 | /// Enable Bitmask enums for OperationEquivalence::Flags. |
1337 | LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE(); |
1338 | |
1339 | //===----------------------------------------------------------------------===// |
1340 | // OperationFingerPrint |
1341 | //===----------------------------------------------------------------------===// |
1342 | |
1343 | /// A unique fingerprint for a specific operation, and all of it's internal |
1344 | /// operations (if `includeNested` is set). |
1345 | class OperationFingerPrint { |
1346 | public: |
1347 | OperationFingerPrint(Operation *topOp, bool includeNested = true); |
1348 | OperationFingerPrint(const OperationFingerPrint &) = default; |
1349 | OperationFingerPrint &operator=(const OperationFingerPrint &) = default; |
1350 | |
1351 | bool operator==(const OperationFingerPrint &other) const { |
1352 | return hash == other.hash; |
1353 | } |
1354 | bool operator!=(const OperationFingerPrint &other) const { |
1355 | return !(*this == other); |
1356 | } |
1357 | |
1358 | private: |
1359 | std::array<uint8_t, 20> hash; |
1360 | }; |
1361 | |
1362 | } // namespace mlir |
1363 | |
1364 | namespace llvm { |
1365 | template <> |
1366 | struct DenseMapInfo<mlir::OperationName> { |
1367 | static mlir::OperationName getEmptyKey() { |
1368 | void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); |
1369 | return mlir::OperationName::getFromOpaquePointer(pointer); |
1370 | } |
1371 | static mlir::OperationName getTombstoneKey() { |
1372 | void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); |
1373 | return mlir::OperationName::getFromOpaquePointer(pointer); |
1374 | } |
1375 | static unsigned getHashValue(mlir::OperationName val) { |
1376 | return DenseMapInfo<void *>::getHashValue(PtrVal: val.getAsOpaquePointer()); |
1377 | } |
1378 | static bool isEqual(mlir::OperationName lhs, mlir::OperationName rhs) { |
1379 | return lhs == rhs; |
1380 | } |
1381 | }; |
1382 | template <> |
1383 | struct DenseMapInfo<mlir::RegisteredOperationName> |
1384 | : public DenseMapInfo<mlir::OperationName> { |
1385 | static mlir::RegisteredOperationName getEmptyKey() { |
1386 | void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); |
1387 | return mlir::RegisteredOperationName::getFromOpaquePointer(pointer); |
1388 | } |
1389 | static mlir::RegisteredOperationName getTombstoneKey() { |
1390 | void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); |
1391 | return mlir::RegisteredOperationName::getFromOpaquePointer(pointer); |
1392 | } |
1393 | }; |
1394 | |
1395 | template <> |
1396 | struct PointerLikeTypeTraits<mlir::OperationName> { |
1397 | static inline void *getAsVoidPointer(mlir::OperationName I) { |
1398 | return const_cast<void *>(I.getAsOpaquePointer()); |
1399 | } |
1400 | static inline mlir::OperationName getFromVoidPointer(void *P) { |
1401 | return mlir::OperationName::getFromOpaquePointer(pointer: P); |
1402 | } |
1403 | static constexpr int NumLowBitsAvailable = |
1404 | PointerLikeTypeTraits<void *>::NumLowBitsAvailable; |
1405 | }; |
1406 | template <> |
1407 | struct PointerLikeTypeTraits<mlir::RegisteredOperationName> |
1408 | : public PointerLikeTypeTraits<mlir::OperationName> { |
1409 | static inline mlir::RegisteredOperationName getFromVoidPointer(void *P) { |
1410 | return mlir::RegisteredOperationName::getFromOpaquePointer(pointer: P); |
1411 | } |
1412 | }; |
1413 | |
1414 | } // namespace llvm |
1415 | |
1416 | #endif |
1417 | |