1//===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Affine/IR/AffineOps.h"
10#include "mlir/Dialect/Arith/IR/Arith.h"
11#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12#include "mlir/Dialect/Complex/IR/Complex.h"
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15#include "mlir/Dialect/GPU/IR/GPUDialect.h"
16#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18#include "mlir/Dialect/SCF/Transforms/Patterns.h"
19#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
20#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23
24namespace mlir {
25#define GEN_PASS_DEF_SPARSEASSEMBLER
26#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
27#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
28#define GEN_PASS_DEF_SPARSIFICATIONPASS
29#define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
30#define GEN_PASS_DEF_LOWERFOREACHTOSCF
31#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
32#define GEN_PASS_DEF_SPARSETENSORCODEGEN
33#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
34#define GEN_PASS_DEF_SPARSEVECTORIZATION
35#define GEN_PASS_DEF_SPARSEGPUCODEGEN
36#define GEN_PASS_DEF_STAGESPARSEOPERATIONS
37#define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
38#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
39} // namespace mlir
40
41using namespace mlir;
42using namespace mlir::sparse_tensor;
43
44namespace {
45
46//===----------------------------------------------------------------------===//
47// Passes implementation.
48//===----------------------------------------------------------------------===//
49
50struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
51 SparseAssembler() = default;
52 SparseAssembler(const SparseAssembler &pass) = default;
53 SparseAssembler(bool dO) { directOut = dO; }
54
55 void runOnOperation() override {
56 auto *ctx = &getContext();
57 RewritePatternSet patterns(ctx);
58 populateSparseAssembler(patterns, directOut);
59 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
60 }
61};
62
63struct SparseReinterpretMap
64 : public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
65 SparseReinterpretMap() = default;
66 SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
67 SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
68 scope = options.scope;
69 }
70
71 void runOnOperation() override {
72 auto *ctx = &getContext();
73 RewritePatternSet patterns(ctx);
74 populateSparseReinterpretMap(patterns, scope);
75 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
76 }
77};
78
79struct PreSparsificationRewritePass
80 : public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
81 PreSparsificationRewritePass() = default;
82 PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) =
83 default;
84
85 void runOnOperation() override {
86 auto *ctx = &getContext();
87 RewritePatternSet patterns(ctx);
88 populatePreSparsificationRewriting(patterns);
89 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
90 }
91};
92
93struct SparsificationPass
94 : public impl::SparsificationPassBase<SparsificationPass> {
95 SparsificationPass() = default;
96 SparsificationPass(const SparsificationPass &pass) = default;
97 SparsificationPass(const SparsificationOptions &options) {
98 parallelization = options.parallelizationStrategy;
99 sparseEmitStrategy = options.sparseEmitStrategy;
100 enableRuntimeLibrary = options.enableRuntimeLibrary;
101 }
102
103 void runOnOperation() override {
104 auto *ctx = &getContext();
105 // Translate strategy flags to strategy options.
106 SparsificationOptions options(parallelization, sparseEmitStrategy,
107 enableRuntimeLibrary);
108 // Apply sparsification and cleanup rewriting.
109 RewritePatternSet patterns(ctx);
110 populateSparsificationPatterns(patterns, options);
111 scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
112 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
113 }
114};
115
116struct StageSparseOperationsPass
117 : public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
118 StageSparseOperationsPass() = default;
119 StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default;
120 void runOnOperation() override {
121 auto *ctx = &getContext();
122 RewritePatternSet patterns(ctx);
123 populateStageSparseOperationsPatterns(patterns);
124 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
125 }
126};
127
128struct LowerSparseOpsToForeachPass
129 : public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> {
130 LowerSparseOpsToForeachPass() = default;
131 LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) =
132 default;
133 LowerSparseOpsToForeachPass(bool enableRT, bool convert) {
134 enableRuntimeLibrary = enableRT;
135 enableConvert = convert;
136 }
137
138 void runOnOperation() override {
139 auto *ctx = &getContext();
140 RewritePatternSet patterns(ctx);
141 populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary,
142 enableConvert);
143 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
144 }
145};
146
147struct LowerForeachToSCFPass
148 : public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> {
149 LowerForeachToSCFPass() = default;
150 LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default;
151
152 void runOnOperation() override {
153 auto *ctx = &getContext();
154 RewritePatternSet patterns(ctx);
155 populateLowerForeachToSCFPatterns(patterns);
156 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
157 }
158};
159
160struct SparseTensorConversionPass
161 : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
162 SparseTensorConversionPass() = default;
163 SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
164
165 void runOnOperation() override {
166 auto *ctx = &getContext();
167 RewritePatternSet patterns(ctx);
168 SparseTensorTypeToPtrConverter converter;
169 ConversionTarget target(*ctx);
170 // Everything in the sparse dialect must go!
171 target.addIllegalDialect<SparseTensorDialect>();
172 // All dynamic rules below accept new function, call, return, and various
173 // tensor and bufferization operations as legal output of the rewriting
174 // provided that all sparse tensor types have been fully rewritten.
175 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
176 return converter.isSignatureLegal(op.getFunctionType());
177 });
178 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
179 return converter.isSignatureLegal(op.getCalleeType());
180 });
181 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
182 return converter.isLegal(op.getOperandTypes());
183 });
184 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
185 return converter.isLegal(op.getOperandTypes());
186 });
187 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
188 return converter.isLegal(op.getSource().getType()) &&
189 converter.isLegal(op.getDest().getType());
190 });
191 target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
192 [&](tensor::ExpandShapeOp op) {
193 return converter.isLegal(op.getSrc().getType()) &&
194 converter.isLegal(op.getResult().getType());
195 });
196 target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
197 [&](tensor::CollapseShapeOp op) {
198 return converter.isLegal(op.getSrc().getType()) &&
199 converter.isLegal(op.getResult().getType());
200 });
201 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
202 [&](bufferization::AllocTensorOp op) {
203 return converter.isLegal(op.getType());
204 });
205 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
206 [&](bufferization::DeallocTensorOp op) {
207 return converter.isLegal(op.getTensor().getType());
208 });
209 // The following operations and dialects may be introduced by the
210 // rewriting rules, and are therefore marked as legal.
211 target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
212 linalg::YieldOp, tensor::ExtractOp,
213 tensor::FromElementsOp>();
214 target.addLegalDialect<
215 arith::ArithDialect, bufferization::BufferizationDialect,
216 LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
217
218 // Populate with rules and apply rewriting rules.
219 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
220 converter);
221 populateCallOpTypeConversionPattern(patterns, converter);
222 scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter&: converter, patterns,
223 target);
224 populateSparseTensorConversionPatterns(typeConverter&: converter, patterns);
225 if (failed(applyPartialConversion(getOperation(), target,
226 std::move(patterns))))
227 signalPassFailure();
228 }
229};
230
231struct SparseTensorCodegenPass
232 : public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
233 SparseTensorCodegenPass() = default;
234 SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
235 SparseTensorCodegenPass(bool createDeallocs, bool enableInit) {
236 createSparseDeallocs = createDeallocs;
237 enableBufferInitialization = enableInit;
238 }
239
240 void runOnOperation() override {
241 auto *ctx = &getContext();
242 RewritePatternSet patterns(ctx);
243 SparseTensorTypeToBufferConverter converter;
244 ConversionTarget target(*ctx);
245 // Most ops in the sparse dialect must go!
246 target.addIllegalDialect<SparseTensorDialect>();
247 target.addLegalOp<SortOp>();
248 target.addLegalOp<PushBackOp>();
249 // Storage specifier outlives sparse tensor pipeline.
250 target.addLegalOp<GetStorageSpecifierOp>();
251 target.addLegalOp<SetStorageSpecifierOp>();
252 target.addLegalOp<StorageSpecifierInitOp>();
253 // Note that tensor::FromElementsOp might be yield after lowering unpack.
254 target.addLegalOp<tensor::FromElementsOp>();
255 // All dynamic rules below accept new function, call, return, and
256 // various tensor and bufferization operations as legal output of the
257 // rewriting provided that all sparse tensor types have been fully
258 // rewritten.
259 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
260 return converter.isSignatureLegal(op.getFunctionType());
261 });
262 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
263 return converter.isSignatureLegal(op.getCalleeType());
264 });
265 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
266 return converter.isLegal(op.getOperandTypes());
267 });
268 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
269 [&](bufferization::AllocTensorOp op) {
270 return converter.isLegal(op.getType());
271 });
272 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
273 [&](bufferization::DeallocTensorOp op) {
274 return converter.isLegal(op.getTensor().getType());
275 });
276 // The following operations and dialects may be introduced by the
277 // codegen rules, and are therefore marked as legal.
278 target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
279 target.addLegalDialect<
280 arith::ArithDialect, bufferization::BufferizationDialect,
281 complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
282 target.addLegalOp<UnrealizedConversionCastOp>();
283 // Populate with rules and apply rewriting rules.
284 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
285 converter);
286 scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter&: converter, patterns,
287 target);
288 populateSparseTensorCodegenPatterns(
289 converter, patterns, createSparseDeallocs, enableBufferInitialization);
290 if (failed(applyPartialConversion(getOperation(), target,
291 std::move(patterns))))
292 signalPassFailure();
293 }
294};
295
296struct SparseBufferRewritePass
297 : public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
298 SparseBufferRewritePass() = default;
299 SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
300 SparseBufferRewritePass(bool enableInit) {
301 enableBufferInitialization = enableInit;
302 }
303
304 void runOnOperation() override {
305 auto *ctx = &getContext();
306 RewritePatternSet patterns(ctx);
307 populateSparseBufferRewriting(patterns, enableBufferInitialization);
308 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
309 }
310};
311
312struct SparseVectorizationPass
313 : public impl::SparseVectorizationBase<SparseVectorizationPass> {
314 SparseVectorizationPass() = default;
315 SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
316 SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
317 vectorLength = vl;
318 enableVLAVectorization = vla;
319 enableSIMDIndex32 = sidx32;
320 }
321
322 void runOnOperation() override {
323 if (vectorLength == 0)
324 return signalPassFailure();
325 auto *ctx = &getContext();
326 RewritePatternSet patterns(ctx);
327 populateSparseVectorizationPatterns(
328 patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
329 vector::populateVectorToVectorCanonicalizationPatterns(patterns);
330 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
331 }
332};
333
334struct SparseGPUCodegenPass
335 : public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
336 SparseGPUCodegenPass() = default;
337 SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default;
338 SparseGPUCodegenPass(unsigned nT, bool enableRT) {
339 numThreads = nT;
340 enableRuntimeLibrary = enableRT;
341 }
342
343 void runOnOperation() override {
344 auto *ctx = &getContext();
345 RewritePatternSet patterns(ctx);
346 if (numThreads == 0)
347 populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary);
348 else
349 populateSparseGPUCodegenPatterns(patterns, numThreads);
350 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
351 }
352};
353
354struct StorageSpecifierToLLVMPass
355 : public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
356 StorageSpecifierToLLVMPass() = default;
357
358 void runOnOperation() override {
359 auto *ctx = &getContext();
360 ConversionTarget target(*ctx);
361 RewritePatternSet patterns(ctx);
362 StorageSpecifierToLLVMTypeConverter converter;
363
364 // All ops in the sparse dialect must go!
365 target.addIllegalDialect<SparseTensorDialect>();
366 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
367 return converter.isSignatureLegal(op.getFunctionType());
368 });
369 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
370 return converter.isSignatureLegal(op.getCalleeType());
371 });
372 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
373 return converter.isLegal(op.getOperandTypes());
374 });
375 target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
376
377 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
378 converter);
379 populateCallOpTypeConversionPattern(patterns, converter);
380 populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
381 populateReturnOpTypeConversionPattern(patterns, converter);
382 scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter&: converter, patterns,
383 target);
384 populateStorageSpecifierToLLVMPatterns(converter, patterns);
385 if (failed(applyPartialConversion(getOperation(), target,
386 std::move(patterns))))
387 signalPassFailure();
388 }
389};
390
391} // namespace
392
393//===----------------------------------------------------------------------===//
394// Pass creation methods.
395//===----------------------------------------------------------------------===//
396
397std::unique_ptr<Pass> mlir::createSparseAssembler() {
398 return std::make_unique<SparseAssembler>();
399}
400
401std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
402 return std::make_unique<SparseReinterpretMap>();
403}
404
405std::unique_ptr<Pass>
406mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) {
407 SparseReinterpretMapOptions options;
408 options.scope = scope;
409 return std::make_unique<SparseReinterpretMap>(options);
410}
411
412std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
413 return std::make_unique<PreSparsificationRewritePass>();
414}
415
416std::unique_ptr<Pass> mlir::createSparsificationPass() {
417 return std::make_unique<SparsificationPass>();
418}
419
420std::unique_ptr<Pass>
421mlir::createSparsificationPass(const SparsificationOptions &options) {
422 return std::make_unique<SparsificationPass>(args: options);
423}
424
425std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() {
426 return std::make_unique<StageSparseOperationsPass>();
427}
428
429std::unique_ptr<Pass> mlir::createLowerSparseOpsToForeachPass() {
430 return std::make_unique<LowerSparseOpsToForeachPass>();
431}
432
433std::unique_ptr<Pass>
434mlir::createLowerSparseOpsToForeachPass(bool enableRT, bool enableConvert) {
435 return std::make_unique<LowerSparseOpsToForeachPass>(args&: enableRT, args&: enableConvert);
436}
437
438std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
439 return std::make_unique<LowerForeachToSCFPass>();
440}
441
442std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
443 return std::make_unique<SparseTensorConversionPass>();
444}
445
446std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
447 return std::make_unique<SparseTensorCodegenPass>();
448}
449
450std::unique_ptr<Pass>
451mlir::createSparseTensorCodegenPass(bool createSparseDeallocs,
452 bool enableBufferInitialization) {
453 return std::make_unique<SparseTensorCodegenPass>(args&: createSparseDeallocs,
454 args&: enableBufferInitialization);
455}
456
457std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() {
458 return std::make_unique<SparseBufferRewritePass>();
459}
460
461std::unique_ptr<Pass>
462mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
463 return std::make_unique<SparseBufferRewritePass>(args&: enableBufferInitialization);
464}
465
466std::unique_ptr<Pass> mlir::createSparseVectorizationPass() {
467 return std::make_unique<SparseVectorizationPass>();
468}
469
470std::unique_ptr<Pass>
471mlir::createSparseVectorizationPass(unsigned vectorLength,
472 bool enableVLAVectorization,
473 bool enableSIMDIndex32) {
474 return std::make_unique<SparseVectorizationPass>(
475 args&: vectorLength, args&: enableVLAVectorization, args&: enableSIMDIndex32);
476}
477
478std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass() {
479 return std::make_unique<SparseGPUCodegenPass>();
480}
481
482std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass(unsigned numThreads,
483 bool enableRT) {
484 return std::make_unique<SparseGPUCodegenPass>(args&: numThreads, args&: enableRT);
485}
486
487std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() {
488 return std::make_unique<StorageSpecifierToLLVMPass>();
489}
490

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp