1//===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Utils/CodegenUtils.h"
10
11#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
13#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
14#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
15#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "llvm/Support/FormatVariadic.h"
18
19using namespace mlir;
20using namespace sparse_tensor;
21
22//===----------------------------------------------------------------------===//
23// Helper methods.
24//===----------------------------------------------------------------------===//
25
26// Convert type range to new types range, with sparse tensors externalized.
27static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
28 SmallVectorImpl<Type> *extraTypes, bool directOut) {
29 for (auto type : types) {
30 // All "dense" data passes through unmodified.
31 if (!getSparseTensorEncoding(type)) {
32 convTypes.push_back(Elt: type);
33 continue;
34 }
35
36 // Convert the external representations of the pos/crd/val arrays.
37 const SparseTensorType stt(cast<RankedTensorType>(type));
38 foreachFieldAndTypeInSparseTensor(
39 stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
40 SparseTensorFieldKind kind,
41 Level, LevelType) {
42 if (kind == SparseTensorFieldKind::PosMemRef ||
43 kind == SparseTensorFieldKind::CrdMemRef ||
44 kind == SparseTensorFieldKind::ValMemRef) {
45 auto rtp = cast<ShapedType>(t);
46 if (!directOut) {
47 rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
48 if (extraTypes)
49 extraTypes->push_back(Elt: rtp);
50 }
51 convTypes.push_back(Elt: rtp);
52 }
53 return true;
54 });
55 }
56}
57
58// Convert input and output values to [dis]assemble ops for sparse tensors.
59static void convVals(OpBuilder &builder, Location loc, TypeRange types,
60 ValueRange fromVals, ValueRange extraVals,
61 SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
62 bool directOut) {
63 unsigned idx = 0;
64 for (auto type : types) {
65 // All "dense" data passes through unmodified.
66 if (!getSparseTensorEncoding(type)) {
67 toVals.push_back(Elt: fromVals[idx++]);
68 continue;
69 }
70 // Handle sparse data.
71 auto rtp = cast<RankedTensorType>(type);
72 const SparseTensorType stt(rtp);
73 SmallVector<Value> inputs;
74 SmallVector<Type> retTypes;
75 SmallVector<Type> cntTypes;
76 if (!isIn)
77 inputs.push_back(Elt: fromVals[idx++]); // The sparse tensor to disassemble
78
79 // Collect the external representations of the pos/crd/val arrays.
80 foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
81 SparseTensorFieldKind kind,
82 Level lv, LevelType) {
83 if (kind == SparseTensorFieldKind::PosMemRef ||
84 kind == SparseTensorFieldKind::CrdMemRef ||
85 kind == SparseTensorFieldKind::ValMemRef) {
86 if (isIn) {
87 inputs.push_back(Elt: fromVals[idx++]);
88 } else if (directOut) {
89 Value mem;
90 if (kind == SparseTensorFieldKind::PosMemRef)
91 mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
92 lv);
93 else if (kind == SparseTensorFieldKind::CrdMemRef)
94 mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0],
95 lv);
96 else
97 mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
98 toVals.push_back(Elt: mem);
99 } else {
100 ShapedType rtp = cast<ShapedType>(t);
101 rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
102 inputs.push_back(Elt: extraVals[extra++]);
103 retTypes.push_back(Elt: rtp);
104 cntTypes.push_back(builder.getIndexType());
105 }
106 }
107 return true;
108 });
109
110 if (isIn) {
111 // Assemble multiple inputs into a single sparse tensor.
112 auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
113 toVals.push_back(Elt: a.getResult());
114 } else if (!directOut) {
115 // Disassemble a single sparse input into multiple outputs.
116 // Note that this includes the counters, which are dropped.
117 unsigned len = retTypes.size();
118 retTypes.append(RHS: cntTypes);
119 auto d =
120 builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs);
121 for (unsigned i = 0; i < len; i++)
122 toVals.push_back(Elt: d.getResult(i));
123 }
124 }
125}
126
127//===----------------------------------------------------------------------===//
128// Rewriting rules.
129//===----------------------------------------------------------------------===//
130
131namespace {
132
133// A rewriting rules that converts public entry methods that use sparse tensors
134// as input parameters and/or output return values into wrapper methods that
135// [dis]assemble the individual tensors that constitute the actual storage used
136// externally into MLIR sparse tensors before calling the original method.
137//
138// In particular, each sparse tensor input
139//
140// void foo(..., t, ...) { }
141//
142// makes the original foo() internal and adds the following wrapper method
143//
144// void foo(..., t1..tn, ...) {
145// t = assemble t1..tn
146// _internal_foo(..., t, ...)
147// }
148//
149// and likewise, each output tensor
150//
151// ... T ... bar(...) { return ..., t, ...; }
152//
153// makes the original bar() internal and adds the following wrapper method
154//
155// ... T1..TN ... bar(..., t1'..tn') {
156// ..., t, ... = _internal_bar(...)
157// t1..tn = disassemble t, t1'..tn'
158// return ..., t1..tn, ...
159// }
160//
161// (with a direct-out variant without the disassemble).
162//
163struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
164 using OpRewritePattern::OpRewritePattern;
165
166 SparseFuncAssembler(MLIRContext *context, bool dO)
167 : OpRewritePattern(context), directOut(dO) {}
168
169 LogicalResult matchAndRewrite(func::FuncOp funcOp,
170 PatternRewriter &rewriter) const override {
171 // Only rewrite public entry methods.
172 if (funcOp.isPrivate())
173 return failure();
174
175 // Translate sparse tensor types to external types.
176 SmallVector<Type> inputTypes;
177 SmallVector<Type> outputTypes;
178 SmallVector<Type> extraTypes;
179 convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, false);
180 convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut);
181
182 // Only sparse inputs or outputs need a wrapper method.
183 if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
184 outputTypes.size() == funcOp.getResultTypes().size())
185 return failure();
186
187 // Modify the original method into an internal, private method.
188 auto orgName = funcOp.getName();
189 std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
190 funcOp.setName(wrapper);
191 funcOp.setPrivate();
192
193 // Start the new public wrapper method with original name.
194 Location loc = funcOp.getLoc();
195 ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
196 MLIRContext *context = modOp.getContext();
197 OpBuilder moduleBuilder(modOp.getBodyRegion());
198 unsigned extra = inputTypes.size();
199 inputTypes.append(RHS: extraTypes);
200 auto func = moduleBuilder.create<func::FuncOp>(
201 loc, orgName, FunctionType::get(context, inputTypes, outputTypes));
202 func.setPublic();
203
204 // Construct new wrapper method body.
205 OpBuilder::InsertionGuard insertionGuard(rewriter);
206 Block *body = func.addEntryBlock();
207 rewriter.setInsertionPointToStart(body);
208
209 // Convert inputs.
210 SmallVector<Value> inputs;
211 convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
212 ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut);
213
214 // Call the original, now private method. A subsequent inlining pass can
215 // determine whether cloning the method body in place is worthwhile.
216 auto org = SymbolRefAttr::get(context, wrapper);
217 auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
218 inputs);
219
220 // Convert outputs and return.
221 SmallVector<Value> outputs;
222 convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
223 body->getArguments(), outputs, extra, /*isIn=*/false, directOut);
224 rewriter.create<func::ReturnOp>(loc, outputs);
225
226 // Finally, migrate a potential c-interface property.
227 if (funcOp->getAttrOfType<UnitAttr>(
228 LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
229 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
230 UnitAttr::get(context));
231 funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
232 }
233 return success();
234 }
235
236private:
237 const bool directOut;
238};
239
240} // namespace
241
242//===----------------------------------------------------------------------===//
243// Public method for populating conversion rules.
244//===----------------------------------------------------------------------===//
245
246void mlir::populateSparseAssembler(RewritePatternSet &patterns,
247 bool directOut) {
248 patterns.add<SparseFuncAssembler>(arg: patterns.getContext(), args&: directOut);
249}
250

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp