1 | //===- ReducePatternInterface.h - Collecting Reduce Patterns ----*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H |
10 | #define MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H |
11 | |
12 | #include "mlir/IR/DialectInterface.h" |
13 | |
14 | namespace mlir { |
15 | |
16 | class RewritePatternSet; |
17 | |
18 | /// This is used to report the reduction patterns for a Dialect. While using |
19 | /// mlir-reduce to reduce a module, we may want to transform certain cases into |
20 | /// simpler forms by applying certain rewrite patterns. Implement the |
21 | /// `populateReductionPatterns` to report those patterns by adding them to the |
22 | /// RewritePatternSet. |
23 | /// |
24 | /// Example: |
25 | /// MyDialectReductionPattern::populateReductionPatterns( |
26 | /// RewritePatternSet &patterns) { |
27 | /// patterns.add<TensorOpReduction>(patterns.getContext()); |
28 | /// } |
29 | /// |
30 | /// For DRR, mlir-tblgen will generate a helper function |
31 | /// `populateWithGenerated` which has the same signature therefore you can |
32 | /// delegate to the helper function as well. |
33 | /// |
34 | /// Example: |
35 | /// MyDialectReductionPattern::populateReductionPatterns( |
36 | /// RewritePatternSet &patterns) { |
37 | /// // Include the autogen file somewhere above. |
38 | /// populateWithGenerated(patterns); |
39 | /// } |
40 | class DialectReductionPatternInterface |
41 | : public DialectInterface::Base<DialectReductionPatternInterface> { |
42 | public: |
43 | /// Patterns provided here are intended to transform operations from a complex |
44 | /// form to a simpler form, without breaking the semantics of the program |
45 | /// being reduced. For example, you may want to replace the |
46 | /// tensor<?xindex> with a known rank and type, e.g. tensor<1xi32>, or |
47 | /// replacing an operation with a constant. |
48 | virtual void populateReductionPatterns(RewritePatternSet &patterns) const = 0; |
49 | |
50 | protected: |
51 | DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {} |
52 | }; |
53 | |
54 | } // namespace mlir |
55 | |
56 | #endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H |
57 | |