1 | //===- SparseAssembler.cpp - adds wrapper method around sparse types ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "Utils/CodegenUtils.h" |
10 | |
11 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
12 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
13 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" |
14 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
15 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
17 | #include "llvm/Support/FormatVariadic.h" |
18 | |
19 | using namespace mlir; |
20 | using namespace sparse_tensor; |
21 | |
22 | //===----------------------------------------------------------------------===// |
23 | // Helper methods. |
24 | //===----------------------------------------------------------------------===// |
25 | |
26 | // Convert type range to new types range, with sparse tensors externalized. |
27 | static void convTypes(bool &hasAnnotation, TypeRange types, |
28 | SmallVectorImpl<Type> &convTypes, |
29 | SmallVectorImpl<Type> *, bool directOut) { |
30 | for (auto type : types) { |
31 | // All "dense" data passes through unmodified. |
32 | if (!getSparseTensorEncoding(type)) { |
33 | convTypes.push_back(Elt: type); |
34 | continue; |
35 | } |
36 | hasAnnotation = true; |
37 | |
38 | // Convert the external representations of the pos/crd/val arrays. |
39 | const SparseTensorType stt(cast<RankedTensorType>(type)); |
40 | foreachFieldAndTypeInSparseTensor( |
41 | stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex, |
42 | SparseTensorFieldKind kind, |
43 | Level, LevelType) { |
44 | if (kind == SparseTensorFieldKind::PosMemRef || |
45 | kind == SparseTensorFieldKind::CrdMemRef || |
46 | kind == SparseTensorFieldKind::ValMemRef) { |
47 | auto rtp = cast<ShapedType>(t); |
48 | if (!directOut) { |
49 | rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); |
50 | if (extraTypes) |
51 | extraTypes->push_back(Elt: rtp); |
52 | } |
53 | convTypes.push_back(Elt: rtp); |
54 | } |
55 | return true; |
56 | }); |
57 | } |
58 | } |
59 | |
60 | // Convert input and output values to [dis]assemble ops for sparse tensors. |
61 | static void convVals(OpBuilder &builder, Location loc, TypeRange types, |
62 | ValueRange fromVals, ValueRange , |
63 | SmallVectorImpl<Value> &toVals, unsigned , bool isIn, |
64 | bool directOut) { |
65 | unsigned idx = 0; |
66 | for (auto type : types) { |
67 | // All "dense" data passes through unmodified. |
68 | if (!getSparseTensorEncoding(type)) { |
69 | toVals.push_back(Elt: fromVals[idx++]); |
70 | continue; |
71 | } |
72 | // Handle sparse data. |
73 | auto rtp = cast<RankedTensorType>(type); |
74 | const SparseTensorType stt(rtp); |
75 | SmallVector<Value> inputs; |
76 | SmallVector<Type> retTypes; |
77 | SmallVector<Type> cntTypes; |
78 | if (!isIn) |
79 | inputs.push_back(Elt: fromVals[idx++]); // The sparse tensor to disassemble |
80 | |
81 | // Collect the external representations of the pos/crd/val arrays. |
82 | foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex, |
83 | SparseTensorFieldKind kind, |
84 | Level lv, LevelType) { |
85 | if (kind == SparseTensorFieldKind::PosMemRef || |
86 | kind == SparseTensorFieldKind::CrdMemRef || |
87 | kind == SparseTensorFieldKind::ValMemRef) { |
88 | if (isIn) { |
89 | inputs.push_back(Elt: fromVals[idx++]); |
90 | } else if (directOut) { |
91 | Value mem; |
92 | if (kind == SparseTensorFieldKind::PosMemRef) |
93 | mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0], |
94 | lv); |
95 | else if (kind == SparseTensorFieldKind::CrdMemRef) |
96 | mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0], |
97 | lv); |
98 | else |
99 | mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]); |
100 | toVals.push_back(Elt: mem); |
101 | } else { |
102 | ShapedType rtp = cast<ShapedType>(t); |
103 | rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); |
104 | inputs.push_back(Elt: extraVals[extra++]); |
105 | retTypes.push_back(Elt: rtp); |
106 | cntTypes.push_back(builder.getIndexType()); |
107 | } |
108 | } |
109 | return true; |
110 | }); |
111 | |
112 | if (isIn) { |
113 | // Assemble multiple inputs into a single sparse tensor. |
114 | auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs); |
115 | toVals.push_back(Elt: a.getResult()); |
116 | } else if (!directOut) { |
117 | // Disassemble a single sparse input into multiple outputs. |
118 | // Note that this includes the counters, which are dropped. |
119 | unsigned len = retTypes.size(); |
120 | retTypes.append(RHS: cntTypes); |
121 | auto d = |
122 | builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs); |
123 | for (unsigned i = 0; i < len; i++) |
124 | toVals.push_back(Elt: d.getResult(i)); |
125 | } |
126 | } |
127 | } |
128 | |
129 | //===----------------------------------------------------------------------===// |
130 | // Rewriting rules. |
131 | //===----------------------------------------------------------------------===// |
132 | |
133 | namespace { |
134 | |
135 | // A rewriting rules that converts public entry methods that use sparse tensors |
136 | // as input parameters and/or output return values into wrapper methods that |
137 | // [dis]assemble the individual tensors that constitute the actual storage used |
138 | // externally into MLIR sparse tensors before calling the original method. |
139 | // |
140 | // In particular, each sparse tensor input |
141 | // |
142 | // void foo(..., t, ...) { } |
143 | // |
144 | // makes the original foo() internal and adds the following wrapper method |
145 | // |
146 | // void foo(..., t1..tn, ...) { |
147 | // t = assemble t1..tn |
148 | // _internal_foo(..., t, ...) |
149 | // } |
150 | // |
151 | // and likewise, each output tensor |
152 | // |
153 | // ... T ... bar(...) { return ..., t, ...; } |
154 | // |
155 | // makes the original bar() internal and adds the following wrapper method |
156 | // |
157 | // ... T1..TN ... bar(..., t1'..tn') { |
158 | // ..., t, ... = _internal_bar(...) |
159 | // t1..tn = disassemble t, t1'..tn' |
160 | // return ..., t1..tn, ... |
161 | // } |
162 | // |
163 | // (with a direct-out variant without the disassemble). |
164 | // |
165 | struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> { |
166 | using OpRewritePattern::OpRewritePattern; |
167 | |
168 | SparseFuncAssembler(MLIRContext *context, bool dO) |
169 | : OpRewritePattern(context), directOut(dO) {} |
170 | |
171 | LogicalResult matchAndRewrite(func::FuncOp funcOp, |
172 | PatternRewriter &rewriter) const override { |
173 | // Only rewrite public entry methods. |
174 | if (funcOp.isPrivate()) |
175 | return failure(); |
176 | |
177 | // Translate sparse tensor types to external types. |
178 | SmallVector<Type> inputTypes; |
179 | SmallVector<Type> outputTypes; |
180 | SmallVector<Type> ; |
181 | bool hasAnnotation = false; |
182 | convTypes(hasAnnotation, funcOp.getArgumentTypes(), inputTypes, nullptr, |
183 | false); |
184 | convTypes(hasAnnotation, funcOp.getResultTypes(), outputTypes, &extraTypes, |
185 | directOut); |
186 | |
187 | // Only sparse inputs or outputs need a wrapper method. |
188 | if (!hasAnnotation) |
189 | return failure(); |
190 | |
191 | // Modify the original method into an internal, private method. |
192 | auto orgName = funcOp.getName(); |
193 | std::string wrapper = llvm::formatv("_internal_{0}" , orgName).str(); |
194 | funcOp.setName(wrapper); |
195 | funcOp.setPrivate(); |
196 | |
197 | // Start the new public wrapper method with original name. |
198 | Location loc = funcOp.getLoc(); |
199 | ModuleOp modOp = funcOp->getParentOfType<ModuleOp>(); |
200 | MLIRContext *context = modOp.getContext(); |
201 | OpBuilder moduleBuilder(modOp.getBodyRegion()); |
202 | unsigned = inputTypes.size(); |
203 | inputTypes.append(RHS: extraTypes); |
204 | auto func = moduleBuilder.create<func::FuncOp>( |
205 | loc, orgName, FunctionType::get(context, inputTypes, outputTypes)); |
206 | func.setPublic(); |
207 | |
208 | // Construct new wrapper method body. |
209 | OpBuilder::InsertionGuard insertionGuard(rewriter); |
210 | Block *body = func.addEntryBlock(); |
211 | rewriter.setInsertionPointToStart(body); |
212 | |
213 | // Convert inputs. |
214 | SmallVector<Value> inputs; |
215 | convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(), |
216 | ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut); |
217 | |
218 | // Call the original, now private method. A subsequent inlining pass can |
219 | // determine whether cloning the method body in place is worthwhile. |
220 | auto org = SymbolRefAttr::get(context, wrapper); |
221 | auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org, |
222 | inputs); |
223 | |
224 | // Convert outputs and return. |
225 | SmallVector<Value> outputs; |
226 | convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(), |
227 | body->getArguments(), outputs, extra, /*isIn=*/false, directOut); |
228 | rewriter.create<func::ReturnOp>(loc, outputs); |
229 | |
230 | // Finally, migrate a potential c-interface property. |
231 | if (funcOp->getAttrOfType<UnitAttr>( |
232 | LLVM::LLVMDialect::getEmitCWrapperAttrName())) { |
233 | func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), |
234 | UnitAttr::get(context)); |
235 | funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName()); |
236 | } |
237 | return success(); |
238 | } |
239 | |
240 | private: |
241 | const bool directOut; |
242 | }; |
243 | |
244 | } // namespace |
245 | |
246 | //===----------------------------------------------------------------------===// |
247 | // Public method for populating conversion rules. |
248 | //===----------------------------------------------------------------------===// |
249 | |
250 | void mlir::populateSparseAssembler(RewritePatternSet &patterns, |
251 | bool directOut) { |
252 | patterns.add<SparseFuncAssembler>(arg: patterns.getContext(), args&: directOut); |
253 | } |
254 | |