1//===- SparsificationAndBufferizationPass.cpp - Tensor to Memref Lowering -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
13#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
15#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
16#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
17#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
18#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/GPU/IR/GPUDialect.h"
21#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22#include "mlir/Dialect/Linalg/IR/Linalg.h"
23#include "mlir/Dialect/MemRef/IR/MemRef.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
26#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
27#include "mlir/Dialect/Vector/IR/VectorOps.h"
28#include "mlir/Pass/PassManager.h"
29#include "mlir/Transforms/Passes.h"
30
31using namespace mlir;
32
33namespace mlir {
34
35#define GEN_PASS_DEF_SPARSIFICATIONANDBUFFERIZATION
36#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
37
38namespace sparse_tensor {
39
40/// Return `true` if one of the given types is a sparse tensor type.
41static bool containsSparseTensor(TypeRange types) {
42 for (Type t : types)
43 if (isa<TensorType>(t) && getSparseTensorEncoding(t))
44 return true;
45 return false;
46}
47
48/// A pass that lowers tensor ops to memref ops, regardless of whether they are
49/// dense or sparse.
50///
51/// One-Shot Analysis is used to detect RaW conflicts and to insert buffer
52/// copies of the tensor level (`insertTensorCopies`). Afterwards, the lowering
53/// of tensor ops to memref ops follows a different code path depending on
54/// whether the op is sparse or dense:
55///
56/// * Sparse tensor ops are lowered through Sparsification and follow-up pass
57/// that lowers sparse_tensor dialect ops.
58/// * Dense tensor ops are lowered through BufferizableOpInterface
59/// implementations.
60class SparsificationAndBufferizationPass
61 : public impl::SparsificationAndBufferizationBase<
62 SparsificationAndBufferizationPass> {
63public:
64 SparsificationAndBufferizationPass(
65 const bufferization::OneShotBufferizationOptions &bufferizationOptions,
66 const SparsificationOptions &sparsificationOptions,
67 bool createSparseDeallocs, bool enableRuntimeLibrary,
68 bool enableBufferInitialization, unsigned vectorLength,
69 bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen)
70 : bufferizationOptions(bufferizationOptions),
71 sparsificationOptions(sparsificationOptions),
72 createSparseDeallocs(createSparseDeallocs),
73 enableRuntimeLibrary(enableRuntimeLibrary),
74 enableBufferInitialization(enableBufferInitialization),
75 vectorLength(vectorLength),
76 enableVLAVectorization(enableVLAVectorization),
77 enableSIMDIndex32(enableSIMDIndex32), enableGPULibgen(enableGPULibgen) {
78 }
79
80 /// Bufferize all dense ops. This assumes that no further analysis is needed
81 /// and that all required buffer copies were already inserted by
82 /// `insertTensorCopies` in the form of `bufferization.alloc_tensor` ops.
83 LogicalResult runDenseBufferization() {
84 bufferization::OneShotBufferizationOptions updatedOptions =
85 bufferizationOptions;
86 // Skip all sparse ops.
87 updatedOptions.opFilter.denyOperation([&](Operation *op) {
88 if (containsSparseTensor(types: TypeRange(op->getResults())) ||
89 containsSparseTensor(types: TypeRange(op->getOperands())))
90 return true;
91 if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
92 FunctionType funcType = funcOp.getFunctionType();
93 if (containsSparseTensor(funcType.getInputs()) ||
94 containsSparseTensor(funcType.getResults()))
95 return true;
96 }
97 return false;
98 });
99
100 if (failed(bufferization::bufferizeModuleOp(moduleOp: cast<ModuleOp>(getOperation()),
101 options: updatedOptions)))
102 return failure();
103
104 bufferization::removeBufferizationAttributesInModule(moduleOp: getOperation());
105 return success();
106 }
107
108 void runOnOperation() override {
109 // Run enabling transformations.
110 {
111 OpPassManager pm("builtin.module");
112 pm.addPass(pass: createPreSparsificationRewritePass());
113 pm.addNestedPass<func::FuncOp>(
114 bufferization::createEmptyTensorToAllocTensorPass());
115 if (failed(runPipeline(pm, getOperation())))
116 return signalPassFailure();
117 }
118
119 // Insert tensor copies. This step runs One-Shot Analysis (which analyzes
120 // SSA use-def chains of tensor IR) and decides where buffer copies are
121 // needed and where buffers can be written to in-place. These decisions are
122 // materialized in the IR in the form of `bufferization.alloc_tensor` ops.
123 //
124 // Note: All following steps in this pass must be careful not to modify the
125 // structure of the IR (i.e., tensor use-def chains), as that could
126 // invalidate the results of the analysis. From now on, only small and
127 // localized rewrites are allowed, such as replacing a tensor op with its
128 // memref equivalent.
129 if (failed(bufferization::insertTensorCopies(getOperation(),
130 bufferizationOptions)))
131 return signalPassFailure();
132
133 // Option `testAnalysisOnly` is a debug/testing flag. If set, the results of
134 // OneShotAnalysis are added to the IR via attributes. In that case, do not
135 // continue with the remaining pipeline.
136 if (bufferizationOptions.testAnalysisOnly)
137 return;
138
139 // Bufferize all sparse ops. No further analysis is needed. All required
140 // buffer copies were already inserted by `insertTensorCopies` in the form
141 // of `bufferization.alloc_tensor` ops.
142 {
143 OpPassManager pm("builtin.module");
144 if (enableGPULibgen)
145 pm.addPass(createSparseGPUCodegenPass(0, enableRuntimeLibrary));
146 pm.addPass(pass: createSparseReinterpretMapPass(scope: ReinterpretMapScope::kAll));
147 pm.addPass(createSparsificationPass(sparsificationOptions));
148 pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
149 pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,
150 /*enableConvert=*/true));
151 pm.addPass(
152 pass: createSparseReinterpretMapPass(scope: ReinterpretMapScope::kExceptGeneric));
153 pm.addNestedPass<func::FuncOp>(createLowerForeachToSCFPass());
154 pm.addPass(mlir::pass: createLoopInvariantCodeMotionPass());
155 if (vectorLength > 0) {
156 pm.addPass(createSparseVectorizationPass(
157 vectorLength, enableVLAVectorization, enableSIMDIndex32));
158 }
159 if (enableRuntimeLibrary) {
160 pm.addPass(createSparseTensorConversionPass());
161 } else {
162 pm.addPass(createSparseTensorCodegenPass(createSparseDeallocs,
163 enableBufferInitialization));
164 pm.addPass(createSparseBufferRewritePass(enableBufferInitialization));
165 }
166 if (failed(runPipeline(pm, getOperation())))
167 return signalPassFailure();
168 }
169
170 // Bufferize all dense ops.
171 if (failed(result: runDenseBufferization()))
172 signalPassFailure();
173 }
174
175private:
176 bufferization::OneShotBufferizationOptions bufferizationOptions;
177 SparsificationOptions sparsificationOptions;
178 bool createSparseDeallocs;
179 bool enableRuntimeLibrary;
180 bool enableBufferInitialization;
181 unsigned vectorLength;
182 bool enableVLAVectorization;
183 bool enableSIMDIndex32;
184 bool enableGPULibgen;
185};
186
187} // namespace sparse_tensor
188} // namespace mlir
189
190mlir::bufferization::OneShotBufferizationOptions
191mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
192 using namespace mlir::bufferization;
193 OneShotBufferizationOptions options;
194 options.bufferizeFunctionBoundaries = true;
195 options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
196 options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
197 const BufferizationOptions &options) {
198 return getMemRefTypeWithStaticIdentityLayout(
199 tensorType: cast<TensorType>(Val: value.getType()), memorySpace);
200 };
201 if (analysisOnly) {
202 options.testAnalysisOnly = true;
203 options.printConflicts = true;
204 }
205 // Since this mini-pipeline may be used in alternative pipelines (viz.
206 // different from the default "sparsifier" pipeline) where unknown ops
207 // are handled by alternative bufferization methods that are downstream
208 // of this mini-pipeline, we allow unknown ops by default (failure to
209 // bufferize is eventually apparent by failing to convert to LLVM IR).
210 options.allowUnknownOps = true;
211 return options;
212}
213
214std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass() {
215 SparsificationOptions sparseOptions;
216 return createSparsificationAndBufferizationPass(
217 getBufferizationOptionsForSparsification(/*analysisOnly=*/false),
218 sparseOptions,
219 /*createSparseDeallocs=*/false,
220 /*enableRuntimeLibrary=*/false,
221 /*enableBufferInitialization=*/false,
222 /*vectorLength=*/0,
223 /*enableVLAVectorization=*/false,
224 /*enableSIMDIndex32=*/false,
225 /*enableGPULibgen=*/false);
226}
227
228std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass(
229 const bufferization::OneShotBufferizationOptions &bufferizationOptions,
230 const SparsificationOptions &sparsificationOptions,
231 bool createSparseDeallocs, bool enableRuntimeLibrary,
232 bool enableBufferInitialization, unsigned vectorLength,
233 bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen) {
234 return std::make_unique<
235 mlir::sparse_tensor::SparsificationAndBufferizationPass>(
236 args: bufferizationOptions, args: sparsificationOptions, args&: createSparseDeallocs,
237 args&: enableRuntimeLibrary, args&: enableBufferInitialization, args&: vectorLength,
238 args&: enableVLAVectorization, args&: enableSIMDIndex32, args&: enableGPULibgen);
239}
240

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp