1 | //===- TransformsDetail.h - -------------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H |
10 | #define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H |
11 | |
12 | #include "mlir/IR/PatternMatch.h" |
13 | #include "mlir/IR/SymbolTable.h" |
14 | |
15 | namespace mlir { |
16 | namespace mesh { |
17 | |
18 | template <typename Op> |
19 | struct OpRewritePatternWithSymbolTableCollection : OpRewritePattern<Op> { |
20 | template <typename... OpRewritePatternArgs> |
21 | OpRewritePatternWithSymbolTableCollection( |
22 | SymbolTableCollection &symbolTableCollection, |
23 | OpRewritePatternArgs &&...opRewritePatternArgs) |
24 | : OpRewritePattern<Op>( |
25 | std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...), |
26 | symbolTableCollection(symbolTableCollection) {} |
27 | |
28 | protected: |
29 | SymbolTableCollection &symbolTableCollection; |
30 | }; |
31 | |
32 | } // namespace mesh |
33 | } // namespace mlir |
34 | |
35 | #endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H |
36 | |