1 | //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
10 | #include "mlir/Dialect/Arith/IR/Arith.h" |
11 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
12 | #include "mlir/Dialect/Complex/IR/Complex.h" |
13 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
14 | #include "mlir/Dialect/Func/Transforms/FuncConversions.h" |
15 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
16 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
17 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
18 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
19 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
20 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
21 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
22 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
23 | |
24 | namespace mlir { |
25 | #define GEN_PASS_DEF_SPARSEASSEMBLER |
26 | #define GEN_PASS_DEF_SPARSEREINTERPRETMAP |
27 | #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE |
28 | #define GEN_PASS_DEF_SPARSIFICATIONPASS |
29 | #define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF |
30 | #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH |
31 | #define GEN_PASS_DEF_LOWERFOREACHTOSCF |
32 | #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS |
33 | #define GEN_PASS_DEF_SPARSETENSORCODEGEN |
34 | #define GEN_PASS_DEF_SPARSEBUFFERREWRITE |
35 | #define GEN_PASS_DEF_SPARSEVECTORIZATION |
36 | #define GEN_PASS_DEF_SPARSEGPUCODEGEN |
37 | #define GEN_PASS_DEF_STAGESPARSEOPERATIONS |
38 | #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM |
39 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" |
40 | } // namespace mlir |
41 | |
42 | using namespace mlir; |
43 | using namespace mlir::sparse_tensor; |
44 | |
45 | namespace { |
46 | |
47 | //===----------------------------------------------------------------------===// |
48 | // Passes implementation. |
49 | //===----------------------------------------------------------------------===// |
50 | |
51 | struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> { |
52 | SparseAssembler() = default; |
53 | SparseAssembler(const SparseAssembler &pass) = default; |
54 | SparseAssembler(bool dO) { directOut = dO; } |
55 | |
56 | void runOnOperation() override { |
57 | auto *ctx = &getContext(); |
58 | RewritePatternSet patterns(ctx); |
59 | populateSparseAssembler(patterns, directOut); |
60 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
61 | } |
62 | }; |
63 | |
64 | struct SparseReinterpretMap |
65 | : public impl::SparseReinterpretMapBase<SparseReinterpretMap> { |
66 | SparseReinterpretMap() = default; |
67 | SparseReinterpretMap(const SparseReinterpretMap &pass) = default; |
68 | SparseReinterpretMap(const SparseReinterpretMapOptions &options) { |
69 | scope = options.scope; |
70 | } |
71 | |
72 | void runOnOperation() override { |
73 | auto *ctx = &getContext(); |
74 | RewritePatternSet patterns(ctx); |
75 | populateSparseReinterpretMap(patterns, scope); |
76 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
77 | } |
78 | }; |
79 | |
80 | struct PreSparsificationRewritePass |
81 | : public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> { |
82 | PreSparsificationRewritePass() = default; |
83 | PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) = |
84 | default; |
85 | |
86 | void runOnOperation() override { |
87 | auto *ctx = &getContext(); |
88 | RewritePatternSet patterns(ctx); |
89 | populatePreSparsificationRewriting(patterns); |
90 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
91 | } |
92 | }; |
93 | |
94 | struct SparsificationPass |
95 | : public impl::SparsificationPassBase<SparsificationPass> { |
96 | SparsificationPass() = default; |
97 | SparsificationPass(const SparsificationPass &pass) = default; |
98 | SparsificationPass(const SparsificationOptions &options) { |
99 | parallelization = options.parallelizationStrategy; |
100 | sparseEmitStrategy = options.sparseEmitStrategy; |
101 | enableRuntimeLibrary = options.enableRuntimeLibrary; |
102 | } |
103 | |
104 | void runOnOperation() override { |
105 | auto *ctx = &getContext(); |
106 | // Translate strategy flags to strategy options. |
107 | SparsificationOptions options(parallelization, sparseEmitStrategy, |
108 | enableRuntimeLibrary); |
109 | // Apply sparsification and cleanup rewriting. |
110 | RewritePatternSet patterns(ctx); |
111 | populateSparsificationPatterns(patterns, options); |
112 | scf::ForOp::getCanonicalizationPatterns(patterns, ctx); |
113 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
114 | } |
115 | }; |
116 | |
117 | struct StageSparseOperationsPass |
118 | : public impl::StageSparseOperationsBase<StageSparseOperationsPass> { |
119 | StageSparseOperationsPass() = default; |
120 | StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default; |
121 | void runOnOperation() override { |
122 | auto *ctx = &getContext(); |
123 | RewritePatternSet patterns(ctx); |
124 | populateStageSparseOperationsPatterns(patterns); |
125 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
126 | } |
127 | }; |
128 | |
129 | struct LowerSparseOpsToForeachPass |
130 | : public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> { |
131 | LowerSparseOpsToForeachPass() = default; |
132 | LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) = |
133 | default; |
134 | LowerSparseOpsToForeachPass(bool enableRT, bool convert) { |
135 | enableRuntimeLibrary = enableRT; |
136 | enableConvert = convert; |
137 | } |
138 | |
139 | void runOnOperation() override { |
140 | auto *ctx = &getContext(); |
141 | RewritePatternSet patterns(ctx); |
142 | populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary, |
143 | enableConvert); |
144 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
145 | } |
146 | }; |
147 | |
148 | struct LowerForeachToSCFPass |
149 | : public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> { |
150 | LowerForeachToSCFPass() = default; |
151 | LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default; |
152 | |
153 | void runOnOperation() override { |
154 | auto *ctx = &getContext(); |
155 | RewritePatternSet patterns(ctx); |
156 | populateLowerForeachToSCFPatterns(patterns); |
157 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
158 | } |
159 | }; |
160 | |
161 | struct LowerSparseIterationToSCFPass |
162 | : public impl::LowerSparseIterationToSCFBase< |
163 | LowerSparseIterationToSCFPass> { |
164 | LowerSparseIterationToSCFPass() = default; |
165 | LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) = |
166 | default; |
167 | |
168 | void runOnOperation() override { |
169 | auto *ctx = &getContext(); |
170 | RewritePatternSet patterns(ctx); |
171 | SparseIterationTypeConverter converter; |
172 | ConversionTarget target(*ctx); |
173 | |
174 | // The actual conversion. |
175 | target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect, |
176 | memref::MemRefDialect, scf::SCFDialect, |
177 | sparse_tensor::SparseTensorDialect>(); |
178 | target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp, |
179 | IterateOp>(); |
180 | target.addLegalOp<UnrealizedConversionCastOp>(); |
181 | populateLowerSparseIterationToSCFPatterns(converter, patterns); |
182 | |
183 | if (failed(applyPartialConversion(getOperation(), target, |
184 | std::move(patterns)))) |
185 | signalPassFailure(); |
186 | } |
187 | }; |
188 | |
189 | struct SparseTensorConversionPass |
190 | : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> { |
191 | SparseTensorConversionPass() = default; |
192 | SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default; |
193 | |
194 | void runOnOperation() override { |
195 | auto *ctx = &getContext(); |
196 | RewritePatternSet patterns(ctx); |
197 | SparseTensorTypeToPtrConverter converter; |
198 | ConversionTarget target(*ctx); |
199 | // Everything in the sparse dialect must go! |
200 | target.addIllegalDialect<SparseTensorDialect>(); |
201 | // All dynamic rules below accept new function, call, return, and various |
202 | // tensor and bufferization operations as legal output of the rewriting |
203 | // provided that all sparse tensor types have been fully rewritten. |
204 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
205 | return converter.isSignatureLegal(op.getFunctionType()); |
206 | }); |
207 | target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { |
208 | return converter.isSignatureLegal(op.getCalleeType()); |
209 | }); |
210 | target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { |
211 | return converter.isLegal(op.getOperandTypes()); |
212 | }); |
213 | target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) { |
214 | return converter.isLegal(op.getOperandTypes()); |
215 | }); |
216 | target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) { |
217 | return converter.isLegal(op.getSource().getType()) && |
218 | converter.isLegal(op.getDest().getType()); |
219 | }); |
220 | target.addDynamicallyLegalOp<tensor::ExpandShapeOp>( |
221 | [&](tensor::ExpandShapeOp op) { |
222 | return converter.isLegal(op.getSrc().getType()) && |
223 | converter.isLegal(op.getResult().getType()); |
224 | }); |
225 | target.addDynamicallyLegalOp<tensor::CollapseShapeOp>( |
226 | [&](tensor::CollapseShapeOp op) { |
227 | return converter.isLegal(op.getSrc().getType()) && |
228 | converter.isLegal(op.getResult().getType()); |
229 | }); |
230 | target.addDynamicallyLegalOp<bufferization::AllocTensorOp>( |
231 | [&](bufferization::AllocTensorOp op) { |
232 | return converter.isLegal(op.getType()); |
233 | }); |
234 | target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>( |
235 | [&](bufferization::DeallocTensorOp op) { |
236 | return converter.isLegal(op.getTensor().getType()); |
237 | }); |
238 | // The following operations and dialects may be introduced by the |
239 | // rewriting rules, and are therefore marked as legal. |
240 | target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp, |
241 | linalg::YieldOp, tensor::ExtractOp, |
242 | tensor::FromElementsOp>(); |
243 | target.addLegalDialect< |
244 | arith::ArithDialect, bufferization::BufferizationDialect, |
245 | LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>(); |
246 | |
247 | // Populate with rules and apply rewriting rules. |
248 | populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, |
249 | converter); |
250 | populateCallOpTypeConversionPattern(patterns, converter); |
251 | scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter: converter, patterns, |
252 | target); |
253 | populateSparseTensorConversionPatterns(typeConverter: converter, patterns); |
254 | if (failed(applyPartialConversion(getOperation(), target, |
255 | std::move(patterns)))) |
256 | signalPassFailure(); |
257 | } |
258 | }; |
259 | |
260 | struct SparseTensorCodegenPass |
261 | : public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> { |
262 | SparseTensorCodegenPass() = default; |
263 | SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default; |
264 | SparseTensorCodegenPass(bool createDeallocs, bool enableInit) { |
265 | createSparseDeallocs = createDeallocs; |
266 | enableBufferInitialization = enableInit; |
267 | } |
268 | |
269 | void runOnOperation() override { |
270 | auto *ctx = &getContext(); |
271 | RewritePatternSet patterns(ctx); |
272 | SparseTensorTypeToBufferConverter converter; |
273 | ConversionTarget target(*ctx); |
274 | // Most ops in the sparse dialect must go! |
275 | target.addIllegalDialect<SparseTensorDialect>(); |
276 | target.addLegalOp<SortOp>(); |
277 | target.addLegalOp<PushBackOp>(); |
278 | // Storage specifier outlives sparse tensor pipeline. |
279 | target.addLegalOp<GetStorageSpecifierOp>(); |
280 | target.addLegalOp<SetStorageSpecifierOp>(); |
281 | target.addLegalOp<StorageSpecifierInitOp>(); |
282 | // Note that tensor::FromElementsOp might be yield after lowering unpack. |
283 | target.addLegalOp<tensor::FromElementsOp>(); |
284 | // All dynamic rules below accept new function, call, return, and |
285 | // various tensor and bufferization operations as legal output of the |
286 | // rewriting provided that all sparse tensor types have been fully |
287 | // rewritten. |
288 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
289 | return converter.isSignatureLegal(op.getFunctionType()); |
290 | }); |
291 | target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { |
292 | return converter.isSignatureLegal(op.getCalleeType()); |
293 | }); |
294 | target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { |
295 | return converter.isLegal(op.getOperandTypes()); |
296 | }); |
297 | target.addDynamicallyLegalOp<bufferization::AllocTensorOp>( |
298 | [&](bufferization::AllocTensorOp op) { |
299 | return converter.isLegal(op.getType()); |
300 | }); |
301 | target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>( |
302 | [&](bufferization::DeallocTensorOp op) { |
303 | return converter.isLegal(op.getTensor().getType()); |
304 | }); |
305 | // The following operations and dialects may be introduced by the |
306 | // codegen rules, and are therefore marked as legal. |
307 | target.addLegalOp<linalg::FillOp, linalg::YieldOp>(); |
308 | target.addLegalDialect< |
309 | arith::ArithDialect, bufferization::BufferizationDialect, |
310 | complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>(); |
311 | target.addLegalOp<UnrealizedConversionCastOp>(); |
312 | // Populate with rules and apply rewriting rules. |
313 | populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, |
314 | converter); |
315 | scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter: converter, patterns, |
316 | target); |
317 | populateSparseTensorCodegenPatterns( |
318 | converter, patterns, createSparseDeallocs, enableBufferInitialization); |
319 | if (failed(applyPartialConversion(getOperation(), target, |
320 | std::move(patterns)))) |
321 | signalPassFailure(); |
322 | } |
323 | }; |
324 | |
325 | struct SparseBufferRewritePass |
326 | : public impl::SparseBufferRewriteBase<SparseBufferRewritePass> { |
327 | SparseBufferRewritePass() = default; |
328 | SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default; |
329 | SparseBufferRewritePass(bool enableInit) { |
330 | enableBufferInitialization = enableInit; |
331 | } |
332 | |
333 | void runOnOperation() override { |
334 | auto *ctx = &getContext(); |
335 | RewritePatternSet patterns(ctx); |
336 | populateSparseBufferRewriting(patterns, enableBufferInitialization); |
337 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
338 | } |
339 | }; |
340 | |
341 | struct SparseVectorizationPass |
342 | : public impl::SparseVectorizationBase<SparseVectorizationPass> { |
343 | SparseVectorizationPass() = default; |
344 | SparseVectorizationPass(const SparseVectorizationPass &pass) = default; |
345 | SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) { |
346 | vectorLength = vl; |
347 | enableVLAVectorization = vla; |
348 | enableSIMDIndex32 = sidx32; |
349 | } |
350 | |
351 | void runOnOperation() override { |
352 | if (vectorLength == 0) |
353 | return signalPassFailure(); |
354 | auto *ctx = &getContext(); |
355 | RewritePatternSet patterns(ctx); |
356 | populateSparseVectorizationPatterns( |
357 | patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32); |
358 | vector::populateVectorToVectorCanonicalizationPatterns(patterns); |
359 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
360 | } |
361 | }; |
362 | |
363 | struct SparseGPUCodegenPass |
364 | : public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> { |
365 | SparseGPUCodegenPass() = default; |
366 | SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default; |
367 | SparseGPUCodegenPass(unsigned nT, bool enableRT) { |
368 | numThreads = nT; |
369 | enableRuntimeLibrary = enableRT; |
370 | } |
371 | |
372 | void runOnOperation() override { |
373 | auto *ctx = &getContext(); |
374 | RewritePatternSet patterns(ctx); |
375 | if (numThreads == 0) |
376 | populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary); |
377 | else |
378 | populateSparseGPUCodegenPatterns(patterns, numThreads); |
379 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
380 | } |
381 | }; |
382 | |
383 | struct StorageSpecifierToLLVMPass |
384 | : public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> { |
385 | StorageSpecifierToLLVMPass() = default; |
386 | |
387 | void runOnOperation() override { |
388 | auto *ctx = &getContext(); |
389 | ConversionTarget target(*ctx); |
390 | RewritePatternSet patterns(ctx); |
391 | StorageSpecifierToLLVMTypeConverter converter; |
392 | |
393 | // All ops in the sparse dialect must go! |
394 | target.addIllegalDialect<SparseTensorDialect>(); |
395 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
396 | return converter.isSignatureLegal(op.getFunctionType()); |
397 | }); |
398 | target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { |
399 | return converter.isSignatureLegal(op.getCalleeType()); |
400 | }); |
401 | target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { |
402 | return converter.isLegal(op.getOperandTypes()); |
403 | }); |
404 | target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>(); |
405 | |
406 | populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, |
407 | converter); |
408 | populateCallOpTypeConversionPattern(patterns, converter); |
409 | populateBranchOpInterfaceTypeConversionPattern(patterns, converter); |
410 | populateReturnOpTypeConversionPattern(patterns, converter); |
411 | scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter: converter, patterns, |
412 | target); |
413 | populateStorageSpecifierToLLVMPatterns(converter, patterns); |
414 | if (failed(applyPartialConversion(getOperation(), target, |
415 | std::move(patterns)))) |
416 | signalPassFailure(); |
417 | } |
418 | }; |
419 | |
420 | } // namespace |
421 | |
422 | //===----------------------------------------------------------------------===// |
423 | // Pass creation methods. |
424 | //===----------------------------------------------------------------------===// |
425 | |
426 | std::unique_ptr<Pass> mlir::createSparseAssembler() { |
427 | return std::make_unique<SparseAssembler>(); |
428 | } |
429 | |
430 | std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() { |
431 | return std::make_unique<SparseReinterpretMap>(); |
432 | } |
433 | |
434 | std::unique_ptr<Pass> |
435 | mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) { |
436 | SparseReinterpretMapOptions options; |
437 | options.scope = scope; |
438 | return std::make_unique<SparseReinterpretMap>(options); |
439 | } |
440 | |
441 | std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() { |
442 | return std::make_unique<PreSparsificationRewritePass>(); |
443 | } |
444 | |
445 | std::unique_ptr<Pass> mlir::createSparsificationPass() { |
446 | return std::make_unique<SparsificationPass>(); |
447 | } |
448 | |
449 | std::unique_ptr<Pass> |
450 | mlir::createSparsificationPass(const SparsificationOptions &options) { |
451 | return std::make_unique<SparsificationPass>(args: options); |
452 | } |
453 | |
454 | std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() { |
455 | return std::make_unique<StageSparseOperationsPass>(); |
456 | } |
457 | |
458 | std::unique_ptr<Pass> mlir::createLowerSparseOpsToForeachPass() { |
459 | return std::make_unique<LowerSparseOpsToForeachPass>(); |
460 | } |
461 | |
462 | std::unique_ptr<Pass> |
463 | mlir::createLowerSparseOpsToForeachPass(bool enableRT, bool enableConvert) { |
464 | return std::make_unique<LowerSparseOpsToForeachPass>(args&: enableRT, args&: enableConvert); |
465 | } |
466 | |
467 | std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() { |
468 | return std::make_unique<LowerForeachToSCFPass>(); |
469 | } |
470 | |
471 | std::unique_ptr<Pass> mlir::createLowerSparseIterationToSCFPass() { |
472 | return std::make_unique<LowerSparseIterationToSCFPass>(); |
473 | } |
474 | |
475 | std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() { |
476 | return std::make_unique<SparseTensorConversionPass>(); |
477 | } |
478 | |
479 | std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() { |
480 | return std::make_unique<SparseTensorCodegenPass>(); |
481 | } |
482 | |
483 | std::unique_ptr<Pass> |
484 | mlir::createSparseTensorCodegenPass(bool createSparseDeallocs, |
485 | bool enableBufferInitialization) { |
486 | return std::make_unique<SparseTensorCodegenPass>(args&: createSparseDeallocs, |
487 | args&: enableBufferInitialization); |
488 | } |
489 | |
490 | std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() { |
491 | return std::make_unique<SparseBufferRewritePass>(); |
492 | } |
493 | |
494 | std::unique_ptr<Pass> |
495 | mlir::createSparseBufferRewritePass(bool enableBufferInitialization) { |
496 | return std::make_unique<SparseBufferRewritePass>(args&: enableBufferInitialization); |
497 | } |
498 | |
499 | std::unique_ptr<Pass> mlir::createSparseVectorizationPass() { |
500 | return std::make_unique<SparseVectorizationPass>(); |
501 | } |
502 | |
503 | std::unique_ptr<Pass> |
504 | mlir::createSparseVectorizationPass(unsigned vectorLength, |
505 | bool enableVLAVectorization, |
506 | bool enableSIMDIndex32) { |
507 | return std::make_unique<SparseVectorizationPass>( |
508 | args&: vectorLength, args&: enableVLAVectorization, args&: enableSIMDIndex32); |
509 | } |
510 | |
511 | std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass() { |
512 | return std::make_unique<SparseGPUCodegenPass>(); |
513 | } |
514 | |
515 | std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass(unsigned numThreads, |
516 | bool enableRT) { |
517 | return std::make_unique<SparseGPUCodegenPass>(args&: numThreads, args&: enableRT); |
518 | } |
519 | |
520 | std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() { |
521 | return std::make_unique<StorageSpecifierToLLVMPass>(); |
522 | } |
523 | |