1 | //===- SparseAssembler.cpp - adds wrapper method around sparse types ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "Utils/CodegenUtils.h" |
10 | |
11 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
12 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
13 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" |
14 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
15 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
17 | #include "llvm/Support/FormatVariadic.h" |
18 | |
19 | using namespace mlir; |
20 | using namespace sparse_tensor; |
21 | |
22 | //===----------------------------------------------------------------------===// |
23 | // Helper methods. |
24 | //===----------------------------------------------------------------------===// |
25 | |
26 | // Convert type range to new types range, with sparse tensors externalized. |
27 | static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes, |
28 | SmallVectorImpl<Type> *, bool directOut) { |
29 | for (auto type : types) { |
30 | // All "dense" data passes through unmodified. |
31 | if (!getSparseTensorEncoding(type)) { |
32 | convTypes.push_back(Elt: type); |
33 | continue; |
34 | } |
35 | |
36 | // Convert the external representations of the pos/crd/val arrays. |
37 | const SparseTensorType stt(cast<RankedTensorType>(type)); |
38 | foreachFieldAndTypeInSparseTensor( |
39 | stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex, |
40 | SparseTensorFieldKind kind, |
41 | Level, LevelType) { |
42 | if (kind == SparseTensorFieldKind::PosMemRef || |
43 | kind == SparseTensorFieldKind::CrdMemRef || |
44 | kind == SparseTensorFieldKind::ValMemRef) { |
45 | auto rtp = cast<ShapedType>(t); |
46 | if (!directOut) { |
47 | rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); |
48 | if (extraTypes) |
49 | extraTypes->push_back(Elt: rtp); |
50 | } |
51 | convTypes.push_back(Elt: rtp); |
52 | } |
53 | return true; |
54 | }); |
55 | } |
56 | } |
57 | |
58 | // Convert input and output values to [dis]assemble ops for sparse tensors. |
59 | static void convVals(OpBuilder &builder, Location loc, TypeRange types, |
60 | ValueRange fromVals, ValueRange , |
61 | SmallVectorImpl<Value> &toVals, unsigned , bool isIn, |
62 | bool directOut) { |
63 | unsigned idx = 0; |
64 | for (auto type : types) { |
65 | // All "dense" data passes through unmodified. |
66 | if (!getSparseTensorEncoding(type)) { |
67 | toVals.push_back(Elt: fromVals[idx++]); |
68 | continue; |
69 | } |
70 | // Handle sparse data. |
71 | auto rtp = cast<RankedTensorType>(type); |
72 | const SparseTensorType stt(rtp); |
73 | SmallVector<Value> inputs; |
74 | SmallVector<Type> retTypes; |
75 | SmallVector<Type> cntTypes; |
76 | if (!isIn) |
77 | inputs.push_back(Elt: fromVals[idx++]); // The sparse tensor to disassemble |
78 | |
79 | // Collect the external representations of the pos/crd/val arrays. |
80 | foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex, |
81 | SparseTensorFieldKind kind, |
82 | Level lv, LevelType) { |
83 | if (kind == SparseTensorFieldKind::PosMemRef || |
84 | kind == SparseTensorFieldKind::CrdMemRef || |
85 | kind == SparseTensorFieldKind::ValMemRef) { |
86 | if (isIn) { |
87 | inputs.push_back(Elt: fromVals[idx++]); |
88 | } else if (directOut) { |
89 | Value mem; |
90 | if (kind == SparseTensorFieldKind::PosMemRef) |
91 | mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0], |
92 | lv); |
93 | else if (kind == SparseTensorFieldKind::CrdMemRef) |
94 | mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0], |
95 | lv); |
96 | else |
97 | mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]); |
98 | toVals.push_back(Elt: mem); |
99 | } else { |
100 | ShapedType rtp = cast<ShapedType>(t); |
101 | rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); |
102 | inputs.push_back(Elt: extraVals[extra++]); |
103 | retTypes.push_back(Elt: rtp); |
104 | cntTypes.push_back(builder.getIndexType()); |
105 | } |
106 | } |
107 | return true; |
108 | }); |
109 | |
110 | if (isIn) { |
111 | // Assemble multiple inputs into a single sparse tensor. |
112 | auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs); |
113 | toVals.push_back(Elt: a.getResult()); |
114 | } else if (!directOut) { |
115 | // Disassemble a single sparse input into multiple outputs. |
116 | // Note that this includes the counters, which are dropped. |
117 | unsigned len = retTypes.size(); |
118 | retTypes.append(RHS: cntTypes); |
119 | auto d = |
120 | builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs); |
121 | for (unsigned i = 0; i < len; i++) |
122 | toVals.push_back(Elt: d.getResult(i)); |
123 | } |
124 | } |
125 | } |
126 | |
127 | //===----------------------------------------------------------------------===// |
128 | // Rewriting rules. |
129 | //===----------------------------------------------------------------------===// |
130 | |
131 | namespace { |
132 | |
133 | // A rewriting rules that converts public entry methods that use sparse tensors |
134 | // as input parameters and/or output return values into wrapper methods that |
135 | // [dis]assemble the individual tensors that constitute the actual storage used |
136 | // externally into MLIR sparse tensors before calling the original method. |
137 | // |
138 | // In particular, each sparse tensor input |
139 | // |
140 | // void foo(..., t, ...) { } |
141 | // |
142 | // makes the original foo() internal and adds the following wrapper method |
143 | // |
144 | // void foo(..., t1..tn, ...) { |
145 | // t = assemble t1..tn |
146 | // _internal_foo(..., t, ...) |
147 | // } |
148 | // |
149 | // and likewise, each output tensor |
150 | // |
151 | // ... T ... bar(...) { return ..., t, ...; } |
152 | // |
153 | // makes the original bar() internal and adds the following wrapper method |
154 | // |
155 | // ... T1..TN ... bar(..., t1'..tn') { |
156 | // ..., t, ... = _internal_bar(...) |
157 | // t1..tn = disassemble t, t1'..tn' |
158 | // return ..., t1..tn, ... |
159 | // } |
160 | // |
161 | // (with a direct-out variant without the disassemble). |
162 | // |
163 | struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> { |
164 | using OpRewritePattern::OpRewritePattern; |
165 | |
166 | SparseFuncAssembler(MLIRContext *context, bool dO) |
167 | : OpRewritePattern(context), directOut(dO) {} |
168 | |
169 | LogicalResult matchAndRewrite(func::FuncOp funcOp, |
170 | PatternRewriter &rewriter) const override { |
171 | // Only rewrite public entry methods. |
172 | if (funcOp.isPrivate()) |
173 | return failure(); |
174 | |
175 | // Translate sparse tensor types to external types. |
176 | SmallVector<Type> inputTypes; |
177 | SmallVector<Type> outputTypes; |
178 | SmallVector<Type> ; |
179 | convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, false); |
180 | convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut); |
181 | |
182 | // Only sparse inputs or outputs need a wrapper method. |
183 | if (inputTypes.size() == funcOp.getArgumentTypes().size() && |
184 | outputTypes.size() == funcOp.getResultTypes().size()) |
185 | return failure(); |
186 | |
187 | // Modify the original method into an internal, private method. |
188 | auto orgName = funcOp.getName(); |
189 | std::string wrapper = llvm::formatv("_internal_{0}" , orgName).str(); |
190 | funcOp.setName(wrapper); |
191 | funcOp.setPrivate(); |
192 | |
193 | // Start the new public wrapper method with original name. |
194 | Location loc = funcOp.getLoc(); |
195 | ModuleOp modOp = funcOp->getParentOfType<ModuleOp>(); |
196 | MLIRContext *context = modOp.getContext(); |
197 | OpBuilder moduleBuilder(modOp.getBodyRegion()); |
198 | unsigned = inputTypes.size(); |
199 | inputTypes.append(RHS: extraTypes); |
200 | auto func = moduleBuilder.create<func::FuncOp>( |
201 | loc, orgName, FunctionType::get(context, inputTypes, outputTypes)); |
202 | func.setPublic(); |
203 | |
204 | // Construct new wrapper method body. |
205 | OpBuilder::InsertionGuard insertionGuard(rewriter); |
206 | Block *body = func.addEntryBlock(); |
207 | rewriter.setInsertionPointToStart(body); |
208 | |
209 | // Convert inputs. |
210 | SmallVector<Value> inputs; |
211 | convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(), |
212 | ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut); |
213 | |
214 | // Call the original, now private method. A subsequent inlining pass can |
215 | // determine whether cloning the method body in place is worthwhile. |
216 | auto org = SymbolRefAttr::get(context, wrapper); |
217 | auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org, |
218 | inputs); |
219 | |
220 | // Convert outputs and return. |
221 | SmallVector<Value> outputs; |
222 | convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(), |
223 | body->getArguments(), outputs, extra, /*isIn=*/false, directOut); |
224 | rewriter.create<func::ReturnOp>(loc, outputs); |
225 | |
226 | // Finally, migrate a potential c-interface property. |
227 | if (funcOp->getAttrOfType<UnitAttr>( |
228 | LLVM::LLVMDialect::getEmitCWrapperAttrName())) { |
229 | func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), |
230 | UnitAttr::get(context)); |
231 | funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName()); |
232 | } |
233 | return success(); |
234 | } |
235 | |
236 | private: |
237 | const bool directOut; |
238 | }; |
239 | |
240 | } // namespace |
241 | |
242 | //===----------------------------------------------------------------------===// |
243 | // Public method for populating conversion rules. |
244 | //===----------------------------------------------------------------------===// |
245 | |
246 | void mlir::populateSparseAssembler(RewritePatternSet &patterns, |
247 | bool directOut) { |
248 | patterns.add<SparseFuncAssembler>(arg: patterns.getContext(), args&: directOut); |
249 | } |
250 | |