1//===- SparsificationAndBufferizationPass.cpp - Tensor to Memref Lowering -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
13#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
15#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
16#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
17#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
18#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/GPU/IR/GPUDialect.h"
21#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22#include "mlir/Dialect/Linalg/IR/Linalg.h"
23#include "mlir/Dialect/SCF/IR/SCF.h"
24#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
25#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
26#include "mlir/Dialect/Vector/IR/VectorOps.h"
27#include "mlir/Pass/PassManager.h"
28#include "mlir/Transforms/Passes.h"
29
30using namespace mlir;
31
32namespace mlir {
33
34#define GEN_PASS_DEF_SPARSIFICATIONANDBUFFERIZATION
35#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
36
37namespace sparse_tensor {
38
39/// Return `true` if one of the given types is a sparse tensor type.
40static bool containsSparseTensor(TypeRange types) {
41 for (Type t : types)
42 if (isa<TensorType>(Val: t) && getSparseTensorEncoding(type: t))
43 return true;
44 return false;
45}
46
47/// A pass that lowers tensor ops to memref ops, regardless of whether they are
48/// dense or sparse.
49///
50/// One-Shot Analysis is used to detect RaW conflicts and to insert buffer
51/// copies of the tensor level (`insertTensorCopies`). Afterwards, the lowering
52/// of tensor ops to memref ops follows a different code path depending on
53/// whether the op is sparse or dense:
54///
55/// * Sparse tensor ops are lowered through Sparsification and follow-up pass
56/// that lowers sparse_tensor dialect ops.
57/// * Dense tensor ops are lowered through BufferizableOpInterface
58/// implementations.
59class SparsificationAndBufferizationPass
60 : public impl::SparsificationAndBufferizationBase<
61 SparsificationAndBufferizationPass> {
62public:
63 // Private pass options only.
64 SparsificationAndBufferizationPass(
65 const bufferization::OneShotBufferizationOptions &bufferizationOptions,
66 const SparsificationOptions &sparsificationOptions,
67 bool createSparseDeallocs, bool enableRuntimeLibrary,
68 bool enableBufferInitialization)
69 : bufferizationOptions(bufferizationOptions),
70 sparsificationOptions(sparsificationOptions),
71 createSparseDeallocs(createSparseDeallocs),
72 enableRuntimeLibrary(enableRuntimeLibrary),
73 enableBufferInitialization(enableBufferInitialization) {}
74 // Private pass options and visible pass options.
75 SparsificationAndBufferizationPass(
76 const bufferization::OneShotBufferizationOptions &bufferizationOptions,
77 const SparsificationOptions &sparsificationOptions,
78 bool createSparseDeallocs, bool enableRuntimeLibrary,
79 bool enableBufferInitialization, unsigned vl, bool vla, bool index32,
80 bool gpu, SparseEmitStrategy emitStrategy,
81 SparseParallelizationStrategy parallelizationStrategy)
82 : bufferizationOptions(bufferizationOptions),
83 sparsificationOptions(sparsificationOptions),
84 createSparseDeallocs(createSparseDeallocs),
85 enableRuntimeLibrary(enableRuntimeLibrary),
86 enableBufferInitialization(enableBufferInitialization) {
87 // Set the visible pass options explicitly.
88 vectorLength = vl;
89 enableVLAVectorization = vla;
90 enableSIMDIndex32 = index32;
91 enableGPULibgen = gpu;
92 sparseEmitStrategy = emitStrategy;
93 parallelization = parallelizationStrategy;
94 }
95
96 /// Bufferize all dense ops. This assumes that no further analysis is needed
97 /// and that all required buffer copies were already inserted by
98 /// `insertTensorCopies` in the form of `bufferization.alloc_tensor` ops.
99 LogicalResult runDenseBufferization() {
100 bufferization::OneShotBufferizationOptions updatedOptions =
101 bufferizationOptions;
102 // Skip all sparse ops.
103 updatedOptions.opFilter.denyOperation(fn: [&](Operation *op) {
104 if (containsSparseTensor(types: TypeRange(op->getResults())) ||
105 containsSparseTensor(types: TypeRange(op->getOperands())))
106 return true;
107 if (auto funcOp = dyn_cast<func::FuncOp>(Val: op)) {
108 FunctionType funcType = funcOp.getFunctionType();
109 if (containsSparseTensor(types: funcType.getInputs()) ||
110 containsSparseTensor(types: funcType.getResults()))
111 return true;
112 }
113 return false;
114 });
115
116 bufferization::BufferizationState bufferizationState;
117
118 if (failed(Result: bufferization::bufferizeModuleOp(moduleOp: cast<ModuleOp>(Val: getOperation()),
119 options: updatedOptions,
120 state&: bufferizationState)))
121 return failure();
122
123 bufferization::removeBufferizationAttributesInModule(moduleOp: getOperation());
124 return success();
125 }
126
127 void runOnOperation() override {
128 // Overrides the default emit strategy using user-provided value.
129 this->sparsificationOptions.sparseEmitStrategy = sparseEmitStrategy;
130
131 // Overrides the default parallelization strategy using user-provided value.
132 this->sparsificationOptions.parallelizationStrategy = parallelization;
133
134 // Run enabling transformations.
135 {
136 OpPassManager pm("builtin.module");
137 pm.addPass(pass: createPreSparsificationRewritePass());
138 pm.addNestedPass<func::FuncOp>(
139 pass: bufferization::createEmptyTensorToAllocTensorPass());
140 if (failed(Result: runPipeline(pipeline&: pm, op: getOperation())))
141 return signalPassFailure();
142 }
143
144 // Insert tensor copies. This step runs One-Shot Analysis (which analyzes
145 // SSA use-def chains of tensor IR) and decides where buffer copies are
146 // needed and where buffers can be written to in-place. These decisions are
147 // materialized in the IR in the form of `bufferization.alloc_tensor` ops.
148 //
149 // Note: All following steps in this pass must be careful not to modify the
150 // structure of the IR (i.e., tensor use-def chains), as that could
151 // invalidate the results of the analysis. From now on, only small and
152 // localized rewrites are allowed, such as replacing a tensor op with its
153 // memref equivalent.
154 bufferization::BufferizationState bufferizationState;
155
156 if (failed(Result: bufferization::insertTensorCopies(
157 op: getOperation(), options: bufferizationOptions, bufferizationState)))
158 return signalPassFailure();
159
160 // Option `testAnalysisOnly` is a debug/testing flag. If set, the results of
161 // OneShotAnalysis are added to the IR via attributes. In that case, do not
162 // continue with the remaining pipeline.
163 if (bufferizationOptions.testAnalysisOnly)
164 return;
165
166 // Bufferize all sparse ops. No further analysis is needed. All required
167 // buffer copies were already inserted by `insertTensorCopies` in the form
168 // of `bufferization.alloc_tensor` ops.
169 {
170 OpPassManager pm("builtin.module");
171 if (enableGPULibgen)
172 pm.addPass(pass: createSparseGPUCodegenPass(numThreads: 0, enableRT: enableRuntimeLibrary));
173 pm.addPass(pass: createSparseReinterpretMapPass(scope: ReinterpretMapScope::kAll));
174 pm.addPass(pass: createSparsificationPass(options: sparsificationOptions));
175 if (sparsificationOptions.sparseEmitStrategy ==
176 SparseEmitStrategy::kSparseIterator) {
177 pm.addNestedPass<func::FuncOp>(pass: createSparseSpaceCollapsePass());
178 pm.addNestedPass<func::FuncOp>(pass: createLowerSparseIterationToSCFPass());
179 }
180
181 pm.addNestedPass<func::FuncOp>(pass: createStageSparseOperationsPass());
182 pm.addPass(pass: createLowerSparseOpsToForeachPass(enableRT: enableRuntimeLibrary,
183 /*enableConvert=*/true));
184 pm.addPass(
185 pass: createSparseReinterpretMapPass(scope: ReinterpretMapScope::kExceptGeneric));
186 pm.addNestedPass<func::FuncOp>(pass: createLowerForeachToSCFPass());
187 pm.addPass(pass: mlir::createLoopInvariantCodeMotionPass());
188 if (vectorLength > 0) {
189 pm.addPass(pass: createSparseVectorizationPass(
190 vectorLength, enableVLAVectorization, enableSIMDIndex32));
191 }
192 if (enableRuntimeLibrary) {
193 pm.addPass(pass: createSparseTensorConversionPass());
194 } else {
195 pm.addPass(pass: createSparseTensorCodegenPass(createSparseDeallocs,
196 enableBufferInitialization));
197 pm.addPass(pass: createSparseBufferRewritePass(enableBufferInitialization));
198 }
199 if (failed(Result: runPipeline(pipeline&: pm, op: getOperation())))
200 return signalPassFailure();
201 }
202
203 // Bufferize all dense ops.
204 if (failed(Result: runDenseBufferization()))
205 signalPassFailure();
206 }
207
208private:
209 bufferization::OneShotBufferizationOptions bufferizationOptions;
210 SparsificationOptions sparsificationOptions;
211 bool createSparseDeallocs;
212 bool enableRuntimeLibrary;
213 bool enableBufferInitialization;
214};
215
216} // namespace sparse_tensor
217} // namespace mlir
218
219mlir::bufferization::OneShotBufferizationOptions
220mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
221 using namespace mlir::bufferization;
222 OneShotBufferizationOptions options;
223 options.bufferizeFunctionBoundaries = true;
224 options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
225 options.unknownTypeConverterFn = [](TensorType tensorType,
226 Attribute memorySpace,
227 const BufferizationOptions &options) {
228 return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
229 };
230 if (analysisOnly) {
231 options.testAnalysisOnly = true;
232 options.printConflicts = true;
233 }
234 // Since this mini-pipeline may be used in alternative pipelines (viz.
235 // different from the default "sparsifier" pipeline) where unknown ops
236 // are handled by alternative bufferization methods that are downstream
237 // of this mini-pipeline, we allow unknown ops by default (failure to
238 // bufferize is eventually apparent by failing to convert to LLVM IR).
239 options.allowUnknownOps = true;
240 return options;
241}
242
243std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass() {
244 SparsificationOptions sparseOptions;
245 return std::make_unique<
246 mlir::sparse_tensor::SparsificationAndBufferizationPass>(
247 args: getBufferizationOptionsForSparsification(/*analysisOnly=*/false),
248 args&: sparseOptions,
249 /*createSparseDeallocs=*/args: false,
250 /*enableRuntimeLibrary=*/args: false,
251 /*enableBufferInitialization=*/args: false);
252}
253
254std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass(
255 const bufferization::OneShotBufferizationOptions &bufferizationOptions,
256 const SparsificationOptions &sparsificationOptions,
257 bool createSparseDeallocs, bool enableRuntimeLibrary,
258 bool enableBufferInitialization, unsigned vectorLength,
259 bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen,
260 SparseEmitStrategy emitStrategy,
261 SparseParallelizationStrategy parallelizationStrategy) {
262 return std::make_unique<
263 mlir::sparse_tensor::SparsificationAndBufferizationPass>(
264 args: bufferizationOptions, args: sparsificationOptions, args&: createSparseDeallocs,
265 args&: enableRuntimeLibrary, args&: enableBufferInitialization, args&: vectorLength,
266 args&: enableVLAVectorization, args&: enableSIMDIndex32, args&: enableGPULibgen, args&: emitStrategy,
267 args&: parallelizationStrategy);
268}
269

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp