| 1 | //===- PDLPatternMatch.h - PDLPatternMatcher classes -------==---*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #ifndef MLIR_IR_PDLPATTERNMATCH_H |
| 10 | #define MLIR_IR_PDLPATTERNMATCH_H |
| 11 | |
| 12 | #include "mlir/Config/mlir-config.h" |
| 13 | |
| 14 | #if MLIR_ENABLE_PDL_IN_PATTERNMATCH |
| 15 | #include "mlir/IR/Builders.h" |
| 16 | #include "mlir/IR/BuiltinOps.h" |
| 17 | |
| 18 | namespace mlir { |
| 19 | //===----------------------------------------------------------------------===// |
| 20 | // PDL Patterns |
| 21 | //===----------------------------------------------------------------------===// |
| 22 | |
| 23 | //===----------------------------------------------------------------------===// |
| 24 | // PDLValue |
| 25 | |
| 26 | /// Storage type of byte-code interpreter values. These are passed to constraint |
| 27 | /// functions as arguments. |
| 28 | class PDLValue { |
| 29 | public: |
| 30 | /// The underlying kind of a PDL value. |
| 31 | enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange }; |
| 32 | |
| 33 | /// Construct a new PDL value. |
| 34 | PDLValue(const PDLValue &other) = default; |
| 35 | PDLValue(std::nullptr_t = nullptr) {} |
| 36 | PDLValue(Attribute value) |
| 37 | : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {} |
| 38 | PDLValue(Operation *value) : value(value), kind(Kind::Operation) {} |
| 39 | PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {} |
| 40 | PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {} |
| 41 | PDLValue(Value value) |
| 42 | : value(value.getAsOpaquePointer()), kind(Kind::Value) {} |
| 43 | PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {} |
| 44 | |
| 45 | /// Returns true if the type of the held value is `T`. |
| 46 | template <typename T> |
| 47 | bool isa() const { |
| 48 | assert(value && "isa<> used on a null value" ); |
| 49 | return kind == getKindOf<T>(); |
| 50 | } |
| 51 | |
| 52 | /// Attempt to dynamically cast this value to type `T`, returns null if this |
| 53 | /// value is not an instance of `T`. |
| 54 | template <typename T, |
| 55 | typename ResultT = std::conditional_t< |
| 56 | std::is_convertible<T, bool>::value, T, std::optional<T>>> |
| 57 | ResultT dyn_cast() const { |
| 58 | return isa<T>() ? castImpl<T>() : ResultT(); |
| 59 | } |
| 60 | |
| 61 | /// Cast this value to type `T`, asserts if this value is not an instance of |
| 62 | /// `T`. |
| 63 | template <typename T> |
| 64 | T cast() const { |
| 65 | assert(isa<T>() && "expected value to be of type `T`" ); |
| 66 | return castImpl<T>(); |
| 67 | } |
| 68 | |
| 69 | /// Get an opaque pointer to the value. |
| 70 | const void *getAsOpaquePointer() const { return value; } |
| 71 | |
| 72 | /// Return if this value is null or not. |
| 73 | explicit operator bool() const { return value; } |
| 74 | |
| 75 | /// Return the kind of this value. |
| 76 | Kind getKind() const { return kind; } |
| 77 | |
| 78 | /// Print this value to the provided output stream. |
| 79 | void print(raw_ostream &os) const; |
| 80 | |
| 81 | /// Print the specified value kind to an output stream. |
| 82 | static void print(raw_ostream &os, Kind kind); |
| 83 | |
| 84 | private: |
| 85 | /// Find the index of a given type in a range of other types. |
| 86 | template <typename...> |
| 87 | struct index_of_t; |
| 88 | template <typename T, typename... R> |
| 89 | struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {}; |
| 90 | template <typename T, typename F, typename... R> |
| 91 | struct index_of_t<T, F, R...> |
| 92 | : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {}; |
| 93 | |
| 94 | /// Return the kind used for the given T. |
| 95 | template <typename T> |
| 96 | static Kind getKindOf() { |
| 97 | return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type, |
| 98 | TypeRange, Value, ValueRange>::value); |
| 99 | } |
| 100 | |
| 101 | /// The internal implementation of `cast`, that returns the underlying value |
| 102 | /// as the given type `T`. |
| 103 | template <typename T> |
| 104 | std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T> |
| 105 | castImpl() const { |
| 106 | return T::getFromOpaquePointer(value); |
| 107 | } |
| 108 | template <typename T> |
| 109 | std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T> |
| 110 | castImpl() const { |
| 111 | return *reinterpret_cast<T *>(const_cast<void *>(value)); |
| 112 | } |
| 113 | template <typename T> |
| 114 | std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const { |
| 115 | return reinterpret_cast<T>(const_cast<void *>(value)); |
| 116 | } |
| 117 | |
| 118 | /// The internal opaque representation of a PDLValue. |
| 119 | const void *value{nullptr}; |
| 120 | /// The kind of the opaque value. |
| 121 | Kind kind{Kind::Attribute}; |
| 122 | }; |
| 123 | |
| 124 | inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) { |
| 125 | value.print(os); |
| 126 | return os; |
| 127 | } |
| 128 | |
| 129 | inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) { |
| 130 | PDLValue::print(os, kind); |
| 131 | return os; |
| 132 | } |
| 133 | |
| 134 | //===----------------------------------------------------------------------===// |
| 135 | // PDLResultList |
| 136 | |
| 137 | /// The class represents a list of PDL results, returned by a native rewrite |
| 138 | /// method. It provides the mechanism with which to pass PDLValues back to the |
| 139 | /// PDL bytecode. |
| 140 | class PDLResultList { |
| 141 | public: |
| 142 | /// Push a new Attribute value onto the result list. |
| 143 | void push_back(Attribute value) { results.push_back(Elt: value); } |
| 144 | |
| 145 | /// Push a new Operation onto the result list. |
| 146 | void push_back(Operation *value) { results.push_back(Elt: value); } |
| 147 | |
| 148 | /// Push a new Type onto the result list. |
| 149 | void push_back(Type value) { results.push_back(Elt: value); } |
| 150 | |
| 151 | /// Push a new TypeRange onto the result list. |
| 152 | void push_back(TypeRange value) { |
| 153 | // The lifetime of a TypeRange can't be guaranteed, so we'll need to |
| 154 | // allocate a storage for it. |
| 155 | llvm::OwningArrayRef<Type> storage(value.size()); |
| 156 | llvm::copy(Range&: value, Out: storage.begin()); |
| 157 | allocatedTypeRanges.emplace_back(Args: std::move(storage)); |
| 158 | typeRanges.push_back(Elt: allocatedTypeRanges.back()); |
| 159 | results.push_back(Elt: &typeRanges.back()); |
| 160 | } |
| 161 | void push_back(ValueTypeRange<OperandRange> value) { |
| 162 | typeRanges.push_back(Elt: value); |
| 163 | results.push_back(Elt: &typeRanges.back()); |
| 164 | } |
| 165 | void push_back(ValueTypeRange<ResultRange> value) { |
| 166 | typeRanges.push_back(Elt: value); |
| 167 | results.push_back(Elt: &typeRanges.back()); |
| 168 | } |
| 169 | |
| 170 | /// Push a new Value onto the result list. |
| 171 | void push_back(Value value) { results.push_back(Elt: value); } |
| 172 | |
| 173 | /// Push a new ValueRange onto the result list. |
| 174 | void push_back(ValueRange value) { |
| 175 | // The lifetime of a ValueRange can't be guaranteed, so we'll need to |
| 176 | // allocate a storage for it. |
| 177 | llvm::OwningArrayRef<Value> storage(value.size()); |
| 178 | llvm::copy(Range&: value, Out: storage.begin()); |
| 179 | allocatedValueRanges.emplace_back(Args: std::move(storage)); |
| 180 | valueRanges.push_back(Elt: allocatedValueRanges.back()); |
| 181 | results.push_back(Elt: &valueRanges.back()); |
| 182 | } |
| 183 | void push_back(OperandRange value) { |
| 184 | valueRanges.push_back(Elt: value); |
| 185 | results.push_back(Elt: &valueRanges.back()); |
| 186 | } |
| 187 | void push_back(ResultRange value) { |
| 188 | valueRanges.push_back(Elt: value); |
| 189 | results.push_back(Elt: &valueRanges.back()); |
| 190 | } |
| 191 | |
| 192 | protected: |
| 193 | /// Create a new result list with the expected number of results. |
| 194 | PDLResultList(unsigned maxNumResults) { |
| 195 | // For now just reserve enough space for all of the results. We could do |
| 196 | // separate counts per range type, but it isn't really worth it unless there |
| 197 | // are a "large" number of results. |
| 198 | typeRanges.reserve(N: maxNumResults); |
| 199 | valueRanges.reserve(N: maxNumResults); |
| 200 | } |
| 201 | |
| 202 | /// The PDL results held by this list. |
| 203 | SmallVector<PDLValue> results; |
| 204 | /// Memory used to store ranges held by the list. |
| 205 | SmallVector<TypeRange> typeRanges; |
| 206 | SmallVector<ValueRange> valueRanges; |
| 207 | /// Memory allocated to store ranges in the result list whose lifetime was |
| 208 | /// generated in the native function. |
| 209 | SmallVector<llvm::OwningArrayRef<Type>> allocatedTypeRanges; |
| 210 | SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges; |
| 211 | }; |
| 212 | |
| 213 | //===----------------------------------------------------------------------===// |
| 214 | // PDLPatternConfig |
| 215 | |
| 216 | /// An individual configuration for a pattern, which can be accessed by native |
| 217 | /// functions via the PDLPatternConfigSet. This allows for injecting additional |
| 218 | /// configuration into PDL patterns that is specific to certain compilation |
| 219 | /// flows. |
| 220 | class PDLPatternConfig { |
| 221 | public: |
| 222 | virtual ~PDLPatternConfig() = default; |
| 223 | |
| 224 | /// Hooks that are invoked at the beginning and end of a rewrite of a matched |
| 225 | /// pattern. These can be used to setup any specific state necessary for the |
| 226 | /// rewrite. |
| 227 | virtual void notifyRewriteBegin(PatternRewriter &rewriter) {} |
| 228 | virtual void notifyRewriteEnd(PatternRewriter &rewriter) {} |
| 229 | |
| 230 | /// Return the TypeID that represents this configuration. |
| 231 | TypeID getTypeID() const { return id; } |
| 232 | |
| 233 | protected: |
| 234 | PDLPatternConfig(TypeID id) : id(id) {} |
| 235 | |
| 236 | private: |
| 237 | TypeID id; |
| 238 | }; |
| 239 | |
| 240 | /// This class provides a base class for users implementing a type of pattern |
| 241 | /// configuration. |
| 242 | template <typename T> |
| 243 | class PDLPatternConfigBase : public PDLPatternConfig { |
| 244 | public: |
| 245 | /// Support LLVM style casting. |
| 246 | static bool classof(const PDLPatternConfig *config) { |
| 247 | return config->getTypeID() == getConfigID(); |
| 248 | } |
| 249 | |
| 250 | /// Return the type id used for this configuration. |
| 251 | static TypeID getConfigID() { return TypeID::get<T>(); } |
| 252 | |
| 253 | protected: |
| 254 | PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {} |
| 255 | }; |
| 256 | |
| 257 | /// This class contains a set of configurations for a specific pattern. |
| 258 | /// Configurations are uniqued by TypeID, meaning that only one configuration of |
| 259 | /// each type is allowed. |
| 260 | class PDLPatternConfigSet { |
| 261 | public: |
| 262 | PDLPatternConfigSet() = default; |
| 263 | |
| 264 | /// Construct a set with the given configurations. |
| 265 | template <typename... ConfigsT> |
| 266 | PDLPatternConfigSet(ConfigsT &&...configs) { |
| 267 | (addConfig(std::forward<ConfigsT>(configs)), ...); |
| 268 | } |
| 269 | |
| 270 | /// Get the configuration defined by the given type. Asserts that the |
| 271 | /// configuration of the provided type exists. |
| 272 | template <typename T> |
| 273 | const T &get() const { |
| 274 | const T *config = tryGet<T>(); |
| 275 | assert(config && "configuration not found" ); |
| 276 | return *config; |
| 277 | } |
| 278 | |
| 279 | /// Get the configuration defined by the given type, returns nullptr if the |
| 280 | /// configuration does not exist. |
| 281 | template <typename T> |
| 282 | const T *tryGet() const { |
| 283 | for (const auto &configIt : configs) |
| 284 | if (const T *config = dyn_cast<T>(configIt.get())) |
| 285 | return config; |
| 286 | return nullptr; |
| 287 | } |
| 288 | |
| 289 | /// Notify the configurations within this set at the beginning or end of a |
| 290 | /// rewrite of a matched pattern. |
| 291 | void notifyRewriteBegin(PatternRewriter &rewriter) { |
| 292 | for (const auto &config : configs) |
| 293 | config->notifyRewriteBegin(rewriter); |
| 294 | } |
| 295 | void notifyRewriteEnd(PatternRewriter &rewriter) { |
| 296 | for (const auto &config : configs) |
| 297 | config->notifyRewriteEnd(rewriter); |
| 298 | } |
| 299 | |
| 300 | protected: |
| 301 | /// Add a configuration to the set. |
| 302 | template <typename T> |
| 303 | void addConfig(T &&config) { |
| 304 | assert(!tryGet<std::decay_t<T>>() && "configuration already exists" ); |
| 305 | configs.emplace_back( |
| 306 | std::make_unique<std::decay_t<T>>(std::forward<T>(config))); |
| 307 | } |
| 308 | |
| 309 | /// The set of configurations for this pattern. This uses a vector instead of |
| 310 | /// a map with the expectation that the number of configurations per set is |
| 311 | /// small (<= 1). |
| 312 | SmallVector<std::unique_ptr<PDLPatternConfig>> configs; |
| 313 | }; |
| 314 | |
| 315 | //===----------------------------------------------------------------------===// |
| 316 | // PDLPatternModule |
| 317 | |
| 318 | /// A generic PDL pattern constraint function. This function applies a |
| 319 | /// constraint to a given set of opaque PDLValue entities. Returns success if |
| 320 | /// the constraint successfully held, failure otherwise. |
| 321 | using PDLConstraintFunction = std::function<LogicalResult( |
| 322 | PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>; |
| 323 | |
| 324 | /// A native PDL rewrite function. This function performs a rewrite on the |
| 325 | /// given set of values. Any results from this rewrite that should be passed |
| 326 | /// back to PDL should be added to the provided result list. This method is only |
| 327 | /// invoked when the corresponding match was successful. Returns failure if an |
| 328 | /// invariant of the rewrite was broken (certain rewriters may recover from |
| 329 | /// partial pattern application). |
| 330 | using PDLRewriteFunction = std::function<LogicalResult( |
| 331 | PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>; |
| 332 | |
| 333 | namespace detail { |
| 334 | namespace pdl_function_builder { |
| 335 | /// A utility variable that always resolves to false. This is useful for static |
| 336 | /// asserts that are always false, but only should fire in certain templated |
| 337 | /// constructs. For example, if a templated function should never be called, the |
| 338 | /// function could be defined as: |
| 339 | /// |
| 340 | /// template <typename T> |
| 341 | /// void foo() { |
| 342 | /// static_assert(always_false<T>, "This function should never be called"); |
| 343 | /// } |
| 344 | /// |
| 345 | template <class... T> |
| 346 | constexpr bool always_false = false; |
| 347 | |
| 348 | //===----------------------------------------------------------------------===// |
| 349 | // PDL Function Builder: Type Processing |
| 350 | //===----------------------------------------------------------------------===// |
| 351 | |
| 352 | /// This struct provides a convenient way to determine how to process a given |
| 353 | /// type as either a PDL parameter, or a result value. This allows for |
| 354 | /// supporting complex types in constraint and rewrite functions, without |
| 355 | /// requiring the user to hand-write the necessary glue code themselves. |
| 356 | /// Specializations of this class should implement the following methods to |
| 357 | /// enable support as a PDL argument or result type: |
| 358 | /// |
| 359 | /// static LogicalResult verifyAsArg( |
| 360 | /// function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue, |
| 361 | /// size_t argIdx); |
| 362 | /// |
| 363 | /// * This method verifies that the given PDLValue is valid for use as a |
| 364 | /// value of `T`. |
| 365 | /// |
| 366 | /// static T processAsArg(PDLValue pdlValue); |
| 367 | /// |
| 368 | /// * This method processes the given PDLValue as a value of `T`. |
| 369 | /// |
| 370 | /// static void processAsResult(PatternRewriter &, PDLResultList &results, |
| 371 | /// const T &value); |
| 372 | /// |
| 373 | /// * This method processes the given value of `T` as the result of a |
| 374 | /// function invocation. The method should package the value into an |
| 375 | /// appropriate form and append it to the given result list. |
| 376 | /// |
| 377 | /// If the type `T` is based on a higher order value, consider using |
| 378 | /// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify |
| 379 | /// the implementation. |
| 380 | /// |
| 381 | template <typename T, typename Enable = void> |
| 382 | struct ProcessPDLValue; |
| 383 | |
| 384 | /// This struct provides a simplified model for processing types that are based |
| 385 | /// on another type, e.g. APInt is based on the handling for IntegerAttr. This |
| 386 | /// allows for building the necessary processing functions on top of the base |
| 387 | /// value instead of a PDLValue. Derived users should implement the following |
| 388 | /// (which subsume the ProcessPDLValue variants): |
| 389 | /// |
| 390 | /// static LogicalResult verifyAsArg( |
| 391 | /// function_ref<LogicalResult(const Twine &)> errorFn, |
| 392 | /// const BaseT &baseValue, size_t argIdx); |
| 393 | /// |
| 394 | /// * This method verifies that the given PDLValue is valid for use as a |
| 395 | /// value of `T`. |
| 396 | /// |
| 397 | /// static T processAsArg(BaseT baseValue); |
| 398 | /// |
| 399 | /// * This method processes the given base value as a value of `T`. |
| 400 | /// |
| 401 | template <typename T, typename BaseT> |
| 402 | struct ProcessPDLValueBasedOn { |
| 403 | static LogicalResult |
| 404 | verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, |
| 405 | PDLValue pdlValue, size_t argIdx) { |
| 406 | // Verify the base class before continuing. |
| 407 | if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx))) |
| 408 | return failure(); |
| 409 | return ProcessPDLValue<T>::verifyAsArg( |
| 410 | errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx); |
| 411 | } |
| 412 | static T processAsArg(PDLValue pdlValue) { |
| 413 | return ProcessPDLValue<T>::processAsArg( |
| 414 | ProcessPDLValue<BaseT>::processAsArg(pdlValue)); |
| 415 | } |
| 416 | |
| 417 | /// Explicitly add the expected parent API to ensure the parent class |
| 418 | /// implements the necessary API (and doesn't implicitly inherit it from |
| 419 | /// somewhere else). |
| 420 | static LogicalResult |
| 421 | verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, BaseT value, |
| 422 | size_t argIdx) { |
| 423 | return success(); |
| 424 | } |
| 425 | static T processAsArg(BaseT baseValue); |
| 426 | }; |
| 427 | |
| 428 | /// This struct provides a simplified model for processing types that have |
| 429 | /// "builtin" PDLValue support: |
| 430 | /// * Attribute, Operation *, Type, TypeRange, ValueRange |
| 431 | template <typename T> |
| 432 | struct ProcessBuiltinPDLValue { |
| 433 | static LogicalResult |
| 434 | verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, |
| 435 | PDLValue pdlValue, size_t argIdx) { |
| 436 | if (pdlValue) |
| 437 | return success(); |
| 438 | return errorFn("expected a non-null value for argument " + Twine(argIdx) + |
| 439 | " of type: " + llvm::getTypeName<T>()); |
| 440 | } |
| 441 | |
| 442 | static T processAsArg(PDLValue pdlValue) { return pdlValue.cast<T>(); } |
| 443 | static void processAsResult(PatternRewriter &, PDLResultList &results, |
| 444 | T value) { |
| 445 | results.push_back(value); |
| 446 | } |
| 447 | }; |
| 448 | |
| 449 | /// This struct provides a simplified model for processing types that inherit |
| 450 | /// from builtin PDLValue types. For example, derived attributes like |
| 451 | /// IntegerAttr, derived types like IntegerType, derived operations like |
| 452 | /// ModuleOp, Interfaces, etc. |
| 453 | template <typename T, typename BaseT> |
| 454 | struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> { |
| 455 | static LogicalResult |
| 456 | verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, |
| 457 | BaseT baseValue, size_t argIdx) { |
| 458 | return TypeSwitch<BaseT, LogicalResult>(baseValue) |
| 459 | .Case([&](T) { return success(); }) |
| 460 | .Default([&](BaseT) { |
| 461 | return errorFn("expected argument " + Twine(argIdx) + |
| 462 | " to be of type: " + llvm::getTypeName<T>()); |
| 463 | }); |
| 464 | } |
| 465 | using ProcessPDLValueBasedOn<T, BaseT>::verifyAsArg; |
| 466 | |
| 467 | static T processAsArg(BaseT baseValue) { |
| 468 | return baseValue.template cast<T>(); |
| 469 | } |
| 470 | using ProcessPDLValueBasedOn<T, BaseT>::processAsArg; |
| 471 | |
| 472 | static void processAsResult(PatternRewriter &, PDLResultList &results, |
| 473 | T value) { |
| 474 | results.push_back(value); |
| 475 | } |
| 476 | }; |
| 477 | |
| 478 | //===----------------------------------------------------------------------===// |
| 479 | // Attribute |
| 480 | |
| 481 | template <> |
| 482 | struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> {}; |
| 483 | template <typename T> |
| 484 | struct ProcessPDLValue<T, |
| 485 | std::enable_if_t<std::is_base_of<Attribute, T>::value>> |
| 486 | : public ProcessDerivedPDLValue<T, Attribute> {}; |
| 487 | |
| 488 | /// Handling for various Attribute value types. |
| 489 | template <> |
| 490 | struct ProcessPDLValue<StringRef> |
| 491 | : public ProcessPDLValueBasedOn<StringRef, StringAttr> { |
| 492 | static StringRef processAsArg(StringAttr value) { return value.getValue(); } |
| 493 | using ProcessPDLValueBasedOn<StringRef, StringAttr>::processAsArg; |
| 494 | |
| 495 | static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, |
| 496 | StringRef value) { |
| 497 | results.push_back(rewriter.getStringAttr(value)); |
| 498 | } |
| 499 | }; |
| 500 | template <> |
| 501 | struct ProcessPDLValue<std::string> |
| 502 | : public ProcessPDLValueBasedOn<std::string, StringAttr> { |
| 503 | template <typename T> |
| 504 | static std::string processAsArg(T value) { |
| 505 | static_assert(always_false<T>, |
| 506 | "`std::string` arguments require a string copy, use " |
| 507 | "`StringRef` for string-like arguments instead" ); |
| 508 | return {}; |
| 509 | } |
| 510 | static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, |
| 511 | StringRef value) { |
| 512 | results.push_back(rewriter.getStringAttr(value)); |
| 513 | } |
| 514 | }; |
| 515 | |
| 516 | //===----------------------------------------------------------------------===// |
| 517 | // Operation |
| 518 | |
| 519 | template <> |
| 520 | struct ProcessPDLValue<Operation *> |
| 521 | : public ProcessBuiltinPDLValue<Operation *> {}; |
| 522 | template <typename T> |
| 523 | struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>> |
| 524 | : public ProcessDerivedPDLValue<T, Operation *> { |
| 525 | static T processAsArg(Operation *value) { return cast<T>(value); } |
| 526 | }; |
| 527 | |
| 528 | //===----------------------------------------------------------------------===// |
| 529 | // Type |
| 530 | |
| 531 | template <> |
| 532 | struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> {}; |
| 533 | template <typename T> |
| 534 | struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>> |
| 535 | : public ProcessDerivedPDLValue<T, Type> {}; |
| 536 | |
| 537 | //===----------------------------------------------------------------------===// |
| 538 | // TypeRange |
| 539 | |
| 540 | template <> |
| 541 | struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> {}; |
| 542 | template <> |
| 543 | struct ProcessPDLValue<ValueTypeRange<OperandRange>> { |
| 544 | static void processAsResult(PatternRewriter &, PDLResultList &results, |
| 545 | ValueTypeRange<OperandRange> types) { |
| 546 | results.push_back(value: types); |
| 547 | } |
| 548 | }; |
| 549 | template <> |
| 550 | struct ProcessPDLValue<ValueTypeRange<ResultRange>> { |
| 551 | static void processAsResult(PatternRewriter &, PDLResultList &results, |
| 552 | ValueTypeRange<ResultRange> types) { |
| 553 | results.push_back(value: types); |
| 554 | } |
| 555 | }; |
| 556 | template <unsigned N> |
| 557 | struct ProcessPDLValue<SmallVector<Type, N>> { |
| 558 | static void processAsResult(PatternRewriter &, PDLResultList &results, |
| 559 | SmallVector<Type, N> values) { |
| 560 | results.push_back(value: TypeRange(values)); |
| 561 | } |
| 562 | }; |
| 563 | |
| 564 | //===----------------------------------------------------------------------===// |
| 565 | // Value |
| 566 | |
| 567 | template <> |
| 568 | struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> {}; |
| 569 | |
| 570 | //===----------------------------------------------------------------------===// |
| 571 | // ValueRange |
| 572 | |
| 573 | template <> |
| 574 | struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> { |
| 575 | }; |
| 576 | template <> |
| 577 | struct ProcessPDLValue<OperandRange> { |
| 578 | static void processAsResult(PatternRewriter &, PDLResultList &results, |
| 579 | OperandRange values) { |
| 580 | results.push_back(value: values); |
| 581 | } |
| 582 | }; |
| 583 | template <> |
| 584 | struct ProcessPDLValue<ResultRange> { |
| 585 | static void processAsResult(PatternRewriter &, PDLResultList &results, |
| 586 | ResultRange values) { |
| 587 | results.push_back(value: values); |
| 588 | } |
| 589 | }; |
| 590 | template <unsigned N> |
| 591 | struct ProcessPDLValue<SmallVector<Value, N>> { |
| 592 | static void processAsResult(PatternRewriter &, PDLResultList &results, |
| 593 | SmallVector<Value, N> values) { |
| 594 | results.push_back(value: ValueRange(values)); |
| 595 | } |
| 596 | }; |
| 597 | |
| 598 | //===----------------------------------------------------------------------===// |
| 599 | // PDL Function Builder: Argument Handling |
| 600 | //===----------------------------------------------------------------------===// |
| 601 | |
| 602 | /// Validate the given PDLValues match the constraints defined by the argument |
| 603 | /// types of the given function. In the case of failure, a match failure |
| 604 | /// diagnostic is emitted. |
| 605 | /// FIXME: This should be completely removed in favor of `assertArgs`, but PDL |
| 606 | /// does not currently preserve Constraint application ordering. |
| 607 | template <typename PDLFnT, std::size_t... I> |
| 608 | LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values, |
| 609 | std::index_sequence<I...>) { |
| 610 | using FnTraitsT = llvm::function_traits<PDLFnT>; |
| 611 | |
| 612 | auto errorFn = [&](const Twine &msg) { |
| 613 | return rewriter.notifyMatchFailure(arg: rewriter.getUnknownLoc(), msg); |
| 614 | }; |
| 615 | return success( |
| 616 | (succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>:: |
| 617 | verifyAsArg(errorFn, values[I], I)) && |
| 618 | ...)); |
| 619 | } |
| 620 | |
| 621 | /// Assert that the given PDLValues match the constraints defined by the |
| 622 | /// arguments of the given function. In the case of failure, a fatal error |
| 623 | /// is emitted. |
| 624 | template <typename PDLFnT, std::size_t... I> |
| 625 | void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values, |
| 626 | std::index_sequence<I...>) { |
| 627 | // We only want to do verification in debug builds, same as with `assert`. |
| 628 | #ifndef NDEBUG |
| 629 | using FnTraitsT = llvm::function_traits<PDLFnT>; |
| 630 | auto errorFn = [&](const Twine &msg) -> LogicalResult { |
| 631 | llvm::report_fatal_error(reason: msg); |
| 632 | }; |
| 633 | (void)errorFn; |
| 634 | assert((succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>:: |
| 635 | verifyAsArg(errorFn, values[I], I)) && |
| 636 | ...)); |
| 637 | #endif |
| 638 | (void)values; |
| 639 | } |
| 640 | |
| 641 | //===----------------------------------------------------------------------===// |
| 642 | // PDL Function Builder: Results Handling |
| 643 | //===----------------------------------------------------------------------===// |
| 644 | |
| 645 | /// Store a single result within the result list. |
| 646 | template <typename T> |
| 647 | static LogicalResult processResults(PatternRewriter &rewriter, |
| 648 | PDLResultList &results, T &&value) { |
| 649 | ProcessPDLValue<T>::processAsResult(rewriter, results, |
| 650 | std::forward<T>(value)); |
| 651 | return success(); |
| 652 | } |
| 653 | |
| 654 | /// Store a std::pair<> as individual results within the result list. |
| 655 | template <typename T1, typename T2> |
| 656 | static LogicalResult processResults(PatternRewriter &rewriter, |
| 657 | PDLResultList &results, |
| 658 | std::pair<T1, T2> &&pair) { |
| 659 | if (failed(processResults(rewriter, results, std::move(pair.first))) || |
| 660 | failed(processResults(rewriter, results, std::move(pair.second)))) |
| 661 | return failure(); |
| 662 | return success(); |
| 663 | } |
| 664 | |
| 665 | /// Store a std::tuple<> as individual results within the result list. |
| 666 | template <typename... Ts> |
| 667 | static LogicalResult processResults(PatternRewriter &rewriter, |
| 668 | PDLResultList &results, |
| 669 | std::tuple<Ts...> &&tuple) { |
| 670 | auto applyFn = [&](auto &&...args) { |
| 671 | return (succeeded(processResults(rewriter, results, std::move(args))) && |
| 672 | ...); |
| 673 | }; |
| 674 | return success(std::apply(applyFn, std::move(tuple))); |
| 675 | } |
| 676 | |
| 677 | /// Handle LogicalResult propagation. |
| 678 | inline LogicalResult processResults(PatternRewriter &rewriter, |
| 679 | PDLResultList &results, |
| 680 | LogicalResult &&result) { |
| 681 | return result; |
| 682 | } |
| 683 | template <typename T> |
| 684 | static LogicalResult processResults(PatternRewriter &rewriter, |
| 685 | PDLResultList &results, |
| 686 | FailureOr<T> &&result) { |
| 687 | if (failed(result)) |
| 688 | return failure(); |
| 689 | return processResults(rewriter, results, std::move(*result)); |
| 690 | } |
| 691 | |
| 692 | //===----------------------------------------------------------------------===// |
| 693 | // PDL Constraint Builder |
| 694 | //===----------------------------------------------------------------------===// |
| 695 | |
| 696 | /// Process the arguments of a native constraint and invoke it. |
| 697 | template <typename PDLFnT, std::size_t... I, |
| 698 | typename FnTraitsT = llvm::function_traits<PDLFnT>> |
| 699 | typename FnTraitsT::result_t |
| 700 | processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter, |
| 701 | ArrayRef<PDLValue> values, |
| 702 | std::index_sequence<I...>) { |
| 703 | return fn( |
| 704 | rewriter, |
| 705 | (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg( |
| 706 | values[I]))...); |
| 707 | } |
| 708 | |
| 709 | /// Build a constraint function from the given function `ConstraintFnT`. This |
| 710 | /// allows for enabling the user to define simpler, more direct constraint |
| 711 | /// functions without needing to handle the low-level PDL goop. |
| 712 | /// |
| 713 | /// If the constraint function is already in the correct form, we just forward |
| 714 | /// it directly. |
| 715 | template <typename ConstraintFnT> |
| 716 | std::enable_if_t< |
| 717 | std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value, |
| 718 | PDLConstraintFunction> |
| 719 | buildConstraintFn(ConstraintFnT &&constraintFn) { |
| 720 | return std::forward<ConstraintFnT>(constraintFn); |
| 721 | } |
| 722 | /// Otherwise, we generate a wrapper that will unpack the PDLValues in the form |
| 723 | /// we desire. |
| 724 | template <typename ConstraintFnT> |
| 725 | std::enable_if_t< |
| 726 | !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value, |
| 727 | PDLConstraintFunction> |
| 728 | buildConstraintFn(ConstraintFnT &&constraintFn) { |
| 729 | return [constraintFn = std::forward<ConstraintFnT>(constraintFn)]( |
| 730 | PatternRewriter &rewriter, PDLResultList &, |
| 731 | ArrayRef<PDLValue> values) -> LogicalResult { |
| 732 | auto argIndices = std::make_index_sequence< |
| 733 | llvm::function_traits<ConstraintFnT>::num_args - 1>(); |
| 734 | if (failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices))) |
| 735 | return failure(); |
| 736 | return processArgsAndInvokeConstraint(constraintFn, rewriter, values, |
| 737 | argIndices); |
| 738 | }; |
| 739 | } |
| 740 | |
| 741 | //===----------------------------------------------------------------------===// |
| 742 | // PDL Rewrite Builder |
| 743 | //===----------------------------------------------------------------------===// |
| 744 | |
| 745 | /// Process the arguments of a native rewrite and invoke it. |
| 746 | /// This overload handles the case of no return values. |
| 747 | template <typename PDLFnT, std::size_t... I, |
| 748 | typename FnTraitsT = llvm::function_traits<PDLFnT>> |
| 749 | std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value, |
| 750 | LogicalResult> |
| 751 | processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, |
| 752 | PDLResultList &, ArrayRef<PDLValue> values, |
| 753 | std::index_sequence<I...>) { |
| 754 | fn(rewriter, |
| 755 | (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg( |
| 756 | values[I]))...); |
| 757 | return success(); |
| 758 | } |
| 759 | /// This overload handles the case of return values, which need to be packaged |
| 760 | /// into the result list. |
| 761 | template <typename PDLFnT, std::size_t... I, |
| 762 | typename FnTraitsT = llvm::function_traits<PDLFnT>> |
| 763 | std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value, |
| 764 | LogicalResult> |
| 765 | processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, |
| 766 | PDLResultList &results, ArrayRef<PDLValue> values, |
| 767 | std::index_sequence<I...>) { |
| 768 | return processResults( |
| 769 | rewriter, results, |
| 770 | fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>:: |
| 771 | processAsArg(values[I]))...)); |
| 772 | (void)values; |
| 773 | } |
| 774 | |
| 775 | /// Build a rewrite function from the given function `RewriteFnT`. This |
| 776 | /// allows for enabling the user to define simpler, more direct rewrite |
| 777 | /// functions without needing to handle the low-level PDL goop. |
| 778 | /// |
| 779 | /// If the rewrite function is already in the correct form, we just forward |
| 780 | /// it directly. |
| 781 | template <typename RewriteFnT> |
| 782 | std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value, |
| 783 | PDLRewriteFunction> |
| 784 | buildRewriteFn(RewriteFnT &&rewriteFn) { |
| 785 | return std::forward<RewriteFnT>(rewriteFn); |
| 786 | } |
| 787 | /// Otherwise, we generate a wrapper that will unpack the PDLValues in the form |
| 788 | /// we desire. |
| 789 | template <typename RewriteFnT> |
| 790 | std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value, |
| 791 | PDLRewriteFunction> |
| 792 | buildRewriteFn(RewriteFnT &&rewriteFn) { |
| 793 | return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)]( |
| 794 | PatternRewriter &rewriter, PDLResultList &results, |
| 795 | ArrayRef<PDLValue> values) { |
| 796 | auto argIndices = |
| 797 | std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args - |
| 798 | 1>(); |
| 799 | assertArgs<RewriteFnT>(rewriter, values, argIndices); |
| 800 | return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values, |
| 801 | argIndices); |
| 802 | }; |
| 803 | } |
| 804 | |
| 805 | } // namespace pdl_function_builder |
| 806 | } // namespace detail |
| 807 | |
| 808 | //===----------------------------------------------------------------------===// |
| 809 | // PDLPatternModule |
| 810 | |
| 811 | /// This class contains all of the necessary data for a set of PDL patterns, or |
| 812 | /// pattern rewrites specified in the form of the PDL dialect. This PDL module |
| 813 | /// contained by this pattern may contain any number of `pdl.pattern` |
| 814 | /// operations. |
| 815 | class PDLPatternModule { |
| 816 | public: |
| 817 | PDLPatternModule() = default; |
| 818 | |
| 819 | /// Construct a PDL pattern with the given module and configurations. |
| 820 | PDLPatternModule(OwningOpRef<ModuleOp> module) |
| 821 | : pdlModule(std::move(module)) {} |
| 822 | template <typename... ConfigsT> |
| 823 | PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs) |
| 824 | : PDLPatternModule(std::move(module)) { |
| 825 | auto configSet = std::make_unique<PDLPatternConfigSet>( |
| 826 | std::forward<ConfigsT>(patternConfigs)...); |
| 827 | attachConfigToPatterns(*pdlModule, *configSet); |
| 828 | configs.emplace_back(std::move(configSet)); |
| 829 | } |
| 830 | |
| 831 | /// Merge the state in `other` into this pattern module. |
| 832 | void mergeIn(PDLPatternModule &&other); |
| 833 | |
| 834 | /// Return the internal PDL module of this pattern. |
| 835 | ModuleOp getModule() { return pdlModule.get(); } |
| 836 | |
| 837 | /// Return the MLIR context of this pattern. |
| 838 | MLIRContext *getContext() { return getModule()->getContext(); } |
| 839 | |
| 840 | //===--------------------------------------------------------------------===// |
| 841 | // Function Registry |
| 842 | |
| 843 | /// Register a constraint function with PDL. A constraint function may be |
| 844 | /// specified in one of two ways: |
| 845 | /// |
| 846 | /// * `LogicalResult (PatternRewriter &, |
| 847 | /// PDLResultList &, |
| 848 | /// ArrayRef<PDLValue>)` |
| 849 | /// |
| 850 | /// In this overload the arguments of the constraint function are passed via |
| 851 | /// the low-level PDLValue form, and the results are manually appended to |
| 852 | /// the given result list. |
| 853 | /// |
| 854 | /// * `LogicalResult (PatternRewriter &, ValueTs... values)` |
| 855 | /// |
| 856 | /// In this form the arguments of the constraint function are passed via the |
| 857 | /// expected high level C++ type. In this form, the framework will |
| 858 | /// automatically unwrap PDLValues and convert them to the expected ValueTs. |
| 859 | /// For example, if the constraint function accepts a `Operation *`, the |
| 860 | /// framework will automatically cast the input PDLValue. In the case of a |
| 861 | /// `StringRef`, the framework will automatically unwrap the argument as a |
| 862 | /// StringAttr and pass the underlying string value. To see the full list of |
| 863 | /// supported types, or to see how to add handling for custom types, view |
| 864 | /// the definition of `ProcessPDLValue` above. |
| 865 | void registerConstraintFunction(StringRef name, |
| 866 | PDLConstraintFunction constraintFn); |
| 867 | template <typename ConstraintFnT> |
| 868 | void registerConstraintFunction(StringRef name, |
| 869 | ConstraintFnT &&constraintFn) { |
| 870 | registerConstraintFunction(name, |
| 871 | detail::pdl_function_builder::buildConstraintFn( |
| 872 | std::forward<ConstraintFnT>(constraintFn))); |
| 873 | } |
| 874 | |
| 875 | /// Register a rewrite function with PDL. A rewrite function may be specified |
| 876 | /// in one of two ways: |
| 877 | /// |
| 878 | /// * `void (PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)` |
| 879 | /// |
| 880 | /// In this overload the arguments of the constraint function are passed via |
| 881 | /// the low-level PDLValue form, and the results are manually appended to |
| 882 | /// the given result list. |
| 883 | /// |
| 884 | /// * `ResultT (PatternRewriter &, ValueTs... values)` |
| 885 | /// |
| 886 | /// In this form the arguments and result of the rewrite function are passed |
| 887 | /// via the expected high level C++ type. In this form, the framework will |
| 888 | /// automatically unwrap the PDLValues arguments and convert them to the |
| 889 | /// expected ValueTs. It will also automatically handle the processing and |
| 890 | /// packaging of the result value to the result list. For example, if the |
| 891 | /// rewrite function takes a `Operation *`, the framework will automatically |
| 892 | /// cast the input PDLValue. In the case of a `StringRef`, the framework |
| 893 | /// will automatically unwrap the argument as a StringAttr and pass the |
| 894 | /// underlying string value. In the reverse case, if the rewrite returns a |
| 895 | /// StringRef or std::string, it will automatically package this as a |
| 896 | /// StringAttr and append it to the result list. To see the full list of |
| 897 | /// supported types, or to see how to add handling for custom types, view |
| 898 | /// the definition of `ProcessPDLValue` above. |
| 899 | void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn); |
| 900 | template <typename RewriteFnT> |
| 901 | void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) { |
| 902 | registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn( |
| 903 | std::forward<RewriteFnT>(rewriteFn))); |
| 904 | } |
| 905 | |
| 906 | /// Return the set of the registered constraint functions. |
| 907 | const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const { |
| 908 | return constraintFunctions; |
| 909 | } |
| 910 | llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() { |
| 911 | return constraintFunctions; |
| 912 | } |
| 913 | /// Return the set of the registered rewrite functions. |
| 914 | const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const { |
| 915 | return rewriteFunctions; |
| 916 | } |
| 917 | llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() { |
| 918 | return rewriteFunctions; |
| 919 | } |
| 920 | |
| 921 | /// Return the set of the registered pattern configs. |
| 922 | SmallVector<std::unique_ptr<PDLPatternConfigSet>> takeConfigs() { |
| 923 | return std::move(configs); |
| 924 | } |
| 925 | DenseMap<Operation *, PDLPatternConfigSet *> takeConfigMap() { |
| 926 | return std::move(configMap); |
| 927 | } |
| 928 | |
| 929 | /// Clear out the patterns and functions within this module. |
| 930 | void clear() { |
| 931 | pdlModule = nullptr; |
| 932 | constraintFunctions.clear(); |
| 933 | rewriteFunctions.clear(); |
| 934 | } |
| 935 | |
| 936 | private: |
| 937 | /// Attach the given pattern config set to the patterns defined within the |
| 938 | /// given module. |
| 939 | void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet); |
| 940 | |
| 941 | /// The module containing the `pdl.pattern` operations. |
| 942 | OwningOpRef<ModuleOp> pdlModule; |
| 943 | |
| 944 | /// The set of configuration sets referenced by patterns within `pdlModule`. |
| 945 | SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs; |
| 946 | DenseMap<Operation *, PDLPatternConfigSet *> configMap; |
| 947 | |
| 948 | /// The external functions referenced from within the PDL module. |
| 949 | llvm::StringMap<PDLConstraintFunction> constraintFunctions; |
| 950 | llvm::StringMap<PDLRewriteFunction> rewriteFunctions; |
| 951 | }; |
| 952 | } // namespace mlir |
| 953 | |
| 954 | #else |
| 955 | |
| 956 | namespace mlir { |
| 957 | // Stubs for when PDL in pattern rewrites is not enabled. |
| 958 | |
| 959 | class PDLValue { |
| 960 | public: |
| 961 | template <typename T> |
| 962 | T dyn_cast() const { |
| 963 | return nullptr; |
| 964 | } |
| 965 | }; |
| 966 | class PDLResultList {}; |
| 967 | using PDLConstraintFunction = std::function<LogicalResult( |
| 968 | PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>; |
| 969 | using PDLRewriteFunction = std::function<LogicalResult( |
| 970 | PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>; |
| 971 | |
| 972 | class PDLPatternModule { |
| 973 | public: |
| 974 | PDLPatternModule() = default; |
| 975 | |
| 976 | PDLPatternModule(OwningOpRef<ModuleOp> /*module*/) {} |
| 977 | MLIRContext *getContext() { |
| 978 | llvm_unreachable("Error: PDL for rewrites when PDL is not enabled" ); |
| 979 | } |
| 980 | void mergeIn(PDLPatternModule &&other) {} |
| 981 | void clear() {} |
| 982 | template <typename ConstraintFnT> |
| 983 | void registerConstraintFunction(StringRef name, |
| 984 | ConstraintFnT &&constraintFn) {} |
| 985 | void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn) {} |
| 986 | template <typename RewriteFnT> |
| 987 | void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {} |
| 988 | const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const { |
| 989 | return constraintFunctions; |
| 990 | } |
| 991 | |
| 992 | private: |
| 993 | llvm::StringMap<PDLConstraintFunction> constraintFunctions; |
| 994 | }; |
| 995 | |
| 996 | } // namespace mlir |
| 997 | #endif |
| 998 | |
| 999 | #endif // MLIR_IR_PDLPATTERNMATCH_H |
| 1000 | |