1//===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// A pass that converts sparse tensor primitives into calls into a runtime
10// support library. Sparse tensor types are converted into opaque pointers
11// to the underlying sparse storage schemes. The use of opaque pointers
12// together with runtime support library keeps the conversion relatively
13// simple, but at the expense of IR opacity, which obscures opportunities
14// for subsequent optimization of the IR. An alternative is provided by
15// the SparseTensorCodegen pass.
16//
17//===----------------------------------------------------------------------===//
18
19#include "Utils/CodegenUtils.h"
20
21#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
22#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
23#include "mlir/Dialect/Linalg/Utils/Utils.h"
24#include "mlir/Dialect/MemRef/IR/MemRef.h"
25#include "mlir/Dialect/SCF/IR/SCF.h"
26#include "mlir/Dialect/SparseTensor/IR/Enums.h"
27#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
28#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
29#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
30#include "mlir/Dialect/Tensor/IR/Tensor.h"
31#include "mlir/Transforms/DialectConversion.h"
32
33using namespace mlir;
34using namespace mlir::sparse_tensor;
35
36namespace {
37
38//===----------------------------------------------------------------------===//
39// Helper methods.
40//===----------------------------------------------------------------------===//
41
42/// Maps each sparse tensor type to an opaque pointer.
43static std::optional<Type> convertSparseTensorTypes(Type type) {
44 if (getSparseTensorEncoding(type) != nullptr)
45 return LLVM::LLVMPointerType::get(type.getContext());
46 return std::nullopt;
47}
48
49/// Generates call to lookup a level-size. N.B., this only generates
50/// the raw function call, and therefore (intentionally) does not perform
51/// any dim<->lvl conversion or other logic.
52static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor,
53 uint64_t lvl) {
54 StringRef name = "sparseLvlSize";
55 SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, i: lvl)};
56 Type iTp = builder.getIndexType();
57 return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
58 .getResult(0);
59}
60
61/// Generates call to lookup a dimension-size. N.B., this only generates
62/// the raw function call, and therefore (intentionally) does not perform
63/// any dim<->lvl conversion or other logic.
64static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor,
65 uint64_t dim) {
66 StringRef name = "sparseDimSize";
67 SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, i: dim)};
68 Type iTp = builder.getIndexType();
69 return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
70 .getResult(0);
71}
72
73/// Looks up a level-size by returning a statically-computed constant
74/// (when possible), or by calling `genLvlSizeCall` (when dynamic).
75static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
76 SparseTensorType stt, Value tensor,
77 Level lvl) {
78 // Only sparse tensors have "levels" to query.
79 assert(stt.hasEncoding());
80 // TODO: The following implementation only handles permutations;
81 // we'll need to generalize this to handle arbitrary AffineExpr.
82 //
83 // There's no need to assert `isPermutation` here: because
84 // `getDimPosition` checks that the expr isa `AffineDimExpr`,
85 // which is all we care about (for supporting permutations).
86 const Dimension dim =
87 stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(idx: lvl);
88 const Size sz = stt.getDynamicDimSize(d: dim);
89 if (!ShapedType::isDynamic(sz))
90 return constantIndex(builder, loc, i: sz);
91 // If we cannot statically compute the size from the shape, then we
92 // must dynamically query it. (In principle we could also dynamically
93 // compute it, but since we already did so to construct the `tensor`
94 // in the first place, we might as well query rather than recompute.)
95 return genLvlSizeCall(builder, loc, tensor, lvl);
96}
97
98/// Looks up a dimension-size by returning a constant from the shape
99/// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes
100/// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes
101/// of dense tensors).
102static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
103 SparseTensorType stt, Value tensor,
104 Dimension dim) {
105 const Size sz = stt.getDynamicDimSize(d: dim);
106 if (!ShapedType::isDynamic(sz))
107 return constantIndex(builder, loc, i: sz);
108 if (stt.hasEncoding())
109 return genDimSizeCall(builder, loc, tensor, dim);
110 return linalg::createOrFoldDimOp(b&: builder, loc, val: tensor, dim);
111}
112
113/// Populates the array with the dimension-sizes of the given tensor.
114static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt,
115 Value tensor, SmallVectorImpl<Value> &out) {
116 const Dimension dimRank = stt.getDimRank();
117 out.clear();
118 out.reserve(N: dimRank);
119 for (Dimension d = 0; d < dimRank; d++)
120 out.push_back(Elt: createOrFoldDimCall(builder, loc, stt, tensor, dim: d));
121}
122
123/// Returns an array with the dimension-sizes of the given tensor.
124/// If the *tensor* parameters is null, the tensor type is assumed to have a
125/// static shape.
126static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
127 SparseTensorType stt,
128 Value tensor = Value()) {
129 SmallVector<Value> out;
130 fillDimSizes(builder, loc, stt, tensor, out);
131 return out;
132}
133
134/// Generates an uninitialized buffer of the given size and type,
135/// but returns it as type `memref<? x $tp>` (rather than as type
136/// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
137/// this buffer must be explicitly deallocated by client.
138static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
139 auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
140 return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
141}
142
143/// Generates a temporary buffer for the level-types of the given encoding.
144static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
145 SparseTensorType stt) {
146 SmallVector<Value> lvlTypes;
147 lvlTypes.reserve(N: stt.getLvlRank());
148 for (const auto lt : stt.getEncoding().getLvlTypes())
149 lvlTypes.push_back(constantLevelTypeEncoding(builder, loc, lt));
150 return allocaBuffer(builder, loc, values: lvlTypes);
151}
152
153/// Extracts the bare (aligned) pointers that point to the tensor.
154static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc,
155 Value tensor) {
156 auto buf = genToMemref(builder, loc, tensor);
157 return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf);
158}
159
160/// Generates a temporary buffer for the level-types of the given encoding.
161static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc,
162 ValueRange lvlTensors, Value valTensor) {
163 SmallVector<Value> lvlBarePtrs;
164 lvlBarePtrs.reserve(N: lvlTensors.size() + 1);
165 // Passing in lvl buffer pointers.
166 for (const auto lvl : lvlTensors)
167 lvlBarePtrs.push_back(Elt: extractBarePtrFromTensor(builder, loc, tensor: lvl));
168
169 // Passing in value buffer pointers.
170 lvlBarePtrs.push_back(Elt: extractBarePtrFromTensor(builder, loc, tensor: valTensor));
171 Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>(
172 loc, allocaBuffer(builder, loc, lvlBarePtrs));
173 Value idxCast =
174 builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr);
175 return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder),
176 idxCast);
177}
178
179/// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
180/// the "swiss army knife" method of the sparse runtime support library
181/// for materializing sparse tensors into the computation. This abstraction
182/// reduces the need for modifications when the API changes.
183class NewCallParams final {
184public:
185 /// Allocates the `ValueRange` for the `func::CallOp` parameters.
186 NewCallParams(OpBuilder &builder, Location loc)
187 : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
188
189 /// Initializes all static parameters (i.e., those which indicate
190 /// type-level information such as the encoding and sizes), generating
191 /// MLIR buffers as needed, and returning `this` for method chaining.
192 NewCallParams &genBuffers(SparseTensorType stt,
193 ArrayRef<Value> dimSizesValues,
194 Value dimSizesBuffer = Value()) {
195 assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank()));
196 // Sparsity annotations.
197 params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
198 // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
199 params[kParamDimSizes] = dimSizesBuffer
200 ? dimSizesBuffer
201 : allocaBuffer(builder, loc, values: dimSizesValues);
202 SmallVector<Value> lvlSizesValues; // unused
203 params[kParamLvlSizes] = genMapBuffers(
204 builder, loc, stt, dimSizesValues, dimSizesBuffer: params[kParamDimSizes],
205 lvlSizesValues, dim2lvlBuffer&: params[kParamDim2Lvl], lvl2dimBuffer&: params[kParamLvl2Dim]);
206 // Secondary and primary types encoding.
207 const auto enc = stt.getEncoding();
208 params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc);
209 params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc);
210 params[kParamValTp] =
211 constantPrimaryTypeEncoding(builder, loc, elemTp: stt.getElementType());
212 // Return `this` for method chaining.
213 return *this;
214 }
215
216 /// Checks whether all the static parameters have been initialized.
217 bool isInitialized() const {
218 for (unsigned i = 0; i < kNumStaticParams; ++i)
219 if (!params[i])
220 return false;
221 return true;
222 }
223
224 /// Generates a function call, with the current static parameters
225 /// and the given dynamic arguments.
226 Value genNewCall(Action action, Value ptr = Value()) {
227 assert(isInitialized() && "Must initialize before genNewCall");
228 StringRef name = "newSparseTensor";
229 params[kParamAction] = constantAction(builder, loc, action);
230 params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp);
231 return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
232 .getResult(0);
233 }
234
235private:
236 static constexpr unsigned kNumStaticParams = 8;
237 static constexpr unsigned kNumDynamicParams = 2;
238 static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
239 static constexpr unsigned kParamDimSizes = 0;
240 static constexpr unsigned kParamLvlSizes = 1;
241 static constexpr unsigned kParamLvlTypes = 2;
242 static constexpr unsigned kParamDim2Lvl = 3;
243 static constexpr unsigned kParamLvl2Dim = 4;
244 static constexpr unsigned kParamPosTp = 5;
245 static constexpr unsigned kParamCrdTp = 6;
246 static constexpr unsigned kParamValTp = 7;
247 static constexpr unsigned kParamAction = 8;
248 static constexpr unsigned kParamPtr = 9;
249
250 OpBuilder &builder;
251 Location loc;
252 Type pTp;
253 Value params[kNumParams];
254};
255
256/// Generates a call to obtain the values array.
257static Value genValuesCall(OpBuilder &builder, Location loc,
258 SparseTensorType stt, Value ptr) {
259 auto eltTp = stt.getElementType();
260 auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp);
261 SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(elemTp: eltTp)};
262 return createFuncCall(builder, loc, name, resTp, {ptr}, EmitCInterface::On)
263 .getResult(0);
264}
265
266/// Generates a call to obtain the positions array.
267static Value genPositionsCall(OpBuilder &builder, Location loc,
268 SparseTensorType stt, Value ptr, Level l) {
269 Type posTp = stt.getPosType();
270 auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp);
271 Value lvl = constantIndex(builder, loc, i: l);
272 SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(overheadTp: posTp)};
273 return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
274 EmitCInterface::On)
275 .getResult(0);
276}
277
278/// Generates a call to obtain the coordinates array.
279static Value genCoordinatesCall(OpBuilder &builder, Location loc,
280 SparseTensorType stt, Value ptr, Level l) {
281 Type crdTp = stt.getCrdType();
282 auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
283 Value lvl = constantIndex(builder, loc, i: l);
284 SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(overheadTp: crdTp)};
285 return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
286 EmitCInterface::On)
287 .getResult(0);
288}
289
290/// Generates a call to obtain the coordinates array (AoS view).
291static Value genCoordinatesBufferCall(OpBuilder &builder, Location loc,
292 SparseTensorType stt, Value ptr,
293 Level l) {
294 Type crdTp = stt.getCrdType();
295 auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
296 Value lvl = constantIndex(builder, loc, i: l);
297 SmallString<25> name{"sparseCoordinatesBuffer",
298 overheadTypeFunctionSuffix(overheadTp: crdTp)};
299 return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
300 EmitCInterface::On)
301 .getResult(0);
302}
303
304//===----------------------------------------------------------------------===//
305// Conversion rules.
306//===----------------------------------------------------------------------===//
307
308/// Sparse conversion rule for returns.
309class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
310public:
311 using OpConversionPattern::OpConversionPattern;
312 LogicalResult
313 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
314 ConversionPatternRewriter &rewriter) const override {
315 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
316 return success();
317 }
318};
319
320/// Sparse conversion rule for accessing level-sizes.
321class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> {
322public:
323 using OpConversionPattern::OpConversionPattern;
324 LogicalResult
325 matchAndRewrite(LvlOp op, OpAdaptor adaptor,
326 ConversionPatternRewriter &rewriter) const override {
327 const auto stt = getSparseTensorType(op.getSource());
328 // Only rewrite sparse DimOp.
329 if (!stt.hasEncoding())
330 return failure();
331
332 // Only rewrite DimOp with constant index.
333 std::optional<int64_t> lvl = op.getConstantLvlIndex();
334
335 if (!lvl)
336 return failure();
337
338 // By now, if the level size is constant, the operation should have already
339 // been folded by LvlOp's folder, so we generate the call unconditionally.
340 Value src = adaptor.getOperands()[0];
341 rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl));
342 return success();
343 }
344};
345
346/// Sparse conversion rule for trivial tensor casts.
347class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
348public:
349 using OpConversionPattern::OpConversionPattern;
350 LogicalResult
351 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
352 ConversionPatternRewriter &rewriter) const override {
353 // Only rewrite identically annotated source/dest.
354 auto encDst = getSparseTensorEncoding(op.getType());
355 auto encSrc = getSparseTensorEncoding(op.getSource().getType());
356 if (!encDst || encDst != encSrc)
357 return failure();
358 rewriter.replaceOp(op, adaptor.getOperands());
359 return success();
360 }
361};
362
363class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
364public:
365 using OpConversionPattern::OpConversionPattern;
366 LogicalResult
367 matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
368 ConversionPatternRewriter &rewriter) const override {
369 // Simply fold the operation.
370 rewriter.replaceOp(op, adaptor.getSource());
371 return success();
372 }
373};
374
375/// Sparse conversion rule for the new operator.
376class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
377public:
378 using OpConversionPattern::OpConversionPattern;
379 LogicalResult
380 matchAndRewrite(NewOp op, OpAdaptor adaptor,
381 ConversionPatternRewriter &rewriter) const override {
382 Location loc = op.getLoc();
383 const auto stt = getSparseTensorType(op);
384 if (!stt.hasEncoding())
385 return failure();
386 // Construct the `reader` opening method calls.
387 SmallVector<Value> dimSizesValues;
388 Value dimSizesBuffer;
389 Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
390 dimSizesValues, dimSizesBuffer);
391 // Use the `reader` to parse the file.
392 Value tensor = NewCallParams(rewriter, loc)
393 .genBuffers(stt: stt, dimSizesValues, dimSizesBuffer)
394 .genNewCall(Action::kFromReader, reader);
395 // Free the memory for `reader`.
396 createFuncCall(builder&: rewriter, loc, name: "delSparseTensorReader", resultType: {}, operands: {reader},
397 emitCInterface: EmitCInterface::Off);
398 rewriter.replaceOp(op, tensor);
399 return success();
400 }
401};
402
403/// Sparse conversion rule for the alloc operator.
404/// TODO(springerm): remove when bufferization.alloc_tensor is gone
405class SparseTensorAllocConverter
406 : public OpConversionPattern<bufferization::AllocTensorOp> {
407public:
408 using OpConversionPattern::OpConversionPattern;
409 LogicalResult
410 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
411 ConversionPatternRewriter &rewriter) const override {
412 const auto stt = getSparseTensorType(op);
413 if (!stt.hasEncoding())
414 return failure();
415 if (op.getCopy())
416 return rewriter.notifyMatchFailure(op, "alloc copy not implemented");
417 // Gather all dimension sizes as SSA values.
418 Location loc = op.getLoc();
419 const Dimension dimRank = stt.getDimRank();
420 SmallVector<Value> dimSizesValues;
421 dimSizesValues.reserve(N: dimRank);
422 unsigned operandCtr = 0;
423 for (Dimension d = 0; d < dimRank; d++) {
424 dimSizesValues.push_back(
425 Elt: stt.isDynamicDim(d)
426 ? adaptor.getOperands()[operandCtr++]
427 : constantIndex(rewriter, loc, op.getStaticSize(d)));
428 }
429 // Generate the call to construct empty tensor. The sizes are
430 // explicitly defined by the arguments to the alloc operator.
431 rewriter.replaceOp(op, NewCallParams(rewriter, loc)
432 .genBuffers(stt: stt, dimSizesValues)
433 .genNewCall(Action::kEmpty));
434 return success();
435 }
436};
437
438/// Sparse conversion rule for the empty tensor.
439class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
440public:
441 using OpConversionPattern::OpConversionPattern;
442 LogicalResult
443 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
444 ConversionPatternRewriter &rewriter) const override {
445 Location loc = op.getLoc();
446 const auto stt = getSparseTensorType(op);
447 if (!stt.hasEncoding())
448 return failure();
449 // Gather all dimension sizes as SSA values.
450 const Dimension dimRank = stt.getDimRank();
451 SmallVector<Value> dimSizesValues;
452 dimSizesValues.reserve(N: dimRank);
453 auto shape = op.getType().getShape();
454 unsigned operandCtr = 0;
455 for (Dimension d = 0; d < dimRank; d++) {
456 dimSizesValues.push_back(Elt: stt.isDynamicDim(d)
457 ? adaptor.getOperands()[operandCtr++]
458 : constantIndex(rewriter, loc, shape[d]));
459 }
460 // Generate the call to construct empty tensor. The sizes are
461 // explicitly defined by the arguments to the alloc operator.
462 rewriter.replaceOp(op, NewCallParams(rewriter, loc)
463 .genBuffers(stt: stt, dimSizesValues)
464 .genNewCall(Action::kEmpty));
465 return success();
466 }
467};
468
469/// Sparse conversion rule for the convert operator.
470class SparseTensorReorderCOOConverter
471 : public OpConversionPattern<ReorderCOOOp> {
472public:
473 using OpConversionPattern::OpConversionPattern;
474
475 LogicalResult
476 matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor,
477 ConversionPatternRewriter &rewriter) const override {
478 const Location loc = op->getLoc();
479 const auto srcTp = getSparseTensorType(op.getInputCoo());
480 const auto dstTp = getSparseTensorType(op);
481
482 const Value src = adaptor.getInputCoo();
483
484 NewCallParams params(rewriter, loc);
485 SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, srcTp, src);
486 rewriter.replaceOp(op, params.genBuffers(stt: dstTp, dimSizesValues)
487 .genNewCall(Action::kSortCOOInPlace, src));
488
489 return success();
490 }
491};
492
493/// Sparse conversion rule for the dealloc operator.
494class SparseTensorDeallocConverter
495 : public OpConversionPattern<bufferization::DeallocTensorOp> {
496public:
497 using OpConversionPattern::OpConversionPattern;
498 LogicalResult
499 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
500 ConversionPatternRewriter &rewriter) const override {
501 if (!getSparseTensorType(op.getTensor()).hasEncoding())
502 return failure();
503 StringRef name = "delSparseTensor";
504 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
505 EmitCInterface::Off);
506 rewriter.eraseOp(op: op);
507 return success();
508 }
509};
510
511/// Sparse conversion rule for position accesses.
512class SparseTensorToPositionsConverter
513 : public OpConversionPattern<ToPositionsOp> {
514public:
515 using OpConversionPattern::OpConversionPattern;
516 LogicalResult
517 matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
518 ConversionPatternRewriter &rewriter) const override {
519 auto stt = getSparseTensorType(op.getTensor());
520 auto poss = genPositionsCall(rewriter, op.getLoc(), stt,
521 adaptor.getTensor(), op.getLevel());
522 rewriter.replaceOp(op, poss);
523 return success();
524 }
525};
526
527/// Sparse conversion rule for coordinate accesses.
528class SparseTensorToCoordinatesConverter
529 : public OpConversionPattern<ToCoordinatesOp> {
530public:
531 using OpConversionPattern::OpConversionPattern;
532 LogicalResult
533 matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
534 ConversionPatternRewriter &rewriter) const override {
535 const Location loc = op.getLoc();
536 auto stt = getSparseTensorType(op.getTensor());
537 auto crds = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
538 op.getLevel());
539 // Cast the MemRef type to the type expected by the users, though these
540 // two types should be compatible at runtime.
541 if (op.getType() != crds.getType())
542 crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds);
543 rewriter.replaceOp(op, crds);
544 return success();
545 }
546};
547
548/// Sparse conversion rule for coordinate accesses (AoS style).
549class SparseToCoordinatesBufferConverter
550 : public OpConversionPattern<ToCoordinatesBufferOp> {
551public:
552 using OpConversionPattern::OpConversionPattern;
553 LogicalResult
554 matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
555 ConversionPatternRewriter &rewriter) const override {
556 const Location loc = op.getLoc();
557 auto stt = getSparseTensorType(op.getTensor());
558 auto crds = genCoordinatesBufferCall(
559 rewriter, loc, stt, adaptor.getTensor(), stt.getAoSCOOStart());
560 // Cast the MemRef type to the type expected by the users, though these
561 // two types should be compatible at runtime.
562 if (op.getType() != crds.getType())
563 crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds);
564 rewriter.replaceOp(op, crds);
565 return success();
566 }
567};
568
569/// Sparse conversion rule for value accesses.
570class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
571public:
572 using OpConversionPattern::OpConversionPattern;
573 LogicalResult
574 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
575 ConversionPatternRewriter &rewriter) const override {
576 auto stt = getSparseTensorType(op.getTensor());
577 auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
578 rewriter.replaceOp(op, vals);
579 return success();
580 }
581};
582
583/// Sparse conversion rule for number of entries operator.
584class SparseNumberOfEntriesConverter
585 : public OpConversionPattern<NumberOfEntriesOp> {
586public:
587 using OpConversionPattern::OpConversionPattern;
588 LogicalResult
589 matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
590 ConversionPatternRewriter &rewriter) const override {
591 // Query values array size for the actually stored values size.
592 auto stt = getSparseTensorType(op.getTensor());
593 auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
594 auto zero = constantIndex(rewriter, op.getLoc(), 0);
595 rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero);
596 return success();
597 }
598};
599
600/// Sparse conversion rule for tensor rematerialization.
601class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
602public:
603 using OpConversionPattern::OpConversionPattern;
604 LogicalResult
605 matchAndRewrite(LoadOp op, OpAdaptor adaptor,
606 ConversionPatternRewriter &rewriter) const override {
607 if (op.getHasInserts()) {
608 // Finalize any pending insertions.
609 StringRef name = "endLexInsert";
610 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
611 EmitCInterface::Off);
612 }
613 rewriter.replaceOp(op, adaptor.getOperands());
614 return success();
615 }
616};
617
618/// Sparse conversion rule for the insertion operator.
619class SparseTensorInsertConverter
620 : public OpConversionPattern<tensor::InsertOp> {
621public:
622 using OpConversionPattern::OpConversionPattern;
623 LogicalResult
624 matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
625 ConversionPatternRewriter &rewriter) const override {
626 // Note that the current regime only allows for strict lexicographic
627 // coordinate order. All values are passed by reference through stack
628 // allocated memrefs.
629 Location loc = op->getLoc();
630 const auto stt = getSparseTensorType(op.getDest());
631
632 // Dense tensor insertion.
633 if (!stt.hasEncoding())
634 return failure();
635
636 assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
637 const auto elemTp = stt.getElementType();
638 const Level lvlRank = stt.getLvlRank();
639 Value lvlCoords, vref;
640 {
641 OpBuilder::InsertionGuard guard(rewriter);
642 Operation *loop = op;
643 // Finds the outermost loop.
644 while (auto l = loop->getParentOfType<LoopLikeOpInterface>())
645 loop = l;
646
647 if (llvm::isa<LoopLikeOpInterface>(loop)) {
648 // Hoists alloca outside the loop to avoid stack overflow.
649 rewriter.setInsertionPoint(loop);
650 }
651 lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
652 vref = genAllocaScalar(rewriter, loc, elemTp);
653 }
654 storeAll(rewriter, loc, lvlCoords, adaptor.getIndices());
655 rewriter.create<memref::StoreOp>(loc, adaptor.getScalar(), vref);
656 SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
657 createFuncCall(rewriter, loc, name, {},
658 {adaptor.getDest(), lvlCoords, vref}, EmitCInterface::On);
659 rewriter.replaceOp(op, adaptor.getDest());
660 return success();
661 }
662};
663
664/// Sparse conversion rule for the expand operator.
665class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
666public:
667 using OpConversionPattern::OpConversionPattern;
668 LogicalResult
669 matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
670 ConversionPatternRewriter &rewriter) const override {
671 Location loc = op->getLoc();
672 const auto srcTp = getSparseTensorType(op.getTensor());
673 Type eltType = srcTp.getElementType();
674 Type boolType = rewriter.getIntegerType(1);
675 Type idxType = rewriter.getIndexType();
676 // All initialization should be done on entry of the loop nest.
677 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
678 // Get the cardinality of valid coordinates for the innermost level.
679 Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(),
680 srcTp.getLvlRank() - 1);
681 // Allocate temporary buffers for values, filled-switch, and coordinates.
682 // We do not use stack buffers for this, since the expanded size may
683 // be rather large (as it envelops a single expanded dense dimension).
684 Value values = genAlloc(rewriter, loc, sz, tp: eltType);
685 Value filled = genAlloc(rewriter, loc, sz, tp: boolType);
686 Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, tp: idxType);
687 Value zero = constantZero(builder&: rewriter, loc, tp: idxType);
688 // Reset the values/filled-switch to all-zero/false. Note that this
689 // introduces an O(N) operation into the computation, but this reset
690 // operation is amortized over the innermost loops for the access
691 // pattern expansion. As noted in the operation doc, we would like
692 // to amortize this setup cost even between kernels.
693 rewriter.create<linalg::FillOp>(
694 loc, ValueRange{constantZero(rewriter, loc, eltType)},
695 ValueRange{values});
696 rewriter.create<linalg::FillOp>(
697 loc, ValueRange{constantZero(rewriter, loc, boolType)},
698 ValueRange{filled});
699 // Replace expansion op with these buffers and initial coordinate.
700 assert(op.getNumResults() == 4);
701 rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero});
702 return success();
703 }
704};
705
706/// Sparse conversion rule for the compress operator.
707class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
708public:
709 using OpConversionPattern::OpConversionPattern;
710 LogicalResult
711 matchAndRewrite(CompressOp op, OpAdaptor adaptor,
712 ConversionPatternRewriter &rewriter) const override {
713 Location loc = op->getLoc();
714 // Note that this method call resets the values/filled-switch back to
715 // all-zero/false by only iterating over the set elements, so the
716 // complexity remains proportional to the sparsity of the expanded
717 // access pattern.
718 Value values = adaptor.getValues();
719 Value filled = adaptor.getFilled();
720 Value added = adaptor.getAdded();
721 Value count = adaptor.getCount();
722 Value tensor = adaptor.getTensor();
723 const auto stt = getSparseTensorType(op.getTensor());
724 const Type elemTp = stt.getElementType();
725 const Level lvlRank = stt.getLvlRank();
726 auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
727 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
728 SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
729 createFuncCall(rewriter, loc, name, {},
730 {tensor, lvlCoords, values, filled, added, count},
731 EmitCInterface::On);
732 rewriter.replaceOp(op, adaptor.getTensor());
733 // Deallocate the buffers on exit of the loop nest.
734 Operation *parent = getTop(op);
735 rewriter.setInsertionPointAfter(parent);
736 rewriter.create<memref::DeallocOp>(loc, values);
737 rewriter.create<memref::DeallocOp>(loc, filled);
738 rewriter.create<memref::DeallocOp>(loc, added);
739 return success();
740 }
741};
742
743/// Sparse conversion rule for the sparse_tensor.assemble operator.
744class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
745public:
746 using OpConversionPattern::OpConversionPattern;
747 LogicalResult
748 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
749 ConversionPatternRewriter &rewriter) const override {
750 const Location loc = op->getLoc();
751 const auto dstTp = getSparseTensorType(op.getResult());
752 assert(dstTp.hasStaticDimShape());
753 SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, dstTp);
754 // Use a library method to transfer the external buffers from
755 // clients to the internal SparseTensorStorage. Since we cannot
756 // assume clients transfer ownership of the buffers, this method
757 // will copy all data over into a new SparseTensorStorage.
758 Value dst =
759 NewCallParams(rewriter, loc)
760 .genBuffers(stt: dstTp.withoutDimToLvl(), dimSizesValues)
761 .genNewCall(Action::kPack,
762 genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(),
763 adaptor.getValues()));
764 rewriter.replaceOp(op, dst);
765 return success();
766 }
767};
768
769/// Sparse conversion rule for the sparse_tensor.disassemble operator.
770/// Note that the current implementation simply exposes the buffers to
771/// the external client. This assumes the client only reads the buffers
772/// (usually copying it to the external data structures, such as numpy
773/// arrays). The semantics of the disassemble operation technically
774/// require that the copying is done here already using the out-levels
775/// and out-values clause.
776class SparseTensorDisassembleConverter
777 : public OpConversionPattern<DisassembleOp> {
778public:
779 using OpConversionPattern::OpConversionPattern;
780 LogicalResult
781 matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
782 ConversionPatternRewriter &rewriter) const override {
783 Location loc = op->getLoc();
784 auto stt = getSparseTensorType(op.getTensor());
785 SmallVector<Value> retVal;
786 SmallVector<Value> retLen;
787 // Get the positions and coordinates buffers.
788 const Level lvlRank = stt.getLvlRank();
789 Level trailCOOLen = 0;
790 for (Level l = 0; l < lvlRank; l++) {
791 if (!stt.isUniqueLvl(l) &&
792 (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) {
793 // A `(loose)compressed_nu` level marks the start of trailing COO
794 // start level. Since the target coordinate buffer used for trailing
795 // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA
796 // scheme, we cannot simply use the internal buffers.
797 trailCOOLen = lvlRank - l;
798 break;
799 }
800 if (stt.isWithPos(l)) {
801 auto poss =
802 genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
803 auto posLen = linalg::createOrFoldDimOp(b&: rewriter, loc, val: poss, dim: 0);
804 auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
805 retVal.push_back(Elt: poss);
806 retLen.push_back(Elt: genScalarToTensor(rewriter, loc, posLen, posLenTp));
807 }
808 if (stt.isWithCrd(l)) {
809 auto crds =
810 genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
811 auto crdLen = linalg::createOrFoldDimOp(b&: rewriter, loc, val: crds, dim: 0);
812 auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
813 retVal.push_back(Elt: crds);
814 retLen.push_back(Elt: genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
815 }
816 }
817 // Handle AoS vs. SoA mismatch for COO.
818 if (trailCOOLen != 0) {
819 uint64_t cooStartLvl = lvlRank - trailCOOLen;
820 assert(!stt.isUniqueLvl(cooStartLvl) &&
821 (stt.isCompressedLvl(cooStartLvl) ||
822 stt.isLooseCompressedLvl(cooStartLvl)));
823 // Positions.
824 auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
825 cooStartLvl);
826 auto posLen = linalg::createOrFoldDimOp(b&: rewriter, loc, val: poss, dim: 0);
827 auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
828 retVal.push_back(Elt: poss);
829 retLen.push_back(Elt: genScalarToTensor(rewriter, loc, posLen, posLenTp));
830 // Coordinates, copied over with:
831 // for (i = 0; i < crdLen; i++)
832 // buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
833 auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
834 auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
835 cooStartLvl);
836 auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
837 cooStartLvl + 1);
838 auto crdLen = linalg::createOrFoldDimOp(b&: rewriter, loc, val: crds0, dim: 0);
839 auto two = constantIndex(builder&: rewriter, loc, i: 2);
840 auto bufLen = rewriter.create<arith::MulIOp>(loc, crdLen, two);
841 Type indexType = rewriter.getIndexType();
842 auto zero = constantZero(builder&: rewriter, loc, tp: indexType);
843 auto one = constantOne(builder&: rewriter, loc, tp: indexType);
844 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, zero, crdLen, one);
845 auto idx = forOp.getInductionVar();
846 rewriter.setInsertionPointToStart(forOp.getBody());
847 auto c0 = rewriter.create<memref::LoadOp>(loc, crds0, idx);
848 auto c1 = rewriter.create<memref::LoadOp>(loc, crds1, idx);
849 SmallVector<Value> args;
850 args.push_back(Elt: idx);
851 args.push_back(Elt: zero);
852 rewriter.create<memref::StoreOp>(loc, c0, buf, args);
853 args[1] = one;
854 rewriter.create<memref::StoreOp>(loc, c1, buf, args);
855 rewriter.setInsertionPointAfter(forOp);
856 auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
857 retVal.push_back(Elt: buf);
858 retLen.push_back(Elt: genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
859 }
860 // Get the values buffer last.
861 auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
862 auto valLenTp = op.getValLen().getType();
863 auto valLen = linalg::createOrFoldDimOp(b&: rewriter, loc, val: vals, dim: 0);
864 retVal.push_back(Elt: vals);
865 retLen.push_back(Elt: genScalarToTensor(rewriter, loc, valLen, valLenTp));
866
867 // Converts MemRefs back to Tensors.
868 assert(retVal.size() + retLen.size() == op.getNumResults());
869 for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
870 auto tensor = rewriter.create<bufferization::ToTensorOp>(loc, retVal[i]);
871 retVal[i] =
872 rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
873 }
874
875 // Appends the actual memory length used in each buffer returned.
876 retVal.append(in_start: retLen.begin(), in_end: retLen.end());
877 rewriter.replaceOp(op, retVal);
878 return success();
879 }
880};
881
882struct SparseHasRuntimeLibraryConverter
883 : public OpConversionPattern<HasRuntimeLibraryOp> {
884 using OpConversionPattern::OpConversionPattern;
885 LogicalResult
886 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
887 ConversionPatternRewriter &rewriter) const override {
888 auto i1Type = rewriter.getI1Type();
889 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
890 op, i1Type, rewriter.getIntegerAttr(i1Type, 1));
891 return success();
892 }
893};
894
895} // namespace
896
897//===----------------------------------------------------------------------===//
898// Sparse tensor type conversion into opaque pointer.
899//===----------------------------------------------------------------------===//
900
901mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
902 addConversion(callback: [](Type type) { return type; });
903 addConversion(callback&: convertSparseTensorTypes);
904}
905
906//===----------------------------------------------------------------------===//
907// Public method for populating conversion rules.
908//===----------------------------------------------------------------------===//
909
910/// Populates the given patterns list with conversion rules required for
911/// the sparsification of linear algebra operations.
912void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
913 RewritePatternSet &patterns) {
914 patterns
915 .add<SparseReturnConverter, SparseTensorLvlOpConverter,
916 SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter,
917 SparseTensorAllocConverter, SparseTensorEmptyConverter,
918 SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
919 SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
920 SparseToCoordinatesBufferConverter, SparseTensorToValuesConverter,
921 SparseNumberOfEntriesConverter, SparseTensorLoadConverter,
922 SparseTensorInsertConverter, SparseTensorExpandConverter,
923 SparseTensorCompressConverter, SparseTensorAssembleConverter,
924 SparseTensorDisassembleConverter, SparseHasRuntimeLibraryConverter>(
925 arg&: typeConverter, args: patterns.getContext());
926}
927

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp