1 | //===- Passes.h - Sparse tensor pass entry points ---------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This header file defines prototypes of all sparse tensor passes. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES_H_ |
14 | #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES_H_ |
15 | |
16 | #include "mlir/IR/PatternMatch.h" |
17 | #include "mlir/Pass/Pass.h" |
18 | #include "mlir/Transforms/DialectConversion.h" |
19 | |
20 | //===----------------------------------------------------------------------===// |
21 | // Include the generated pass header (which needs some early definitions). |
22 | //===----------------------------------------------------------------------===// |
23 | |
24 | namespace mlir { |
25 | |
26 | namespace bufferization { |
27 | struct OneShotBufferizationOptions; |
28 | } // namespace bufferization |
29 | |
30 | /// Defines a parallelization strategy. Any independent loop is a candidate |
31 | /// for parallelization. The loop is made parallel if (1) allowed by the |
32 | /// strategy (e.g., AnyStorageOuterLoop considers either a dense or sparse |
33 | /// outermost loop only), and (2) the generated code is an actual for-loop |
34 | /// (and not a co-iterating while-loop). |
35 | enum class SparseParallelizationStrategy { |
36 | kNone, |
37 | kDenseOuterLoop, |
38 | kAnyStorageOuterLoop, |
39 | kDenseAnyLoop, |
40 | kAnyStorageAnyLoop |
41 | }; |
42 | |
43 | /// Defines a scope for reinterpret map pass. |
44 | enum class ReinterpretMapScope { |
45 | kAll, // reinterprets all applicable operations |
46 | kGenericOnly, // reinterprets only linalg.generic |
47 | kExceptGeneric, // reinterprets operation other than linalg.generic |
48 | }; |
49 | |
50 | /// Defines a scope for reinterpret map pass. |
51 | enum class SparseEmitStrategy { |
52 | kFunctional, // generate fully inlined (and functional) sparse iteration |
53 | kDebugInterface, // generate only place-holder for sparse iteration |
54 | }; |
55 | |
56 | #define GEN_PASS_DECL |
57 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" |
58 | |
59 | //===----------------------------------------------------------------------===// |
60 | // The SparseAssembler pass. |
61 | //===----------------------------------------------------------------------===// |
62 | |
63 | void populateSparseAssembler(RewritePatternSet &patterns, bool directOut); |
64 | |
65 | std::unique_ptr<Pass> createSparseAssembler(); |
66 | std::unique_ptr<Pass> createSparseAssembler(bool directOut); |
67 | |
68 | //===----------------------------------------------------------------------===// |
69 | // The SparseReinterpretMap pass. |
70 | //===----------------------------------------------------------------------===// |
71 | |
72 | void populateSparseReinterpretMap(RewritePatternSet &patterns, |
73 | ReinterpretMapScope scope); |
74 | |
75 | std::unique_ptr<Pass> createSparseReinterpretMapPass(); |
76 | std::unique_ptr<Pass> createSparseReinterpretMapPass(ReinterpretMapScope scope); |
77 | |
78 | //===----------------------------------------------------------------------===// |
79 | // The PreSparsificationRewriting pass. |
80 | //===----------------------------------------------------------------------===// |
81 | |
82 | void populatePreSparsificationRewriting(RewritePatternSet &patterns); |
83 | |
84 | std::unique_ptr<Pass> createPreSparsificationRewritePass(); |
85 | |
86 | //===----------------------------------------------------------------------===// |
87 | // The Sparsification pass. |
88 | //===----------------------------------------------------------------------===// |
89 | |
90 | /// Options for the Sparsification pass. |
91 | struct SparsificationOptions { |
92 | SparsificationOptions(SparseParallelizationStrategy p, SparseEmitStrategy d, |
93 | bool enableRT) |
94 | : parallelizationStrategy(p), sparseEmitStrategy(d), |
95 | enableRuntimeLibrary(enableRT) {} |
96 | |
97 | SparsificationOptions(SparseParallelizationStrategy p, bool enableRT) |
98 | : SparsificationOptions(p, SparseEmitStrategy::kFunctional, enableRT) {} |
99 | |
100 | SparsificationOptions() |
101 | : SparsificationOptions(SparseParallelizationStrategy::kNone, |
102 | SparseEmitStrategy::kFunctional, true) {} |
103 | |
104 | SparseParallelizationStrategy parallelizationStrategy; |
105 | SparseEmitStrategy sparseEmitStrategy; |
106 | bool enableRuntimeLibrary; |
107 | }; |
108 | |
109 | /// Sets up sparsification rewriting rules with the given options. |
110 | void populateSparsificationPatterns( |
111 | RewritePatternSet &patterns, |
112 | const SparsificationOptions &options = SparsificationOptions()); |
113 | |
114 | std::unique_ptr<Pass> createSparsificationPass(); |
115 | std::unique_ptr<Pass> |
116 | createSparsificationPass(const SparsificationOptions &options); |
117 | |
118 | //===----------------------------------------------------------------------===// |
119 | // The StageSparseOperations pass. |
120 | //===----------------------------------------------------------------------===// |
121 | |
122 | /// Sets up StageSparseOperation rewriting rules. |
123 | void populateStageSparseOperationsPatterns(RewritePatternSet &patterns); |
124 | |
125 | std::unique_ptr<Pass> createStageSparseOperationsPass(); |
126 | |
127 | //===----------------------------------------------------------------------===// |
128 | // The LowerSparseOpsToForeach pass. |
129 | //===----------------------------------------------------------------------===// |
130 | |
131 | void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, |
132 | bool enableRT, bool enableConvert); |
133 | |
134 | std::unique_ptr<Pass> createLowerSparseOpsToForeachPass(); |
135 | std::unique_ptr<Pass> createLowerSparseOpsToForeachPass(bool enableRT, |
136 | bool enableConvert); |
137 | |
138 | //===----------------------------------------------------------------------===// |
139 | // The LowerForeachToSCF pass. |
140 | //===----------------------------------------------------------------------===// |
141 | |
142 | void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns); |
143 | |
144 | std::unique_ptr<Pass> createLowerForeachToSCFPass(); |
145 | |
146 | //===----------------------------------------------------------------------===// |
147 | // The SparseTensorConversion pass. |
148 | //===----------------------------------------------------------------------===// |
149 | |
150 | /// Sparse tensor type converter into an opaque pointer. |
151 | class SparseTensorTypeToPtrConverter : public TypeConverter { |
152 | public: |
153 | SparseTensorTypeToPtrConverter(); |
154 | }; |
155 | |
156 | /// Sets up sparse tensor conversion rules. |
157 | void populateSparseTensorConversionPatterns(TypeConverter &typeConverter, |
158 | RewritePatternSet &patterns); |
159 | |
160 | std::unique_ptr<Pass> createSparseTensorConversionPass(); |
161 | |
162 | //===----------------------------------------------------------------------===// |
163 | // The SparseTensorCodegen pass. |
164 | //===----------------------------------------------------------------------===// |
165 | |
166 | /// Sparse tensor type converter into an actual buffer. |
167 | class SparseTensorTypeToBufferConverter : public TypeConverter { |
168 | public: |
169 | SparseTensorTypeToBufferConverter(); |
170 | }; |
171 | |
172 | /// Sets up sparse tensor codegen rules. |
173 | void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, |
174 | RewritePatternSet &patterns, |
175 | bool createSparseDeallocs, |
176 | bool enableBufferInitialization); |
177 | |
178 | std::unique_ptr<Pass> createSparseTensorCodegenPass(); |
179 | std::unique_ptr<Pass> |
180 | createSparseTensorCodegenPass(bool createSparseDeallocs, |
181 | bool enableBufferInitialization); |
182 | |
183 | //===----------------------------------------------------------------------===// |
184 | // The SparseBufferRewrite pass. |
185 | //===----------------------------------------------------------------------===// |
186 | |
187 | void populateSparseBufferRewriting(RewritePatternSet &patterns, |
188 | bool enableBufferInitialization); |
189 | |
190 | std::unique_ptr<Pass> createSparseBufferRewritePass(); |
191 | std::unique_ptr<Pass> |
192 | createSparseBufferRewritePass(bool enableBufferInitialization); |
193 | |
194 | //===----------------------------------------------------------------------===// |
195 | // The SparseVectorization pass. |
196 | //===----------------------------------------------------------------------===// |
197 | |
198 | void populateSparseVectorizationPatterns(RewritePatternSet &patterns, |
199 | unsigned vectorLength, |
200 | bool enableVLAVectorization, |
201 | bool enableSIMDIndex32); |
202 | |
203 | std::unique_ptr<Pass> createSparseVectorizationPass(); |
204 | std::unique_ptr<Pass> createSparseVectorizationPass(unsigned vectorLength, |
205 | bool enableVLAVectorization, |
206 | bool enableSIMDIndex32); |
207 | |
208 | //===----------------------------------------------------------------------===// |
209 | // The SparseGPU pass. |
210 | //===----------------------------------------------------------------------===// |
211 | |
212 | void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, |
213 | unsigned numThreads); |
214 | |
215 | void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, |
216 | bool enableRT); |
217 | |
218 | std::unique_ptr<Pass> createSparseGPUCodegenPass(); |
219 | std::unique_ptr<Pass> createSparseGPUCodegenPass(unsigned numThreads, |
220 | bool enableRT); |
221 | |
222 | //===----------------------------------------------------------------------===// |
223 | // The SparseStorageSpecifierToLLVM pass. |
224 | //===----------------------------------------------------------------------===// |
225 | |
226 | class StorageSpecifierToLLVMTypeConverter : public TypeConverter { |
227 | public: |
228 | StorageSpecifierToLLVMTypeConverter(); |
229 | }; |
230 | |
231 | void populateStorageSpecifierToLLVMPatterns(TypeConverter &converter, |
232 | RewritePatternSet &patterns); |
233 | std::unique_ptr<Pass> createStorageSpecifierToLLVMPass(); |
234 | |
235 | //===----------------------------------------------------------------------===// |
236 | // The mini-pipeline for sparsification and bufferization. |
237 | //===----------------------------------------------------------------------===// |
238 | |
239 | bufferization::OneShotBufferizationOptions |
240 | getBufferizationOptionsForSparsification(bool analysisOnly); |
241 | |
242 | std::unique_ptr<Pass> createSparsificationAndBufferizationPass(); |
243 | |
244 | std::unique_ptr<Pass> createSparsificationAndBufferizationPass( |
245 | const bufferization::OneShotBufferizationOptions &bufferizationOptions, |
246 | const SparsificationOptions &sparsificationOptions, |
247 | bool createSparseDeallocs, bool enableRuntimeLibrary, |
248 | bool enableBufferInitialization, unsigned vectorLength, |
249 | bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen); |
250 | |
251 | //===----------------------------------------------------------------------===// |
252 | // Registration. |
253 | //===----------------------------------------------------------------------===// |
254 | |
255 | /// Generate the code for registering passes. |
256 | #define GEN_PASS_REGISTRATION |
257 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" |
258 | |
259 | } // namespace mlir |
260 | |
261 | #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES_H_ |
262 | |