1//===- Passes.h - Sparse tensor pass entry points ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header file defines prototypes of all sparse tensor passes.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES_H_
14#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES_H_
15
16#include "mlir/IR/PatternMatch.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/DialectConversion.h"
19
20//===----------------------------------------------------------------------===//
21// Include the generated pass header (which needs some early definitions).
22//===----------------------------------------------------------------------===//
23
24namespace mlir {
25
26namespace bufferization {
27struct OneShotBufferizationOptions;
28} // namespace bufferization
29
30/// Defines a parallelization strategy. Any independent loop is a candidate
31/// for parallelization. The loop is made parallel if (1) allowed by the
32/// strategy (e.g., AnyStorageOuterLoop considers either a dense or sparse
33/// outermost loop only), and (2) the generated code is an actual for-loop
34/// (and not a co-iterating while-loop).
35enum class SparseParallelizationStrategy {
36 kNone,
37 kDenseOuterLoop,
38 kAnyStorageOuterLoop,
39 kDenseAnyLoop,
40 kAnyStorageAnyLoop
41};
42
43/// Defines a scope for reinterpret map pass.
44enum class ReinterpretMapScope {
45 kAll, // reinterprets all applicable operations
46 kGenericOnly, // reinterprets only linalg.generic
47 kExceptGeneric, // reinterprets operation other than linalg.generic
48};
49
50/// Defines a scope for reinterpret map pass.
51enum class SparseEmitStrategy {
52 kFunctional, // generate fully inlined (and functional) sparse iteration
53 kDebugInterface, // generate only place-holder for sparse iteration
54};
55
56#define GEN_PASS_DECL
57#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
58
59//===----------------------------------------------------------------------===//
60// The SparseAssembler pass.
61//===----------------------------------------------------------------------===//
62
63void populateSparseAssembler(RewritePatternSet &patterns, bool directOut);
64
65std::unique_ptr<Pass> createSparseAssembler();
66std::unique_ptr<Pass> createSparseAssembler(bool directOut);
67
68//===----------------------------------------------------------------------===//
69// The SparseReinterpretMap pass.
70//===----------------------------------------------------------------------===//
71
72void populateSparseReinterpretMap(RewritePatternSet &patterns,
73 ReinterpretMapScope scope);
74
75std::unique_ptr<Pass> createSparseReinterpretMapPass();
76std::unique_ptr<Pass> createSparseReinterpretMapPass(ReinterpretMapScope scope);
77
78//===----------------------------------------------------------------------===//
79// The PreSparsificationRewriting pass.
80//===----------------------------------------------------------------------===//
81
82void populatePreSparsificationRewriting(RewritePatternSet &patterns);
83
84std::unique_ptr<Pass> createPreSparsificationRewritePass();
85
86//===----------------------------------------------------------------------===//
87// The Sparsification pass.
88//===----------------------------------------------------------------------===//
89
90/// Options for the Sparsification pass.
91struct SparsificationOptions {
92 SparsificationOptions(SparseParallelizationStrategy p, SparseEmitStrategy d,
93 bool enableRT)
94 : parallelizationStrategy(p), sparseEmitStrategy(d),
95 enableRuntimeLibrary(enableRT) {}
96
97 SparsificationOptions(SparseParallelizationStrategy p, bool enableRT)
98 : SparsificationOptions(p, SparseEmitStrategy::kFunctional, enableRT) {}
99
100 SparsificationOptions()
101 : SparsificationOptions(SparseParallelizationStrategy::kNone,
102 SparseEmitStrategy::kFunctional, true) {}
103
104 SparseParallelizationStrategy parallelizationStrategy;
105 SparseEmitStrategy sparseEmitStrategy;
106 bool enableRuntimeLibrary;
107};
108
109/// Sets up sparsification rewriting rules with the given options.
110void populateSparsificationPatterns(
111 RewritePatternSet &patterns,
112 const SparsificationOptions &options = SparsificationOptions());
113
114std::unique_ptr<Pass> createSparsificationPass();
115std::unique_ptr<Pass>
116createSparsificationPass(const SparsificationOptions &options);
117
118//===----------------------------------------------------------------------===//
119// The StageSparseOperations pass.
120//===----------------------------------------------------------------------===//
121
122/// Sets up StageSparseOperation rewriting rules.
123void populateStageSparseOperationsPatterns(RewritePatternSet &patterns);
124
125std::unique_ptr<Pass> createStageSparseOperationsPass();
126
127//===----------------------------------------------------------------------===//
128// The LowerSparseOpsToForeach pass.
129//===----------------------------------------------------------------------===//
130
131void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
132 bool enableRT, bool enableConvert);
133
134std::unique_ptr<Pass> createLowerSparseOpsToForeachPass();
135std::unique_ptr<Pass> createLowerSparseOpsToForeachPass(bool enableRT,
136 bool enableConvert);
137
138//===----------------------------------------------------------------------===//
139// The LowerForeachToSCF pass.
140//===----------------------------------------------------------------------===//
141
142void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns);
143
144std::unique_ptr<Pass> createLowerForeachToSCFPass();
145
146//===----------------------------------------------------------------------===//
147// The SparseTensorConversion pass.
148//===----------------------------------------------------------------------===//
149
150/// Sparse tensor type converter into an opaque pointer.
151class SparseTensorTypeToPtrConverter : public TypeConverter {
152public:
153 SparseTensorTypeToPtrConverter();
154};
155
156/// Sets up sparse tensor conversion rules.
157void populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
158 RewritePatternSet &patterns);
159
160std::unique_ptr<Pass> createSparseTensorConversionPass();
161
162//===----------------------------------------------------------------------===//
163// The SparseTensorCodegen pass.
164//===----------------------------------------------------------------------===//
165
166/// Sparse tensor type converter into an actual buffer.
167class SparseTensorTypeToBufferConverter : public TypeConverter {
168public:
169 SparseTensorTypeToBufferConverter();
170};
171
172/// Sets up sparse tensor codegen rules.
173void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
174 RewritePatternSet &patterns,
175 bool createSparseDeallocs,
176 bool enableBufferInitialization);
177
178std::unique_ptr<Pass> createSparseTensorCodegenPass();
179std::unique_ptr<Pass>
180createSparseTensorCodegenPass(bool createSparseDeallocs,
181 bool enableBufferInitialization);
182
183//===----------------------------------------------------------------------===//
184// The SparseBufferRewrite pass.
185//===----------------------------------------------------------------------===//
186
187void populateSparseBufferRewriting(RewritePatternSet &patterns,
188 bool enableBufferInitialization);
189
190std::unique_ptr<Pass> createSparseBufferRewritePass();
191std::unique_ptr<Pass>
192createSparseBufferRewritePass(bool enableBufferInitialization);
193
194//===----------------------------------------------------------------------===//
195// The SparseVectorization pass.
196//===----------------------------------------------------------------------===//
197
198void populateSparseVectorizationPatterns(RewritePatternSet &patterns,
199 unsigned vectorLength,
200 bool enableVLAVectorization,
201 bool enableSIMDIndex32);
202
203std::unique_ptr<Pass> createSparseVectorizationPass();
204std::unique_ptr<Pass> createSparseVectorizationPass(unsigned vectorLength,
205 bool enableVLAVectorization,
206 bool enableSIMDIndex32);
207
208//===----------------------------------------------------------------------===//
209// The SparseGPU pass.
210//===----------------------------------------------------------------------===//
211
212void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns,
213 unsigned numThreads);
214
215void populateSparseGPULibgenPatterns(RewritePatternSet &patterns,
216 bool enableRT);
217
218std::unique_ptr<Pass> createSparseGPUCodegenPass();
219std::unique_ptr<Pass> createSparseGPUCodegenPass(unsigned numThreads,
220 bool enableRT);
221
222//===----------------------------------------------------------------------===//
223// The SparseStorageSpecifierToLLVM pass.
224//===----------------------------------------------------------------------===//
225
226class StorageSpecifierToLLVMTypeConverter : public TypeConverter {
227public:
228 StorageSpecifierToLLVMTypeConverter();
229};
230
231void populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
232 RewritePatternSet &patterns);
233std::unique_ptr<Pass> createStorageSpecifierToLLVMPass();
234
235//===----------------------------------------------------------------------===//
236// The mini-pipeline for sparsification and bufferization.
237//===----------------------------------------------------------------------===//
238
239bufferization::OneShotBufferizationOptions
240getBufferizationOptionsForSparsification(bool analysisOnly);
241
242std::unique_ptr<Pass> createSparsificationAndBufferizationPass();
243
244std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
245 const bufferization::OneShotBufferizationOptions &bufferizationOptions,
246 const SparsificationOptions &sparsificationOptions,
247 bool createSparseDeallocs, bool enableRuntimeLibrary,
248 bool enableBufferInitialization, unsigned vectorLength,
249 bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen);
250
251//===----------------------------------------------------------------------===//
252// Registration.
253//===----------------------------------------------------------------------===//
254
255/// Generate the code for registering passes.
256#define GEN_PASS_REGISTRATION
257#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
258
259} // namespace mlir
260
261#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES_H_
262

source code of mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h