1//===- SparsificationAndBufferizationPass.cpp - Tensor to Memref Lowering -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
13#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
15#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
16#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
17#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
18#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/GPU/IR/GPUDialect.h"
21#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22#include "mlir/Dialect/Linalg/IR/Linalg.h"
23#include "mlir/Dialect/MemRef/IR/MemRef.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
26#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
27#include "mlir/Dialect/Vector/IR/VectorOps.h"
28#include "mlir/Pass/PassManager.h"
29#include "mlir/Transforms/Passes.h"
30
31using namespace mlir;
32
33namespace mlir {
34
35#define GEN_PASS_DEF_SPARSIFICATIONANDBUFFERIZATION
36#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
37
38namespace sparse_tensor {
39
40/// Return `true` if one of the given types is a sparse tensor type.
41static bool containsSparseTensor(TypeRange types) {
42 for (Type t : types)
43 if (isa<TensorType>(t) && getSparseTensorEncoding(t))
44 return true;
45 return false;
46}
47
48/// A pass that lowers tensor ops to memref ops, regardless of whether they are
49/// dense or sparse.
50///
51/// One-Shot Analysis is used to detect RaW conflicts and to insert buffer
52/// copies of the tensor level (`insertTensorCopies`). Afterwards, the lowering
53/// of tensor ops to memref ops follows a different code path depending on
54/// whether the op is sparse or dense:
55///
56/// * Sparse tensor ops are lowered through Sparsification and follow-up pass
57/// that lowers sparse_tensor dialect ops.
58/// * Dense tensor ops are lowered through BufferizableOpInterface
59/// implementations.
60class SparsificationAndBufferizationPass
61 : public impl::SparsificationAndBufferizationBase<
62 SparsificationAndBufferizationPass> {
63public:
64 // Private pass options only.
65 SparsificationAndBufferizationPass(
66 const bufferization::OneShotBufferizationOptions &bufferizationOptions,
67 const SparsificationOptions &sparsificationOptions,
68 bool createSparseDeallocs, bool enableRuntimeLibrary,
69 bool enableBufferInitialization)
70 : bufferizationOptions(bufferizationOptions),
71 sparsificationOptions(sparsificationOptions),
72 createSparseDeallocs(createSparseDeallocs),
73 enableRuntimeLibrary(enableRuntimeLibrary),
74 enableBufferInitialization(enableBufferInitialization) {}
75 // Private pass options and visible pass options.
76 SparsificationAndBufferizationPass(
77 const bufferization::OneShotBufferizationOptions &bufferizationOptions,
78 const SparsificationOptions &sparsificationOptions,
79 bool createSparseDeallocs, bool enableRuntimeLibrary,
80 bool enableBufferInitialization, unsigned vl, bool vla, bool index32,
81 bool gpu, SparseEmitStrategy emitStrategy,
82 SparseParallelizationStrategy parallelizationStrategy)
83 : bufferizationOptions(bufferizationOptions),
84 sparsificationOptions(sparsificationOptions),
85 createSparseDeallocs(createSparseDeallocs),
86 enableRuntimeLibrary(enableRuntimeLibrary),
87 enableBufferInitialization(enableBufferInitialization) {
88 // Set the visible pass options explicitly.
89 vectorLength = vl;
90 enableVLAVectorization = vla;
91 enableSIMDIndex32 = index32;
92 enableGPULibgen = gpu;
93 sparseEmitStrategy = emitStrategy;
94 parallelization = parallelizationStrategy;
95 }
96
97 /// Bufferize all dense ops. This assumes that no further analysis is needed
98 /// and that all required buffer copies were already inserted by
99 /// `insertTensorCopies` in the form of `bufferization.alloc_tensor` ops.
100 LogicalResult runDenseBufferization() {
101 bufferization::OneShotBufferizationOptions updatedOptions =
102 bufferizationOptions;
103 // Skip all sparse ops.
104 updatedOptions.opFilter.denyOperation([&](Operation *op) {
105 if (containsSparseTensor(types: TypeRange(op->getResults())) ||
106 containsSparseTensor(types: TypeRange(op->getOperands())))
107 return true;
108 if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
109 FunctionType funcType = funcOp.getFunctionType();
110 if (containsSparseTensor(funcType.getInputs()) ||
111 containsSparseTensor(funcType.getResults()))
112 return true;
113 }
114 return false;
115 });
116
117 bufferization::BufferizationState bufferizationState;
118
119 if (failed(bufferization::bufferizeModuleOp(moduleOp: cast<ModuleOp>(getOperation()),
120 options: updatedOptions,
121 state&: bufferizationState)))
122 return failure();
123
124 bufferization::removeBufferizationAttributesInModule(moduleOp: getOperation());
125 return success();
126 }
127
128 void runOnOperation() override {
129 // Overrides the default emit strategy using user-provided value.
130 this->sparsificationOptions.sparseEmitStrategy = sparseEmitStrategy;
131
132 // Overrides the default parallelization strategy using user-provided value.
133 this->sparsificationOptions.parallelizationStrategy = parallelization;
134
135 // Run enabling transformations.
136 {
137 OpPassManager pm("builtin.module");
138 pm.addPass(pass: createPreSparsificationRewritePass());
139 pm.addNestedPass<func::FuncOp>(
140 bufferization::createEmptyTensorToAllocTensorPass());
141 if (failed(runPipeline(pm, getOperation())))
142 return signalPassFailure();
143 }
144
145 // Insert tensor copies. This step runs One-Shot Analysis (which analyzes
146 // SSA use-def chains of tensor IR) and decides where buffer copies are
147 // needed and where buffers can be written to in-place. These decisions are
148 // materialized in the IR in the form of `bufferization.alloc_tensor` ops.
149 //
150 // Note: All following steps in this pass must be careful not to modify the
151 // structure of the IR (i.e., tensor use-def chains), as that could
152 // invalidate the results of the analysis. From now on, only small and
153 // localized rewrites are allowed, such as replacing a tensor op with its
154 // memref equivalent.
155 bufferization::BufferizationState bufferizationState;
156
157 if (failed(bufferization::insertTensorCopies(
158 getOperation(), bufferizationOptions, bufferizationState)))
159 return signalPassFailure();
160
161 // Option `testAnalysisOnly` is a debug/testing flag. If set, the results of
162 // OneShotAnalysis are added to the IR via attributes. In that case, do not
163 // continue with the remaining pipeline.
164 if (bufferizationOptions.testAnalysisOnly)
165 return;
166
167 // Bufferize all sparse ops. No further analysis is needed. All required
168 // buffer copies were already inserted by `insertTensorCopies` in the form
169 // of `bufferization.alloc_tensor` ops.
170 {
171 OpPassManager pm("builtin.module");
172 if (enableGPULibgen)
173 pm.addPass(createSparseGPUCodegenPass(0, enableRuntimeLibrary));
174 pm.addPass(pass: createSparseReinterpretMapPass(scope: ReinterpretMapScope::kAll));
175 pm.addPass(createSparsificationPass(sparsificationOptions));
176 if (sparsificationOptions.sparseEmitStrategy ==
177 SparseEmitStrategy::kSparseIterator) {
178 pm.addNestedPass<func::FuncOp>(createSparseSpaceCollapsePass());
179 pm.addNestedPass<func::FuncOp>(createLowerSparseIterationToSCFPass());
180 }
181
182 pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
183 pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,
184 /*enableConvert=*/true));
185 pm.addPass(
186 pass: createSparseReinterpretMapPass(scope: ReinterpretMapScope::kExceptGeneric));
187 pm.addNestedPass<func::FuncOp>(createLowerForeachToSCFPass());
188 pm.addPass(mlir::pass: createLoopInvariantCodeMotionPass());
189 if (vectorLength > 0) {
190 pm.addPass(createSparseVectorizationPass(
191 vectorLength, enableVLAVectorization, enableSIMDIndex32));
192 }
193 if (enableRuntimeLibrary) {
194 pm.addPass(createSparseTensorConversionPass());
195 } else {
196 pm.addPass(createSparseTensorCodegenPass(createSparseDeallocs,
197 enableBufferInitialization));
198 pm.addPass(createSparseBufferRewritePass(enableBufferInitialization));
199 }
200 if (failed(runPipeline(pm, getOperation())))
201 return signalPassFailure();
202 }
203
204 // Bufferize all dense ops.
205 if (failed(Result: runDenseBufferization()))
206 signalPassFailure();
207 }
208
209private:
210 bufferization::OneShotBufferizationOptions bufferizationOptions;
211 SparsificationOptions sparsificationOptions;
212 bool createSparseDeallocs;
213 bool enableRuntimeLibrary;
214 bool enableBufferInitialization;
215};
216
217} // namespace sparse_tensor
218} // namespace mlir
219
220mlir::bufferization::OneShotBufferizationOptions
221mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
222 using namespace mlir::bufferization;
223 OneShotBufferizationOptions options;
224 options.bufferizeFunctionBoundaries = true;
225 options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
226 options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
227 const BufferizationOptions &options) {
228 return getMemRefTypeWithStaticIdentityLayout(
229 tensorType: cast<TensorType>(Val: value.getType()), memorySpace);
230 };
231 if (analysisOnly) {
232 options.testAnalysisOnly = true;
233 options.printConflicts = true;
234 }
235 // Since this mini-pipeline may be used in alternative pipelines (viz.
236 // different from the default "sparsifier" pipeline) where unknown ops
237 // are handled by alternative bufferization methods that are downstream
238 // of this mini-pipeline, we allow unknown ops by default (failure to
239 // bufferize is eventually apparent by failing to convert to LLVM IR).
240 options.allowUnknownOps = true;
241 return options;
242}
243
244std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass() {
245 SparsificationOptions sparseOptions;
246 return std::make_unique<
247 mlir::sparse_tensor::SparsificationAndBufferizationPass>(
248 args: getBufferizationOptionsForSparsification(/*analysisOnly=*/false),
249 args&: sparseOptions,
250 /*createSparseDeallocs=*/args: false,
251 /*enableRuntimeLibrary=*/args: false,
252 /*enableBufferInitialization=*/args: false);
253}
254
255std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass(
256 const bufferization::OneShotBufferizationOptions &bufferizationOptions,
257 const SparsificationOptions &sparsificationOptions,
258 bool createSparseDeallocs, bool enableRuntimeLibrary,
259 bool enableBufferInitialization, unsigned vectorLength,
260 bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen,
261 SparseEmitStrategy emitStrategy,
262 SparseParallelizationStrategy parallelizationStrategy) {
263 return std::make_unique<
264 mlir::sparse_tensor::SparsificationAndBufferizationPass>(
265 args: bufferizationOptions, args: sparsificationOptions, args&: createSparseDeallocs,
266 args&: enableRuntimeLibrary, args&: enableBufferInitialization, args&: vectorLength,
267 args&: enableVLAVectorization, args&: enableSIMDIndex32, args&: enableGPULibgen, args&: emitStrategy,
268 args&: parallelizationStrategy);
269}
270

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp