1 | //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
10 | #include "mlir/Dialect/Arith/IR/Arith.h" |
11 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
12 | #include "mlir/Dialect/Complex/IR/Complex.h" |
13 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
14 | #include "mlir/Dialect/Func/Transforms/FuncConversions.h" |
15 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
16 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
17 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
18 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
19 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
20 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
21 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
22 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
23 | |
24 | namespace mlir { |
25 | #define GEN_PASS_DEF_SPARSEASSEMBLER |
26 | #define GEN_PASS_DEF_SPARSEREINTERPRETMAP |
27 | #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE |
28 | #define GEN_PASS_DEF_SPARSIFICATIONPASS |
29 | #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH |
30 | #define GEN_PASS_DEF_LOWERFOREACHTOSCF |
31 | #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS |
32 | #define GEN_PASS_DEF_SPARSETENSORCODEGEN |
33 | #define GEN_PASS_DEF_SPARSEBUFFERREWRITE |
34 | #define GEN_PASS_DEF_SPARSEVECTORIZATION |
35 | #define GEN_PASS_DEF_SPARSEGPUCODEGEN |
36 | #define GEN_PASS_DEF_STAGESPARSEOPERATIONS |
37 | #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM |
38 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" |
39 | } // namespace mlir |
40 | |
41 | using namespace mlir; |
42 | using namespace mlir::sparse_tensor; |
43 | |
44 | namespace { |
45 | |
46 | //===----------------------------------------------------------------------===// |
47 | // Passes implementation. |
48 | //===----------------------------------------------------------------------===// |
49 | |
50 | struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> { |
51 | SparseAssembler() = default; |
52 | SparseAssembler(const SparseAssembler &pass) = default; |
53 | SparseAssembler(bool dO) { directOut = dO; } |
54 | |
55 | void runOnOperation() override { |
56 | auto *ctx = &getContext(); |
57 | RewritePatternSet patterns(ctx); |
58 | populateSparseAssembler(patterns, directOut); |
59 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
60 | } |
61 | }; |
62 | |
63 | struct SparseReinterpretMap |
64 | : public impl::SparseReinterpretMapBase<SparseReinterpretMap> { |
65 | SparseReinterpretMap() = default; |
66 | SparseReinterpretMap(const SparseReinterpretMap &pass) = default; |
67 | SparseReinterpretMap(const SparseReinterpretMapOptions &options) { |
68 | scope = options.scope; |
69 | } |
70 | |
71 | void runOnOperation() override { |
72 | auto *ctx = &getContext(); |
73 | RewritePatternSet patterns(ctx); |
74 | populateSparseReinterpretMap(patterns, scope); |
75 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
76 | } |
77 | }; |
78 | |
79 | struct PreSparsificationRewritePass |
80 | : public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> { |
81 | PreSparsificationRewritePass() = default; |
82 | PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) = |
83 | default; |
84 | |
85 | void runOnOperation() override { |
86 | auto *ctx = &getContext(); |
87 | RewritePatternSet patterns(ctx); |
88 | populatePreSparsificationRewriting(patterns); |
89 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
90 | } |
91 | }; |
92 | |
93 | struct SparsificationPass |
94 | : public impl::SparsificationPassBase<SparsificationPass> { |
95 | SparsificationPass() = default; |
96 | SparsificationPass(const SparsificationPass &pass) = default; |
97 | SparsificationPass(const SparsificationOptions &options) { |
98 | parallelization = options.parallelizationStrategy; |
99 | sparseEmitStrategy = options.sparseEmitStrategy; |
100 | enableRuntimeLibrary = options.enableRuntimeLibrary; |
101 | } |
102 | |
103 | void runOnOperation() override { |
104 | auto *ctx = &getContext(); |
105 | // Translate strategy flags to strategy options. |
106 | SparsificationOptions options(parallelization, sparseEmitStrategy, |
107 | enableRuntimeLibrary); |
108 | // Apply sparsification and cleanup rewriting. |
109 | RewritePatternSet patterns(ctx); |
110 | populateSparsificationPatterns(patterns, options); |
111 | scf::ForOp::getCanonicalizationPatterns(patterns, ctx); |
112 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
113 | } |
114 | }; |
115 | |
116 | struct StageSparseOperationsPass |
117 | : public impl::StageSparseOperationsBase<StageSparseOperationsPass> { |
118 | StageSparseOperationsPass() = default; |
119 | StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default; |
120 | void runOnOperation() override { |
121 | auto *ctx = &getContext(); |
122 | RewritePatternSet patterns(ctx); |
123 | populateStageSparseOperationsPatterns(patterns); |
124 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
125 | } |
126 | }; |
127 | |
128 | struct LowerSparseOpsToForeachPass |
129 | : public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> { |
130 | LowerSparseOpsToForeachPass() = default; |
131 | LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) = |
132 | default; |
133 | LowerSparseOpsToForeachPass(bool enableRT, bool convert) { |
134 | enableRuntimeLibrary = enableRT; |
135 | enableConvert = convert; |
136 | } |
137 | |
138 | void runOnOperation() override { |
139 | auto *ctx = &getContext(); |
140 | RewritePatternSet patterns(ctx); |
141 | populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary, |
142 | enableConvert); |
143 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
144 | } |
145 | }; |
146 | |
147 | struct LowerForeachToSCFPass |
148 | : public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> { |
149 | LowerForeachToSCFPass() = default; |
150 | LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default; |
151 | |
152 | void runOnOperation() override { |
153 | auto *ctx = &getContext(); |
154 | RewritePatternSet patterns(ctx); |
155 | populateLowerForeachToSCFPatterns(patterns); |
156 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
157 | } |
158 | }; |
159 | |
160 | struct SparseTensorConversionPass |
161 | : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> { |
162 | SparseTensorConversionPass() = default; |
163 | SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default; |
164 | |
165 | void runOnOperation() override { |
166 | auto *ctx = &getContext(); |
167 | RewritePatternSet patterns(ctx); |
168 | SparseTensorTypeToPtrConverter converter; |
169 | ConversionTarget target(*ctx); |
170 | // Everything in the sparse dialect must go! |
171 | target.addIllegalDialect<SparseTensorDialect>(); |
172 | // All dynamic rules below accept new function, call, return, and various |
173 | // tensor and bufferization operations as legal output of the rewriting |
174 | // provided that all sparse tensor types have been fully rewritten. |
175 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
176 | return converter.isSignatureLegal(op.getFunctionType()); |
177 | }); |
178 | target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { |
179 | return converter.isSignatureLegal(op.getCalleeType()); |
180 | }); |
181 | target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { |
182 | return converter.isLegal(op.getOperandTypes()); |
183 | }); |
184 | target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) { |
185 | return converter.isLegal(op.getOperandTypes()); |
186 | }); |
187 | target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) { |
188 | return converter.isLegal(op.getSource().getType()) && |
189 | converter.isLegal(op.getDest().getType()); |
190 | }); |
191 | target.addDynamicallyLegalOp<tensor::ExpandShapeOp>( |
192 | [&](tensor::ExpandShapeOp op) { |
193 | return converter.isLegal(op.getSrc().getType()) && |
194 | converter.isLegal(op.getResult().getType()); |
195 | }); |
196 | target.addDynamicallyLegalOp<tensor::CollapseShapeOp>( |
197 | [&](tensor::CollapseShapeOp op) { |
198 | return converter.isLegal(op.getSrc().getType()) && |
199 | converter.isLegal(op.getResult().getType()); |
200 | }); |
201 | target.addDynamicallyLegalOp<bufferization::AllocTensorOp>( |
202 | [&](bufferization::AllocTensorOp op) { |
203 | return converter.isLegal(op.getType()); |
204 | }); |
205 | target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>( |
206 | [&](bufferization::DeallocTensorOp op) { |
207 | return converter.isLegal(op.getTensor().getType()); |
208 | }); |
209 | // The following operations and dialects may be introduced by the |
210 | // rewriting rules, and are therefore marked as legal. |
211 | target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp, |
212 | linalg::YieldOp, tensor::ExtractOp, |
213 | tensor::FromElementsOp>(); |
214 | target.addLegalDialect< |
215 | arith::ArithDialect, bufferization::BufferizationDialect, |
216 | LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>(); |
217 | |
218 | // Populate with rules and apply rewriting rules. |
219 | populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, |
220 | converter); |
221 | populateCallOpTypeConversionPattern(patterns, converter); |
222 | scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter&: converter, patterns, |
223 | target); |
224 | populateSparseTensorConversionPatterns(typeConverter&: converter, patterns); |
225 | if (failed(applyPartialConversion(getOperation(), target, |
226 | std::move(patterns)))) |
227 | signalPassFailure(); |
228 | } |
229 | }; |
230 | |
231 | struct SparseTensorCodegenPass |
232 | : public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> { |
233 | SparseTensorCodegenPass() = default; |
234 | SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default; |
235 | SparseTensorCodegenPass(bool createDeallocs, bool enableInit) { |
236 | createSparseDeallocs = createDeallocs; |
237 | enableBufferInitialization = enableInit; |
238 | } |
239 | |
240 | void runOnOperation() override { |
241 | auto *ctx = &getContext(); |
242 | RewritePatternSet patterns(ctx); |
243 | SparseTensorTypeToBufferConverter converter; |
244 | ConversionTarget target(*ctx); |
245 | // Most ops in the sparse dialect must go! |
246 | target.addIllegalDialect<SparseTensorDialect>(); |
247 | target.addLegalOp<SortOp>(); |
248 | target.addLegalOp<PushBackOp>(); |
249 | // Storage specifier outlives sparse tensor pipeline. |
250 | target.addLegalOp<GetStorageSpecifierOp>(); |
251 | target.addLegalOp<SetStorageSpecifierOp>(); |
252 | target.addLegalOp<StorageSpecifierInitOp>(); |
253 | // Note that tensor::FromElementsOp might be yield after lowering unpack. |
254 | target.addLegalOp<tensor::FromElementsOp>(); |
255 | // All dynamic rules below accept new function, call, return, and |
256 | // various tensor and bufferization operations as legal output of the |
257 | // rewriting provided that all sparse tensor types have been fully |
258 | // rewritten. |
259 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
260 | return converter.isSignatureLegal(op.getFunctionType()); |
261 | }); |
262 | target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { |
263 | return converter.isSignatureLegal(op.getCalleeType()); |
264 | }); |
265 | target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { |
266 | return converter.isLegal(op.getOperandTypes()); |
267 | }); |
268 | target.addDynamicallyLegalOp<bufferization::AllocTensorOp>( |
269 | [&](bufferization::AllocTensorOp op) { |
270 | return converter.isLegal(op.getType()); |
271 | }); |
272 | target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>( |
273 | [&](bufferization::DeallocTensorOp op) { |
274 | return converter.isLegal(op.getTensor().getType()); |
275 | }); |
276 | // The following operations and dialects may be introduced by the |
277 | // codegen rules, and are therefore marked as legal. |
278 | target.addLegalOp<linalg::FillOp, linalg::YieldOp>(); |
279 | target.addLegalDialect< |
280 | arith::ArithDialect, bufferization::BufferizationDialect, |
281 | complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>(); |
282 | target.addLegalOp<UnrealizedConversionCastOp>(); |
283 | // Populate with rules and apply rewriting rules. |
284 | populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, |
285 | converter); |
286 | scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter&: converter, patterns, |
287 | target); |
288 | populateSparseTensorCodegenPatterns( |
289 | converter, patterns, createSparseDeallocs, enableBufferInitialization); |
290 | if (failed(applyPartialConversion(getOperation(), target, |
291 | std::move(patterns)))) |
292 | signalPassFailure(); |
293 | } |
294 | }; |
295 | |
296 | struct SparseBufferRewritePass |
297 | : public impl::SparseBufferRewriteBase<SparseBufferRewritePass> { |
298 | SparseBufferRewritePass() = default; |
299 | SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default; |
300 | SparseBufferRewritePass(bool enableInit) { |
301 | enableBufferInitialization = enableInit; |
302 | } |
303 | |
304 | void runOnOperation() override { |
305 | auto *ctx = &getContext(); |
306 | RewritePatternSet patterns(ctx); |
307 | populateSparseBufferRewriting(patterns, enableBufferInitialization); |
308 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
309 | } |
310 | }; |
311 | |
312 | struct SparseVectorizationPass |
313 | : public impl::SparseVectorizationBase<SparseVectorizationPass> { |
314 | SparseVectorizationPass() = default; |
315 | SparseVectorizationPass(const SparseVectorizationPass &pass) = default; |
316 | SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) { |
317 | vectorLength = vl; |
318 | enableVLAVectorization = vla; |
319 | enableSIMDIndex32 = sidx32; |
320 | } |
321 | |
322 | void runOnOperation() override { |
323 | if (vectorLength == 0) |
324 | return signalPassFailure(); |
325 | auto *ctx = &getContext(); |
326 | RewritePatternSet patterns(ctx); |
327 | populateSparseVectorizationPatterns( |
328 | patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32); |
329 | vector::populateVectorToVectorCanonicalizationPatterns(patterns); |
330 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
331 | } |
332 | }; |
333 | |
334 | struct SparseGPUCodegenPass |
335 | : public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> { |
336 | SparseGPUCodegenPass() = default; |
337 | SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default; |
338 | SparseGPUCodegenPass(unsigned nT, bool enableRT) { |
339 | numThreads = nT; |
340 | enableRuntimeLibrary = enableRT; |
341 | } |
342 | |
343 | void runOnOperation() override { |
344 | auto *ctx = &getContext(); |
345 | RewritePatternSet patterns(ctx); |
346 | if (numThreads == 0) |
347 | populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary); |
348 | else |
349 | populateSparseGPUCodegenPatterns(patterns, numThreads); |
350 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
351 | } |
352 | }; |
353 | |
354 | struct StorageSpecifierToLLVMPass |
355 | : public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> { |
356 | StorageSpecifierToLLVMPass() = default; |
357 | |
358 | void runOnOperation() override { |
359 | auto *ctx = &getContext(); |
360 | ConversionTarget target(*ctx); |
361 | RewritePatternSet patterns(ctx); |
362 | StorageSpecifierToLLVMTypeConverter converter; |
363 | |
364 | // All ops in the sparse dialect must go! |
365 | target.addIllegalDialect<SparseTensorDialect>(); |
366 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
367 | return converter.isSignatureLegal(op.getFunctionType()); |
368 | }); |
369 | target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { |
370 | return converter.isSignatureLegal(op.getCalleeType()); |
371 | }); |
372 | target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { |
373 | return converter.isLegal(op.getOperandTypes()); |
374 | }); |
375 | target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>(); |
376 | |
377 | populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, |
378 | converter); |
379 | populateCallOpTypeConversionPattern(patterns, converter); |
380 | populateBranchOpInterfaceTypeConversionPattern(patterns, converter); |
381 | populateReturnOpTypeConversionPattern(patterns, converter); |
382 | scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter&: converter, patterns, |
383 | target); |
384 | populateStorageSpecifierToLLVMPatterns(converter, patterns); |
385 | if (failed(applyPartialConversion(getOperation(), target, |
386 | std::move(patterns)))) |
387 | signalPassFailure(); |
388 | } |
389 | }; |
390 | |
391 | } // namespace |
392 | |
393 | //===----------------------------------------------------------------------===// |
394 | // Pass creation methods. |
395 | //===----------------------------------------------------------------------===// |
396 | |
397 | std::unique_ptr<Pass> mlir::createSparseAssembler() { |
398 | return std::make_unique<SparseAssembler>(); |
399 | } |
400 | |
401 | std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() { |
402 | return std::make_unique<SparseReinterpretMap>(); |
403 | } |
404 | |
405 | std::unique_ptr<Pass> |
406 | mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) { |
407 | SparseReinterpretMapOptions options; |
408 | options.scope = scope; |
409 | return std::make_unique<SparseReinterpretMap>(options); |
410 | } |
411 | |
412 | std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() { |
413 | return std::make_unique<PreSparsificationRewritePass>(); |
414 | } |
415 | |
416 | std::unique_ptr<Pass> mlir::createSparsificationPass() { |
417 | return std::make_unique<SparsificationPass>(); |
418 | } |
419 | |
420 | std::unique_ptr<Pass> |
421 | mlir::createSparsificationPass(const SparsificationOptions &options) { |
422 | return std::make_unique<SparsificationPass>(args: options); |
423 | } |
424 | |
425 | std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() { |
426 | return std::make_unique<StageSparseOperationsPass>(); |
427 | } |
428 | |
429 | std::unique_ptr<Pass> mlir::createLowerSparseOpsToForeachPass() { |
430 | return std::make_unique<LowerSparseOpsToForeachPass>(); |
431 | } |
432 | |
433 | std::unique_ptr<Pass> |
434 | mlir::createLowerSparseOpsToForeachPass(bool enableRT, bool enableConvert) { |
435 | return std::make_unique<LowerSparseOpsToForeachPass>(args&: enableRT, args&: enableConvert); |
436 | } |
437 | |
438 | std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() { |
439 | return std::make_unique<LowerForeachToSCFPass>(); |
440 | } |
441 | |
442 | std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() { |
443 | return std::make_unique<SparseTensorConversionPass>(); |
444 | } |
445 | |
446 | std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() { |
447 | return std::make_unique<SparseTensorCodegenPass>(); |
448 | } |
449 | |
450 | std::unique_ptr<Pass> |
451 | mlir::createSparseTensorCodegenPass(bool createSparseDeallocs, |
452 | bool enableBufferInitialization) { |
453 | return std::make_unique<SparseTensorCodegenPass>(args&: createSparseDeallocs, |
454 | args&: enableBufferInitialization); |
455 | } |
456 | |
457 | std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() { |
458 | return std::make_unique<SparseBufferRewritePass>(); |
459 | } |
460 | |
461 | std::unique_ptr<Pass> |
462 | mlir::createSparseBufferRewritePass(bool enableBufferInitialization) { |
463 | return std::make_unique<SparseBufferRewritePass>(args&: enableBufferInitialization); |
464 | } |
465 | |
466 | std::unique_ptr<Pass> mlir::createSparseVectorizationPass() { |
467 | return std::make_unique<SparseVectorizationPass>(); |
468 | } |
469 | |
470 | std::unique_ptr<Pass> |
471 | mlir::createSparseVectorizationPass(unsigned vectorLength, |
472 | bool enableVLAVectorization, |
473 | bool enableSIMDIndex32) { |
474 | return std::make_unique<SparseVectorizationPass>( |
475 | args&: vectorLength, args&: enableVLAVectorization, args&: enableSIMDIndex32); |
476 | } |
477 | |
478 | std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass() { |
479 | return std::make_unique<SparseGPUCodegenPass>(); |
480 | } |
481 | |
482 | std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass(unsigned numThreads, |
483 | bool enableRT) { |
484 | return std::make_unique<SparseGPUCodegenPass>(args&: numThreads, args&: enableRT); |
485 | } |
486 | |
487 | std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() { |
488 | return std::make_unique<StorageSpecifierToLLVMPass>(); |
489 | } |
490 | |