1//===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Utils/CodegenUtils.h"
10
11#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
13#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
14#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
15#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "llvm/Support/FormatVariadic.h"
18
19using namespace mlir;
20using namespace sparse_tensor;
21
22//===----------------------------------------------------------------------===//
23// Helper methods.
24//===----------------------------------------------------------------------===//
25
26// Convert type range to new types range, with sparse tensors externalized.
27static void convTypes(bool &hasAnnotation, TypeRange types,
28 SmallVectorImpl<Type> &convTypes,
29 SmallVectorImpl<Type> *extraTypes, bool directOut) {
30 for (auto type : types) {
31 // All "dense" data passes through unmodified.
32 if (!getSparseTensorEncoding(type)) {
33 convTypes.push_back(Elt: type);
34 continue;
35 }
36 hasAnnotation = true;
37
38 // Convert the external representations of the pos/crd/val arrays.
39 const SparseTensorType stt(cast<RankedTensorType>(type));
40 foreachFieldAndTypeInSparseTensor(
41 stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
42 SparseTensorFieldKind kind,
43 Level, LevelType) {
44 if (kind == SparseTensorFieldKind::PosMemRef ||
45 kind == SparseTensorFieldKind::CrdMemRef ||
46 kind == SparseTensorFieldKind::ValMemRef) {
47 auto rtp = cast<ShapedType>(t);
48 if (!directOut) {
49 rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
50 if (extraTypes)
51 extraTypes->push_back(Elt: rtp);
52 }
53 convTypes.push_back(Elt: rtp);
54 }
55 return true;
56 });
57 }
58}
59
60// Convert input and output values to [dis]assemble ops for sparse tensors.
61static void convVals(OpBuilder &builder, Location loc, TypeRange types,
62 ValueRange fromVals, ValueRange extraVals,
63 SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
64 bool directOut) {
65 unsigned idx = 0;
66 for (auto type : types) {
67 // All "dense" data passes through unmodified.
68 if (!getSparseTensorEncoding(type)) {
69 toVals.push_back(Elt: fromVals[idx++]);
70 continue;
71 }
72 // Handle sparse data.
73 auto rtp = cast<RankedTensorType>(type);
74 const SparseTensorType stt(rtp);
75 SmallVector<Value> inputs;
76 SmallVector<Type> retTypes;
77 SmallVector<Type> cntTypes;
78 if (!isIn)
79 inputs.push_back(Elt: fromVals[idx++]); // The sparse tensor to disassemble
80
81 // Collect the external representations of the pos/crd/val arrays.
82 foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
83 SparseTensorFieldKind kind,
84 Level lv, LevelType) {
85 if (kind == SparseTensorFieldKind::PosMemRef ||
86 kind == SparseTensorFieldKind::CrdMemRef ||
87 kind == SparseTensorFieldKind::ValMemRef) {
88 if (isIn) {
89 inputs.push_back(Elt: fromVals[idx++]);
90 } else if (directOut) {
91 Value mem;
92 if (kind == SparseTensorFieldKind::PosMemRef)
93 mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
94 lv);
95 else if (kind == SparseTensorFieldKind::CrdMemRef)
96 mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0],
97 lv);
98 else
99 mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
100 toVals.push_back(Elt: mem);
101 } else {
102 ShapedType rtp = cast<ShapedType>(t);
103 rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
104 inputs.push_back(Elt: extraVals[extra++]);
105 retTypes.push_back(Elt: rtp);
106 cntTypes.push_back(builder.getIndexType());
107 }
108 }
109 return true;
110 });
111
112 if (isIn) {
113 // Assemble multiple inputs into a single sparse tensor.
114 auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
115 toVals.push_back(Elt: a.getResult());
116 } else if (!directOut) {
117 // Disassemble a single sparse input into multiple outputs.
118 // Note that this includes the counters, which are dropped.
119 unsigned len = retTypes.size();
120 retTypes.append(RHS: cntTypes);
121 auto d =
122 builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs);
123 for (unsigned i = 0; i < len; i++)
124 toVals.push_back(Elt: d.getResult(i));
125 }
126 }
127}
128
129//===----------------------------------------------------------------------===//
130// Rewriting rules.
131//===----------------------------------------------------------------------===//
132
133namespace {
134
135// A rewriting rules that converts public entry methods that use sparse tensors
136// as input parameters and/or output return values into wrapper methods that
137// [dis]assemble the individual tensors that constitute the actual storage used
138// externally into MLIR sparse tensors before calling the original method.
139//
140// In particular, each sparse tensor input
141//
142// void foo(..., t, ...) { }
143//
144// makes the original foo() internal and adds the following wrapper method
145//
146// void foo(..., t1..tn, ...) {
147// t = assemble t1..tn
148// _internal_foo(..., t, ...)
149// }
150//
151// and likewise, each output tensor
152//
153// ... T ... bar(...) { return ..., t, ...; }
154//
155// makes the original bar() internal and adds the following wrapper method
156//
157// ... T1..TN ... bar(..., t1'..tn') {
158// ..., t, ... = _internal_bar(...)
159// t1..tn = disassemble t, t1'..tn'
160// return ..., t1..tn, ...
161// }
162//
163// (with a direct-out variant without the disassemble).
164//
165struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
166 using OpRewritePattern::OpRewritePattern;
167
168 SparseFuncAssembler(MLIRContext *context, bool dO)
169 : OpRewritePattern(context), directOut(dO) {}
170
171 LogicalResult matchAndRewrite(func::FuncOp funcOp,
172 PatternRewriter &rewriter) const override {
173 // Only rewrite public entry methods.
174 if (funcOp.isPrivate())
175 return failure();
176
177 // Translate sparse tensor types to external types.
178 SmallVector<Type> inputTypes;
179 SmallVector<Type> outputTypes;
180 SmallVector<Type> extraTypes;
181 bool hasAnnotation = false;
182 convTypes(hasAnnotation, funcOp.getArgumentTypes(), inputTypes, nullptr,
183 false);
184 convTypes(hasAnnotation, funcOp.getResultTypes(), outputTypes, &extraTypes,
185 directOut);
186
187 // Only sparse inputs or outputs need a wrapper method.
188 if (!hasAnnotation)
189 return failure();
190
191 // Modify the original method into an internal, private method.
192 auto orgName = funcOp.getName();
193 std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
194 funcOp.setName(wrapper);
195 funcOp.setPrivate();
196
197 // Start the new public wrapper method with original name.
198 Location loc = funcOp.getLoc();
199 ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
200 MLIRContext *context = modOp.getContext();
201 OpBuilder moduleBuilder(modOp.getBodyRegion());
202 unsigned extra = inputTypes.size();
203 inputTypes.append(RHS: extraTypes);
204 auto func = moduleBuilder.create<func::FuncOp>(
205 loc, orgName, FunctionType::get(context, inputTypes, outputTypes));
206 func.setPublic();
207
208 // Construct new wrapper method body.
209 OpBuilder::InsertionGuard insertionGuard(rewriter);
210 Block *body = func.addEntryBlock();
211 rewriter.setInsertionPointToStart(body);
212
213 // Convert inputs.
214 SmallVector<Value> inputs;
215 convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
216 ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut);
217
218 // Call the original, now private method. A subsequent inlining pass can
219 // determine whether cloning the method body in place is worthwhile.
220 auto org = SymbolRefAttr::get(context, wrapper);
221 auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
222 inputs);
223
224 // Convert outputs and return.
225 SmallVector<Value> outputs;
226 convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
227 body->getArguments(), outputs, extra, /*isIn=*/false, directOut);
228 rewriter.create<func::ReturnOp>(loc, outputs);
229
230 // Finally, migrate a potential c-interface property.
231 if (funcOp->getAttrOfType<UnitAttr>(
232 LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
233 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
234 UnitAttr::get(context));
235 funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
236 }
237 return success();
238 }
239
240private:
241 const bool directOut;
242};
243
244} // namespace
245
246//===----------------------------------------------------------------------===//
247// Public method for populating conversion rules.
248//===----------------------------------------------------------------------===//
249
250void mlir::populateSparseAssembler(RewritePatternSet &patterns,
251 bool directOut) {
252 patterns.add<SparseFuncAssembler>(arg: patterns.getContext(), args&: directOut);
253}
254

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp