1//===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Affine/IR/AffineOps.h"
10#include "mlir/Dialect/Arith/IR/Arith.h"
11#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12#include "mlir/Dialect/Complex/IR/Complex.h"
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15#include "mlir/Dialect/GPU/IR/GPUDialect.h"
16#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18#include "mlir/Dialect/SCF/Transforms/Patterns.h"
19#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
20#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23
24namespace mlir {
25#define GEN_PASS_DEF_SPARSEASSEMBLER
26#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
27#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
28#define GEN_PASS_DEF_SPARSIFICATIONPASS
29#define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
30#define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
31#define GEN_PASS_DEF_LOWERFOREACHTOSCF
32#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
33#define GEN_PASS_DEF_SPARSETENSORCODEGEN
34#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
35#define GEN_PASS_DEF_SPARSEVECTORIZATION
36#define GEN_PASS_DEF_SPARSEGPUCODEGEN
37#define GEN_PASS_DEF_STAGESPARSEOPERATIONS
38#define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
39#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
40} // namespace mlir
41
42using namespace mlir;
43using namespace mlir::sparse_tensor;
44
45namespace {
46
47//===----------------------------------------------------------------------===//
48// Passes implementation.
49//===----------------------------------------------------------------------===//
50
51struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
52 SparseAssembler() = default;
53 SparseAssembler(const SparseAssembler &pass) = default;
54 SparseAssembler(bool dO) { directOut = dO; }
55
56 void runOnOperation() override {
57 auto *ctx = &getContext();
58 RewritePatternSet patterns(ctx);
59 populateSparseAssembler(patterns, directOut);
60 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
61 }
62};
63
64struct SparseReinterpretMap
65 : public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
66 SparseReinterpretMap() = default;
67 SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
68 SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
69 scope = options.scope;
70 }
71
72 void runOnOperation() override {
73 auto *ctx = &getContext();
74 RewritePatternSet patterns(ctx);
75 populateSparseReinterpretMap(patterns, scope);
76 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
77 }
78};
79
80struct PreSparsificationRewritePass
81 : public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
82 PreSparsificationRewritePass() = default;
83 PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) =
84 default;
85
86 void runOnOperation() override {
87 auto *ctx = &getContext();
88 RewritePatternSet patterns(ctx);
89 populatePreSparsificationRewriting(patterns);
90 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
91 }
92};
93
94struct SparsificationPass
95 : public impl::SparsificationPassBase<SparsificationPass> {
96 SparsificationPass() = default;
97 SparsificationPass(const SparsificationPass &pass) = default;
98 SparsificationPass(const SparsificationOptions &options) {
99 parallelization = options.parallelizationStrategy;
100 sparseEmitStrategy = options.sparseEmitStrategy;
101 enableRuntimeLibrary = options.enableRuntimeLibrary;
102 }
103
104 void runOnOperation() override {
105 auto *ctx = &getContext();
106 // Translate strategy flags to strategy options.
107 SparsificationOptions options(parallelization, sparseEmitStrategy,
108 enableRuntimeLibrary);
109 // Apply sparsification and cleanup rewriting.
110 RewritePatternSet patterns(ctx);
111 populateSparsificationPatterns(patterns, options);
112 scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
113 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
114 }
115};
116
117struct StageSparseOperationsPass
118 : public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
119 StageSparseOperationsPass() = default;
120 StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default;
121 void runOnOperation() override {
122 auto *ctx = &getContext();
123 RewritePatternSet patterns(ctx);
124 populateStageSparseOperationsPatterns(patterns);
125 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
126 }
127};
128
129struct LowerSparseOpsToForeachPass
130 : public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> {
131 LowerSparseOpsToForeachPass() = default;
132 LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) =
133 default;
134 LowerSparseOpsToForeachPass(bool enableRT, bool convert) {
135 enableRuntimeLibrary = enableRT;
136 enableConvert = convert;
137 }
138
139 void runOnOperation() override {
140 auto *ctx = &getContext();
141 RewritePatternSet patterns(ctx);
142 populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary,
143 enableConvert);
144 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
145 }
146};
147
148struct LowerForeachToSCFPass
149 : public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> {
150 LowerForeachToSCFPass() = default;
151 LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default;
152
153 void runOnOperation() override {
154 auto *ctx = &getContext();
155 RewritePatternSet patterns(ctx);
156 populateLowerForeachToSCFPatterns(patterns);
157 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
158 }
159};
160
161struct LowerSparseIterationToSCFPass
162 : public impl::LowerSparseIterationToSCFBase<
163 LowerSparseIterationToSCFPass> {
164 LowerSparseIterationToSCFPass() = default;
165 LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =
166 default;
167
168 void runOnOperation() override {
169 auto *ctx = &getContext();
170 RewritePatternSet patterns(ctx);
171 SparseIterationTypeConverter converter;
172 ConversionTarget target(*ctx);
173
174 // The actual conversion.
175 target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
176 memref::MemRefDialect, scf::SCFDialect,
177 sparse_tensor::SparseTensorDialect>();
178 target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp,
179 IterateOp>();
180 target.addLegalOp<UnrealizedConversionCastOp>();
181 populateLowerSparseIterationToSCFPatterns(converter, patterns);
182
183 if (failed(applyPartialConversion(getOperation(), target,
184 std::move(patterns))))
185 signalPassFailure();
186 }
187};
188
189struct SparseTensorConversionPass
190 : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
191 SparseTensorConversionPass() = default;
192 SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
193
194 void runOnOperation() override {
195 auto *ctx = &getContext();
196 RewritePatternSet patterns(ctx);
197 SparseTensorTypeToPtrConverter converter;
198 ConversionTarget target(*ctx);
199 // Everything in the sparse dialect must go!
200 target.addIllegalDialect<SparseTensorDialect>();
201 // All dynamic rules below accept new function, call, return, and various
202 // tensor and bufferization operations as legal output of the rewriting
203 // provided that all sparse tensor types have been fully rewritten.
204 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
205 return converter.isSignatureLegal(op.getFunctionType());
206 });
207 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
208 return converter.isSignatureLegal(op.getCalleeType());
209 });
210 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
211 return converter.isLegal(op.getOperandTypes());
212 });
213 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
214 return converter.isLegal(op.getOperandTypes());
215 });
216 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
217 return converter.isLegal(op.getSource().getType()) &&
218 converter.isLegal(op.getDest().getType());
219 });
220 target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
221 [&](tensor::ExpandShapeOp op) {
222 return converter.isLegal(op.getSrc().getType()) &&
223 converter.isLegal(op.getResult().getType());
224 });
225 target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
226 [&](tensor::CollapseShapeOp op) {
227 return converter.isLegal(op.getSrc().getType()) &&
228 converter.isLegal(op.getResult().getType());
229 });
230 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
231 [&](bufferization::AllocTensorOp op) {
232 return converter.isLegal(op.getType());
233 });
234 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
235 [&](bufferization::DeallocTensorOp op) {
236 return converter.isLegal(op.getTensor().getType());
237 });
238 // The following operations and dialects may be introduced by the
239 // rewriting rules, and are therefore marked as legal.
240 target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
241 linalg::YieldOp, tensor::ExtractOp,
242 tensor::FromElementsOp>();
243 target.addLegalDialect<
244 arith::ArithDialect, bufferization::BufferizationDialect,
245 LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
246
247 // Populate with rules and apply rewriting rules.
248 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
249 converter);
250 populateCallOpTypeConversionPattern(patterns, converter);
251 scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter: converter, patterns,
252 target);
253 populateSparseTensorConversionPatterns(typeConverter: converter, patterns);
254 if (failed(applyPartialConversion(getOperation(), target,
255 std::move(patterns))))
256 signalPassFailure();
257 }
258};
259
260struct SparseTensorCodegenPass
261 : public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
262 SparseTensorCodegenPass() = default;
263 SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
264 SparseTensorCodegenPass(bool createDeallocs, bool enableInit) {
265 createSparseDeallocs = createDeallocs;
266 enableBufferInitialization = enableInit;
267 }
268
269 void runOnOperation() override {
270 auto *ctx = &getContext();
271 RewritePatternSet patterns(ctx);
272 SparseTensorTypeToBufferConverter converter;
273 ConversionTarget target(*ctx);
274 // Most ops in the sparse dialect must go!
275 target.addIllegalDialect<SparseTensorDialect>();
276 target.addLegalOp<SortOp>();
277 target.addLegalOp<PushBackOp>();
278 // Storage specifier outlives sparse tensor pipeline.
279 target.addLegalOp<GetStorageSpecifierOp>();
280 target.addLegalOp<SetStorageSpecifierOp>();
281 target.addLegalOp<StorageSpecifierInitOp>();
282 // Note that tensor::FromElementsOp might be yield after lowering unpack.
283 target.addLegalOp<tensor::FromElementsOp>();
284 // All dynamic rules below accept new function, call, return, and
285 // various tensor and bufferization operations as legal output of the
286 // rewriting provided that all sparse tensor types have been fully
287 // rewritten.
288 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
289 return converter.isSignatureLegal(op.getFunctionType());
290 });
291 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
292 return converter.isSignatureLegal(op.getCalleeType());
293 });
294 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
295 return converter.isLegal(op.getOperandTypes());
296 });
297 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
298 [&](bufferization::AllocTensorOp op) {
299 return converter.isLegal(op.getType());
300 });
301 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
302 [&](bufferization::DeallocTensorOp op) {
303 return converter.isLegal(op.getTensor().getType());
304 });
305 // The following operations and dialects may be introduced by the
306 // codegen rules, and are therefore marked as legal.
307 target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
308 target.addLegalDialect<
309 arith::ArithDialect, bufferization::BufferizationDialect,
310 complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
311 target.addLegalOp<UnrealizedConversionCastOp>();
312 // Populate with rules and apply rewriting rules.
313 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
314 converter);
315 scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter: converter, patterns,
316 target);
317 populateSparseTensorCodegenPatterns(
318 converter, patterns, createSparseDeallocs, enableBufferInitialization);
319 if (failed(applyPartialConversion(getOperation(), target,
320 std::move(patterns))))
321 signalPassFailure();
322 }
323};
324
325struct SparseBufferRewritePass
326 : public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
327 SparseBufferRewritePass() = default;
328 SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
329 SparseBufferRewritePass(bool enableInit) {
330 enableBufferInitialization = enableInit;
331 }
332
333 void runOnOperation() override {
334 auto *ctx = &getContext();
335 RewritePatternSet patterns(ctx);
336 populateSparseBufferRewriting(patterns, enableBufferInitialization);
337 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
338 }
339};
340
341struct SparseVectorizationPass
342 : public impl::SparseVectorizationBase<SparseVectorizationPass> {
343 SparseVectorizationPass() = default;
344 SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
345 SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
346 vectorLength = vl;
347 enableVLAVectorization = vla;
348 enableSIMDIndex32 = sidx32;
349 }
350
351 void runOnOperation() override {
352 if (vectorLength == 0)
353 return signalPassFailure();
354 auto *ctx = &getContext();
355 RewritePatternSet patterns(ctx);
356 populateSparseVectorizationPatterns(
357 patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
358 vector::populateVectorToVectorCanonicalizationPatterns(patterns);
359 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
360 }
361};
362
363struct SparseGPUCodegenPass
364 : public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
365 SparseGPUCodegenPass() = default;
366 SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default;
367 SparseGPUCodegenPass(unsigned nT, bool enableRT) {
368 numThreads = nT;
369 enableRuntimeLibrary = enableRT;
370 }
371
372 void runOnOperation() override {
373 auto *ctx = &getContext();
374 RewritePatternSet patterns(ctx);
375 if (numThreads == 0)
376 populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary);
377 else
378 populateSparseGPUCodegenPatterns(patterns, numThreads);
379 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
380 }
381};
382
383struct StorageSpecifierToLLVMPass
384 : public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
385 StorageSpecifierToLLVMPass() = default;
386
387 void runOnOperation() override {
388 auto *ctx = &getContext();
389 ConversionTarget target(*ctx);
390 RewritePatternSet patterns(ctx);
391 StorageSpecifierToLLVMTypeConverter converter;
392
393 // All ops in the sparse dialect must go!
394 target.addIllegalDialect<SparseTensorDialect>();
395 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
396 return converter.isSignatureLegal(op.getFunctionType());
397 });
398 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
399 return converter.isSignatureLegal(op.getCalleeType());
400 });
401 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
402 return converter.isLegal(op.getOperandTypes());
403 });
404 target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
405
406 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
407 converter);
408 populateCallOpTypeConversionPattern(patterns, converter);
409 populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
410 populateReturnOpTypeConversionPattern(patterns, converter);
411 scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter: converter, patterns,
412 target);
413 populateStorageSpecifierToLLVMPatterns(converter, patterns);
414 if (failed(applyPartialConversion(getOperation(), target,
415 std::move(patterns))))
416 signalPassFailure();
417 }
418};
419
420} // namespace
421
422//===----------------------------------------------------------------------===//
423// Pass creation methods.
424//===----------------------------------------------------------------------===//
425
426std::unique_ptr<Pass> mlir::createSparseAssembler() {
427 return std::make_unique<SparseAssembler>();
428}
429
430std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
431 return std::make_unique<SparseReinterpretMap>();
432}
433
434std::unique_ptr<Pass>
435mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) {
436 SparseReinterpretMapOptions options;
437 options.scope = scope;
438 return std::make_unique<SparseReinterpretMap>(options);
439}
440
441std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
442 return std::make_unique<PreSparsificationRewritePass>();
443}
444
445std::unique_ptr<Pass> mlir::createSparsificationPass() {
446 return std::make_unique<SparsificationPass>();
447}
448
449std::unique_ptr<Pass>
450mlir::createSparsificationPass(const SparsificationOptions &options) {
451 return std::make_unique<SparsificationPass>(args: options);
452}
453
454std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() {
455 return std::make_unique<StageSparseOperationsPass>();
456}
457
458std::unique_ptr<Pass> mlir::createLowerSparseOpsToForeachPass() {
459 return std::make_unique<LowerSparseOpsToForeachPass>();
460}
461
462std::unique_ptr<Pass>
463mlir::createLowerSparseOpsToForeachPass(bool enableRT, bool enableConvert) {
464 return std::make_unique<LowerSparseOpsToForeachPass>(args&: enableRT, args&: enableConvert);
465}
466
467std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
468 return std::make_unique<LowerForeachToSCFPass>();
469}
470
471std::unique_ptr<Pass> mlir::createLowerSparseIterationToSCFPass() {
472 return std::make_unique<LowerSparseIterationToSCFPass>();
473}
474
475std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
476 return std::make_unique<SparseTensorConversionPass>();
477}
478
479std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
480 return std::make_unique<SparseTensorCodegenPass>();
481}
482
483std::unique_ptr<Pass>
484mlir::createSparseTensorCodegenPass(bool createSparseDeallocs,
485 bool enableBufferInitialization) {
486 return std::make_unique<SparseTensorCodegenPass>(args&: createSparseDeallocs,
487 args&: enableBufferInitialization);
488}
489
490std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() {
491 return std::make_unique<SparseBufferRewritePass>();
492}
493
494std::unique_ptr<Pass>
495mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
496 return std::make_unique<SparseBufferRewritePass>(args&: enableBufferInitialization);
497}
498
499std::unique_ptr<Pass> mlir::createSparseVectorizationPass() {
500 return std::make_unique<SparseVectorizationPass>();
501}
502
503std::unique_ptr<Pass>
504mlir::createSparseVectorizationPass(unsigned vectorLength,
505 bool enableVLAVectorization,
506 bool enableSIMDIndex32) {
507 return std::make_unique<SparseVectorizationPass>(
508 args&: vectorLength, args&: enableVLAVectorization, args&: enableSIMDIndex32);
509}
510
511std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass() {
512 return std::make_unique<SparseGPUCodegenPass>();
513}
514
515std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass(unsigned numThreads,
516 bool enableRT) {
517 return std::make_unique<SparseGPUCodegenPass>(args&: numThreads, args&: enableRT);
518}
519
520std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() {
521 return std::make_unique<StorageSpecifierToLLVMPass>();
522}
523

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp