1//===- SparseTensorDescriptor.cpp -----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "SparseTensorDescriptor.h"
10#include "CodegenUtils.h"
11
12#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13#include "mlir/Dialect/MemRef/IR/MemRef.h"
14#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
16#include "mlir/Transforms/DialectConversion.h"
17
18using namespace mlir;
19using namespace sparse_tensor;
20
21//===----------------------------------------------------------------------===//
22// Private helper methods.
23//===----------------------------------------------------------------------===//
24
25/// Constructs a nullable `LevelAttr` from the `std::optional<Level>`.
26static IntegerAttr optionalLevelAttr(MLIRContext *ctx,
27 std::optional<Level> lvl) {
28 return lvl ? IntegerAttr::get(IndexType::get(ctx), lvl.value())
29 : IntegerAttr();
30}
31
32// This is only ever called from `SparseTensorTypeToBufferConverter`,
33// which is why the first argument is `RankedTensorType` rather than
34// `SparseTensorType`.
35static std::optional<LogicalResult>
36convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
37 const SparseTensorType stt(rtp);
38 if (!stt.hasEncoding())
39 return std::nullopt;
40
41 unsigned numFields = fields.size();
42 (void)numFields;
43 foreachFieldAndTypeInSparseTensor(
44 stt,
45 [&](Type fieldType, FieldIndex fieldIdx,
46 SparseTensorFieldKind /*fieldKind*/, Level /*lvl*/,
47 LevelType /*lt*/) -> bool {
48 assert(numFields + fieldIdx == fields.size());
49 fields.push_back(Elt: fieldType);
50 return true;
51 });
52 return success();
53}
54
55//===----------------------------------------------------------------------===//
56// The sparse tensor type converter (defined in Passes.h).
57//===----------------------------------------------------------------------===//
58
59static Value materializeTuple(OpBuilder &builder, RankedTensorType tp,
60 ValueRange inputs, Location loc) {
61 if (!getSparseTensorEncoding(tp))
62 // Not a sparse tensor.
63 return Value();
64 // Sparsifier knows how to cancel out these casts.
65 return genTuple(builder, loc, tp, inputs);
66}
67
68SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
69 addConversion(callback: [](Type type) { return type; });
70 addConversion(callback&: convertSparseTensorType);
71
72 // Required by scf.for 1:N type conversion.
73 addSourceMaterialization(callback&: materializeTuple);
74}
75
76//===----------------------------------------------------------------------===//
77// StorageTensorSpecifier methods.
78//===----------------------------------------------------------------------===//
79
80Value SparseTensorSpecifier::getInitValue(OpBuilder &builder, Location loc,
81 SparseTensorType stt) {
82 return builder.create<StorageSpecifierInitOp>(
83 loc, StorageSpecifierType::get(stt.getEncoding()));
84}
85
86Value SparseTensorSpecifier::getSpecifierField(OpBuilder &builder, Location loc,
87 StorageSpecifierKind kind,
88 std::optional<Level> lvl) {
89 return builder.create<GetStorageSpecifierOp>(
90 loc, specifier, kind, optionalLevelAttr(specifier.getContext(), lvl));
91}
92
93void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc,
94 Value v,
95 StorageSpecifierKind kind,
96 std::optional<Level> lvl) {
97 // TODO: make `v` have type `TypedValue<IndexType>` instead.
98 assert(v.getType().isIndex());
99 specifier = builder.create<SetStorageSpecifierOp>(
100 loc, specifier, kind, optionalLevelAttr(specifier.getContext(), lvl), v);
101}
102
103//===----------------------------------------------------------------------===//
104// SparseTensorDescriptor methods.
105//===----------------------------------------------------------------------===//
106
107Value sparse_tensor::SparseTensorDescriptor::getCrdMemRefOrView(
108 OpBuilder &builder, Location loc, Level lvl) const {
109 const Level cooStart = rType.getAoSCOOStart();
110 if (lvl < cooStart)
111 return getMemRefField(SparseTensorFieldKind::CrdMemRef, lvl);
112
113 Value stride = constantIndex(builder, loc, rType.getLvlRank() - cooStart);
114 Value size = getCrdMemSize(builder, loc, cooStart);
115 size = builder.create<arith::DivUIOp>(loc, size, stride);
116 return builder.create<memref::SubViewOp>(
117 loc, getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart),
118 /*offset=*/ValueRange{constantIndex(builder, loc, lvl - cooStart)},
119 /*size=*/ValueRange{size},
120 /*step=*/ValueRange{stride});
121}
122

source code of mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp