1 | //===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This header file defines prototypes that expose pass constructors in the |
10 | // shape transformation library. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_ |
15 | #define MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_ |
16 | |
17 | #include "mlir/Pass/Pass.h" |
18 | |
19 | namespace mlir { |
20 | class ConversionTarget; |
21 | class ModuleOp; |
22 | class TypeConverter; |
23 | namespace func { |
24 | class FuncOp; |
25 | } // namespace func |
26 | } // namespace mlir |
27 | |
28 | namespace mlir { |
29 | |
30 | #define GEN_PASS_DECL |
31 | #include "mlir/Dialect/Shape/Transforms/Passes.h.inc" |
32 | |
33 | /// Creates an instance of the ShapeToShapeLowering pass that legalizes Shape |
34 | /// dialect to be convertible to Arith. For example, `shape.num_elements` get |
35 | /// transformed to `shape.reduce`, which can be lowered to SCF and Arith. |
36 | std::unique_ptr<Pass> createShapeToShapeLowering(); |
37 | |
38 | /// Collects a set of patterns to rewrite ops within the Shape dialect. |
39 | void populateShapeRewritePatterns(RewritePatternSet &patterns); |
40 | |
41 | // Collects a set of patterns to replace all constraints with passing witnesses. |
42 | // This is intended to then allow all ShapeConstraint related ops and data to |
43 | // have no effects and allow them to be freely removed such as through |
44 | // canonicalization and dead code elimination. |
45 | // |
46 | // After this pass, no cstr_ operations exist. |
47 | void populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns); |
48 | std::unique_ptr<OperationPass<func::FuncOp>> createRemoveShapeConstraintsPass(); |
49 | |
50 | // Bufferizes shape dialect ops. |
51 | // |
52 | // Note that most shape dialect ops must be converted to std before |
53 | // bufferization happens, as they are intended to be bufferized at the std |
54 | // level. |
55 | std::unique_ptr<OperationPass<func::FuncOp>> createShapeBufferizePass(); |
56 | |
57 | /// Outline the shape computation part by adding shape.func and populate |
58 | /// conrresponding mapping infomation into ShapeMappingAnalysis. |
59 | std::unique_ptr<OperationPass<ModuleOp>> createOutlineShapeComputationPass(); |
60 | |
61 | //===----------------------------------------------------------------------===// |
62 | // Registration |
63 | //===----------------------------------------------------------------------===// |
64 | |
65 | /// Generate the code for registering passes. |
66 | #define GEN_PASS_REGISTRATION |
67 | #include "mlir/Dialect/Shape/Transforms/Passes.h.inc" |
68 | |
69 | } // namespace mlir |
70 | |
71 | #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_ |
72 | |