1 | //===- SparsificationAndBufferizationPass.cpp - Tensor to Memref Lowering -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
10 | |
11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
12 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
13 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
14 | #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" |
15 | #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" |
16 | #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" |
17 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
18 | #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" |
19 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
20 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
21 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
22 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
23 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
24 | #include "mlir/Dialect/SCF/IR/SCF.h" |
25 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
26 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
27 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
28 | #include "mlir/Pass/PassManager.h" |
29 | #include "mlir/Transforms/Passes.h" |
30 | |
31 | using namespace mlir; |
32 | |
33 | namespace mlir { |
34 | |
35 | #define GEN_PASS_DEF_SPARSIFICATIONANDBUFFERIZATION |
36 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" |
37 | |
38 | namespace sparse_tensor { |
39 | |
40 | /// Return `true` if one of the given types is a sparse tensor type. |
41 | static bool containsSparseTensor(TypeRange types) { |
42 | for (Type t : types) |
43 | if (isa<TensorType>(t) && getSparseTensorEncoding(t)) |
44 | return true; |
45 | return false; |
46 | } |
47 | |
48 | /// A pass that lowers tensor ops to memref ops, regardless of whether they are |
49 | /// dense or sparse. |
50 | /// |
51 | /// One-Shot Analysis is used to detect RaW conflicts and to insert buffer |
52 | /// copies of the tensor level (`insertTensorCopies`). Afterwards, the lowering |
53 | /// of tensor ops to memref ops follows a different code path depending on |
54 | /// whether the op is sparse or dense: |
55 | /// |
56 | /// * Sparse tensor ops are lowered through Sparsification and follow-up pass |
57 | /// that lowers sparse_tensor dialect ops. |
58 | /// * Dense tensor ops are lowered through BufferizableOpInterface |
59 | /// implementations. |
60 | class SparsificationAndBufferizationPass |
61 | : public impl::SparsificationAndBufferizationBase< |
62 | SparsificationAndBufferizationPass> { |
63 | public: |
64 | SparsificationAndBufferizationPass( |
65 | const bufferization::OneShotBufferizationOptions &bufferizationOptions, |
66 | const SparsificationOptions &sparsificationOptions, |
67 | bool createSparseDeallocs, bool enableRuntimeLibrary, |
68 | bool enableBufferInitialization, unsigned vectorLength, |
69 | bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen) |
70 | : bufferizationOptions(bufferizationOptions), |
71 | sparsificationOptions(sparsificationOptions), |
72 | createSparseDeallocs(createSparseDeallocs), |
73 | enableRuntimeLibrary(enableRuntimeLibrary), |
74 | enableBufferInitialization(enableBufferInitialization), |
75 | vectorLength(vectorLength), |
76 | enableVLAVectorization(enableVLAVectorization), |
77 | enableSIMDIndex32(enableSIMDIndex32), enableGPULibgen(enableGPULibgen) { |
78 | } |
79 | |
80 | /// Bufferize all dense ops. This assumes that no further analysis is needed |
81 | /// and that all required buffer copies were already inserted by |
82 | /// `insertTensorCopies` in the form of `bufferization.alloc_tensor` ops. |
83 | LogicalResult runDenseBufferization() { |
84 | bufferization::OneShotBufferizationOptions updatedOptions = |
85 | bufferizationOptions; |
86 | // Skip all sparse ops. |
87 | updatedOptions.opFilter.denyOperation([&](Operation *op) { |
88 | if (containsSparseTensor(types: TypeRange(op->getResults())) || |
89 | containsSparseTensor(types: TypeRange(op->getOperands()))) |
90 | return true; |
91 | if (auto funcOp = dyn_cast<func::FuncOp>(op)) { |
92 | FunctionType funcType = funcOp.getFunctionType(); |
93 | if (containsSparseTensor(funcType.getInputs()) || |
94 | containsSparseTensor(funcType.getResults())) |
95 | return true; |
96 | } |
97 | return false; |
98 | }); |
99 | |
100 | if (failed(bufferization::bufferizeModuleOp(moduleOp: cast<ModuleOp>(getOperation()), |
101 | options: updatedOptions))) |
102 | return failure(); |
103 | |
104 | bufferization::removeBufferizationAttributesInModule(moduleOp: getOperation()); |
105 | return success(); |
106 | } |
107 | |
108 | void runOnOperation() override { |
109 | // Run enabling transformations. |
110 | { |
111 | OpPassManager pm("builtin.module" ); |
112 | pm.addPass(pass: createPreSparsificationRewritePass()); |
113 | pm.addNestedPass<func::FuncOp>( |
114 | bufferization::createEmptyTensorToAllocTensorPass()); |
115 | if (failed(runPipeline(pm, getOperation()))) |
116 | return signalPassFailure(); |
117 | } |
118 | |
119 | // Insert tensor copies. This step runs One-Shot Analysis (which analyzes |
120 | // SSA use-def chains of tensor IR) and decides where buffer copies are |
121 | // needed and where buffers can be written to in-place. These decisions are |
122 | // materialized in the IR in the form of `bufferization.alloc_tensor` ops. |
123 | // |
124 | // Note: All following steps in this pass must be careful not to modify the |
125 | // structure of the IR (i.e., tensor use-def chains), as that could |
126 | // invalidate the results of the analysis. From now on, only small and |
127 | // localized rewrites are allowed, such as replacing a tensor op with its |
128 | // memref equivalent. |
129 | if (failed(bufferization::insertTensorCopies(getOperation(), |
130 | bufferizationOptions))) |
131 | return signalPassFailure(); |
132 | |
133 | // Option `testAnalysisOnly` is a debug/testing flag. If set, the results of |
134 | // OneShotAnalysis are added to the IR via attributes. In that case, do not |
135 | // continue with the remaining pipeline. |
136 | if (bufferizationOptions.testAnalysisOnly) |
137 | return; |
138 | |
139 | // Bufferize all sparse ops. No further analysis is needed. All required |
140 | // buffer copies were already inserted by `insertTensorCopies` in the form |
141 | // of `bufferization.alloc_tensor` ops. |
142 | { |
143 | OpPassManager pm("builtin.module" ); |
144 | if (enableGPULibgen) |
145 | pm.addPass(createSparseGPUCodegenPass(0, enableRuntimeLibrary)); |
146 | pm.addPass(pass: createSparseReinterpretMapPass(scope: ReinterpretMapScope::kAll)); |
147 | pm.addPass(createSparsificationPass(sparsificationOptions)); |
148 | pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass()); |
149 | pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary, |
150 | /*enableConvert=*/true)); |
151 | pm.addPass( |
152 | pass: createSparseReinterpretMapPass(scope: ReinterpretMapScope::kExceptGeneric)); |
153 | pm.addNestedPass<func::FuncOp>(createLowerForeachToSCFPass()); |
154 | pm.addPass(mlir::pass: createLoopInvariantCodeMotionPass()); |
155 | if (vectorLength > 0) { |
156 | pm.addPass(createSparseVectorizationPass( |
157 | vectorLength, enableVLAVectorization, enableSIMDIndex32)); |
158 | } |
159 | if (enableRuntimeLibrary) { |
160 | pm.addPass(createSparseTensorConversionPass()); |
161 | } else { |
162 | pm.addPass(createSparseTensorCodegenPass(createSparseDeallocs, |
163 | enableBufferInitialization)); |
164 | pm.addPass(createSparseBufferRewritePass(enableBufferInitialization)); |
165 | } |
166 | if (failed(runPipeline(pm, getOperation()))) |
167 | return signalPassFailure(); |
168 | } |
169 | |
170 | // Bufferize all dense ops. |
171 | if (failed(result: runDenseBufferization())) |
172 | signalPassFailure(); |
173 | } |
174 | |
175 | private: |
176 | bufferization::OneShotBufferizationOptions bufferizationOptions; |
177 | SparsificationOptions sparsificationOptions; |
178 | bool createSparseDeallocs; |
179 | bool enableRuntimeLibrary; |
180 | bool enableBufferInitialization; |
181 | unsigned vectorLength; |
182 | bool enableVLAVectorization; |
183 | bool enableSIMDIndex32; |
184 | bool enableGPULibgen; |
185 | }; |
186 | |
187 | } // namespace sparse_tensor |
188 | } // namespace mlir |
189 | |
190 | mlir::bufferization::OneShotBufferizationOptions |
191 | mlir::getBufferizationOptionsForSparsification(bool analysisOnly) { |
192 | using namespace mlir::bufferization; |
193 | OneShotBufferizationOptions options; |
194 | options.bufferizeFunctionBoundaries = true; |
195 | options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap); |
196 | options.unknownTypeConverterFn = [](Value value, Attribute memorySpace, |
197 | const BufferizationOptions &options) { |
198 | return getMemRefTypeWithStaticIdentityLayout( |
199 | tensorType: cast<TensorType>(Val: value.getType()), memorySpace); |
200 | }; |
201 | if (analysisOnly) { |
202 | options.testAnalysisOnly = true; |
203 | options.printConflicts = true; |
204 | } |
205 | // Since this mini-pipeline may be used in alternative pipelines (viz. |
206 | // different from the default "sparsifier" pipeline) where unknown ops |
207 | // are handled by alternative bufferization methods that are downstream |
208 | // of this mini-pipeline, we allow unknown ops by default (failure to |
209 | // bufferize is eventually apparent by failing to convert to LLVM IR). |
210 | options.allowUnknownOps = true; |
211 | return options; |
212 | } |
213 | |
214 | std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass() { |
215 | SparsificationOptions sparseOptions; |
216 | return createSparsificationAndBufferizationPass( |
217 | getBufferizationOptionsForSparsification(/*analysisOnly=*/false), |
218 | sparseOptions, |
219 | /*createSparseDeallocs=*/false, |
220 | /*enableRuntimeLibrary=*/false, |
221 | /*enableBufferInitialization=*/false, |
222 | /*vectorLength=*/0, |
223 | /*enableVLAVectorization=*/false, |
224 | /*enableSIMDIndex32=*/false, |
225 | /*enableGPULibgen=*/false); |
226 | } |
227 | |
228 | std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass( |
229 | const bufferization::OneShotBufferizationOptions &bufferizationOptions, |
230 | const SparsificationOptions &sparsificationOptions, |
231 | bool createSparseDeallocs, bool enableRuntimeLibrary, |
232 | bool enableBufferInitialization, unsigned vectorLength, |
233 | bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen) { |
234 | return std::make_unique< |
235 | mlir::sparse_tensor::SparsificationAndBufferizationPass>( |
236 | args: bufferizationOptions, args: sparsificationOptions, args&: createSparseDeallocs, |
237 | args&: enableRuntimeLibrary, args&: enableBufferInitialization, args&: vectorLength, |
238 | args&: enableVLAVectorization, args&: enableSIMDIndex32, args&: enableGPULibgen); |
239 | } |
240 | |