1//===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Utils/CodegenUtils.h"
10
11#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
12#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
13#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
14#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
15#include "llvm/Support/FormatVariadic.h"
16
17using namespace mlir;
18using namespace sparse_tensor;
19
20//===----------------------------------------------------------------------===//
21// Helper methods.
22//===----------------------------------------------------------------------===//
23
24// Convert type range to new types range, with sparse tensors externalized.
25static void convTypes(bool &hasAnnotation, TypeRange types,
26 SmallVectorImpl<Type> &convTypes,
27 SmallVectorImpl<Type> *extraTypes, bool directOut) {
28 for (auto type : types) {
29 // All "dense" data passes through unmodified.
30 if (!getSparseTensorEncoding(type)) {
31 convTypes.push_back(Elt: type);
32 continue;
33 }
34 hasAnnotation = true;
35
36 // Convert the external representations of the pos/crd/val arrays.
37 const SparseTensorType stt(cast<RankedTensorType>(Val&: type));
38 foreachFieldAndTypeInSparseTensor(
39 stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
40 SparseTensorFieldKind kind,
41 Level, LevelType) {
42 if (kind == SparseTensorFieldKind::PosMemRef ||
43 kind == SparseTensorFieldKind::CrdMemRef ||
44 kind == SparseTensorFieldKind::ValMemRef) {
45 auto rtp = cast<ShapedType>(Val&: t);
46 if (!directOut) {
47 rtp = RankedTensorType::get(shape: rtp.getShape(), elementType: rtp.getElementType());
48 if (extraTypes)
49 extraTypes->push_back(Elt: rtp);
50 }
51 convTypes.push_back(Elt: rtp);
52 }
53 return true;
54 });
55 }
56}
57
58// Convert input and output values to [dis]assemble ops for sparse tensors.
59static void convVals(OpBuilder &builder, Location loc, TypeRange types,
60 ValueRange fromVals, ValueRange extraVals,
61 SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
62 bool directOut) {
63 unsigned idx = 0;
64 for (auto type : types) {
65 // All "dense" data passes through unmodified.
66 if (!getSparseTensorEncoding(type)) {
67 toVals.push_back(Elt: fromVals[idx++]);
68 continue;
69 }
70 // Handle sparse data.
71 auto rtp = cast<RankedTensorType>(Val&: type);
72 const SparseTensorType stt(rtp);
73 SmallVector<Value> inputs;
74 SmallVector<Type> retTypes;
75 SmallVector<Type> cntTypes;
76 if (!isIn)
77 inputs.push_back(Elt: fromVals[idx++]); // The sparse tensor to disassemble
78
79 // Collect the external representations of the pos/crd/val arrays.
80 foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
81 SparseTensorFieldKind kind,
82 Level lv, LevelType) {
83 if (kind == SparseTensorFieldKind::PosMemRef ||
84 kind == SparseTensorFieldKind::CrdMemRef ||
85 kind == SparseTensorFieldKind::ValMemRef) {
86 if (isIn) {
87 inputs.push_back(Elt: fromVals[idx++]);
88 } else if (directOut) {
89 Value mem;
90 if (kind == SparseTensorFieldKind::PosMemRef)
91 mem = builder.create<sparse_tensor::ToPositionsOp>(location: loc, args&: inputs[0],
92 args&: lv);
93 else if (kind == SparseTensorFieldKind::CrdMemRef)
94 mem = builder.create<sparse_tensor::ToCoordinatesOp>(location: loc, args&: inputs[0],
95 args&: lv);
96 else
97 mem = builder.create<sparse_tensor::ToValuesOp>(location: loc, args&: inputs[0]);
98 toVals.push_back(Elt: mem);
99 } else {
100 ShapedType rtp = cast<ShapedType>(Val&: t);
101 rtp = RankedTensorType::get(shape: rtp.getShape(), elementType: rtp.getElementType());
102 inputs.push_back(Elt: extraVals[extra++]);
103 retTypes.push_back(Elt: rtp);
104 cntTypes.push_back(Elt: builder.getIndexType());
105 }
106 }
107 return true;
108 });
109
110 if (isIn) {
111 // Assemble multiple inputs into a single sparse tensor.
112 auto a = builder.create<sparse_tensor::AssembleOp>(location: loc, args&: rtp, args&: inputs);
113 toVals.push_back(Elt: a.getResult());
114 } else if (!directOut) {
115 // Disassemble a single sparse input into multiple outputs.
116 // Note that this includes the counters, which are dropped.
117 unsigned len = retTypes.size();
118 retTypes.append(RHS: cntTypes);
119 auto d =
120 builder.create<sparse_tensor::DisassembleOp>(location: loc, args&: retTypes, args&: inputs);
121 for (unsigned i = 0; i < len; i++)
122 toVals.push_back(Elt: d.getResult(i));
123 }
124 }
125}
126
127//===----------------------------------------------------------------------===//
128// Rewriting rules.
129//===----------------------------------------------------------------------===//
130
131namespace {
132
133// A rewriting rules that converts public entry methods that use sparse tensors
134// as input parameters and/or output return values into wrapper methods that
135// [dis]assemble the individual tensors that constitute the actual storage used
136// externally into MLIR sparse tensors before calling the original method.
137//
138// In particular, each sparse tensor input
139//
140// void foo(..., t, ...) { }
141//
142// makes the original foo() internal and adds the following wrapper method
143//
144// void foo(..., t1..tn, ...) {
145// t = assemble t1..tn
146// _internal_foo(..., t, ...)
147// }
148//
149// and likewise, each output tensor
150//
151// ... T ... bar(...) { return ..., t, ...; }
152//
153// makes the original bar() internal and adds the following wrapper method
154//
155// ... T1..TN ... bar(..., t1'..tn') {
156// ..., t, ... = _internal_bar(...)
157// t1..tn = disassemble t, t1'..tn'
158// return ..., t1..tn, ...
159// }
160//
161// (with a direct-out variant without the disassemble).
162//
163struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
164 using OpRewritePattern::OpRewritePattern;
165
166 SparseFuncAssembler(MLIRContext *context, bool dO)
167 : OpRewritePattern(context), directOut(dO) {}
168
169 LogicalResult matchAndRewrite(func::FuncOp funcOp,
170 PatternRewriter &rewriter) const override {
171 // Only rewrite public entry methods.
172 if (funcOp.isPrivate())
173 return failure();
174
175 // Translate sparse tensor types to external types.
176 SmallVector<Type> inputTypes;
177 SmallVector<Type> outputTypes;
178 SmallVector<Type> extraTypes;
179 bool hasAnnotation = false;
180 convTypes(hasAnnotation, types: funcOp.getArgumentTypes(), convTypes&: inputTypes, extraTypes: nullptr,
181 directOut: false);
182 convTypes(hasAnnotation, types: funcOp.getResultTypes(), convTypes&: outputTypes, extraTypes: &extraTypes,
183 directOut);
184
185 // Only sparse inputs or outputs need a wrapper method.
186 if (!hasAnnotation)
187 return failure();
188
189 // Modify the original method into an internal, private method.
190 auto orgName = funcOp.getName();
191 std::string wrapper = llvm::formatv(Fmt: "_internal_{0}", Vals&: orgName).str();
192 funcOp.setName(wrapper);
193 funcOp.setPrivate();
194
195 // Start the new public wrapper method with original name.
196 Location loc = funcOp.getLoc();
197 ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
198 MLIRContext *context = modOp.getContext();
199 OpBuilder moduleBuilder(modOp.getBodyRegion());
200 unsigned extra = inputTypes.size();
201 inputTypes.append(RHS: extraTypes);
202 auto func = moduleBuilder.create<func::FuncOp>(
203 location: loc, args&: orgName, args: FunctionType::get(context, inputs: inputTypes, results: outputTypes));
204 func.setPublic();
205
206 // Construct new wrapper method body.
207 OpBuilder::InsertionGuard insertionGuard(rewriter);
208 Block *body = func.addEntryBlock();
209 rewriter.setInsertionPointToStart(body);
210
211 // Convert inputs.
212 SmallVector<Value> inputs;
213 convVals(builder&: rewriter, loc, types: funcOp.getArgumentTypes(), fromVals: body->getArguments(),
214 extraVals: ValueRange(), toVals&: inputs, /*extra=*/0, /*isIn=*/true, directOut);
215
216 // Call the original, now private method. A subsequent inlining pass can
217 // determine whether cloning the method body in place is worthwhile.
218 auto org = SymbolRefAttr::get(ctx: context, value: wrapper);
219 auto call = rewriter.create<func::CallOp>(location: loc, args: funcOp.getResultTypes(), args&: org,
220 args&: inputs);
221
222 // Convert outputs and return.
223 SmallVector<Value> outputs;
224 convVals(builder&: rewriter, loc, types: funcOp.getResultTypes(), fromVals: call.getResults(),
225 extraVals: body->getArguments(), toVals&: outputs, extra, /*isIn=*/false, directOut);
226 rewriter.create<func::ReturnOp>(location: loc, args&: outputs);
227
228 // Finally, migrate a potential c-interface property.
229 if (funcOp->getAttrOfType<UnitAttr>(
230 name: LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
231 func->setAttr(name: LLVM::LLVMDialect::getEmitCWrapperAttrName(),
232 value: UnitAttr::get(context));
233 funcOp->removeAttr(name: LLVM::LLVMDialect::getEmitCWrapperAttrName());
234 }
235 return success();
236 }
237
238private:
239 const bool directOut;
240};
241
242} // namespace
243
244//===----------------------------------------------------------------------===//
245// Public method for populating conversion rules.
246//===----------------------------------------------------------------------===//
247
248void mlir::populateSparseAssembler(RewritePatternSet &patterns,
249 bool directOut) {
250 patterns.add<SparseFuncAssembler>(arg: patterns.getContext(), args&: directOut);
251}
252

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp