1//===- SparseTensorDescriptor.cpp -----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "SparseTensorDescriptor.h"
10#include "CodegenUtils.h"
11
12#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13#include "mlir/Dialect/MemRef/IR/MemRef.h"
14#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
16#include "mlir/Transforms/DialectConversion.h"
17
18using namespace mlir;
19using namespace sparse_tensor;
20
21//===----------------------------------------------------------------------===//
22// Private helper methods.
23//===----------------------------------------------------------------------===//
24
25/// Constructs a nullable `LevelAttr` from the `std::optional<Level>`.
26static IntegerAttr optionalLevelAttr(MLIRContext *ctx,
27 std::optional<Level> lvl) {
28 return lvl ? IntegerAttr::get(IndexType::get(ctx), lvl.value())
29 : IntegerAttr();
30}
31
32// This is only ever called from `SparseTensorTypeToBufferConverter`,
33// which is why the first argument is `RankedTensorType` rather than
34// `SparseTensorType`.
35static std::optional<LogicalResult>
36convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
37 const SparseTensorType stt(rtp);
38 if (!stt.hasEncoding())
39 return std::nullopt;
40
41 foreachFieldAndTypeInSparseTensor(
42 stt,
43 [&fields](Type fieldType, FieldIndex fieldIdx,
44 SparseTensorFieldKind /*fieldKind*/, Level /*lvl*/,
45 LevelType /*lt*/) -> bool {
46 assert(fieldIdx == fields.size());
47 fields.push_back(Elt: fieldType);
48 return true;
49 });
50 return success();
51}
52
53//===----------------------------------------------------------------------===//
54// The sparse tensor type converter (defined in Passes.h).
55//===----------------------------------------------------------------------===//
56
57SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
58 addConversion(callback: [](Type type) { return type; });
59 addConversion(callback&: convertSparseTensorType);
60
61 // Required by scf.for 1:N type conversion.
62 addSourceMaterialization(callback: [](OpBuilder &builder, RankedTensorType tp,
63 ValueRange inputs,
64 Location loc) -> std::optional<Value> {
65 if (!getSparseTensorEncoding(tp))
66 // Not a sparse tensor.
67 return std::nullopt;
68 // Sparsifier knows how to cancel out these casts.
69 return genTuple(builder, loc, tp, inputs);
70 });
71}
72
73//===----------------------------------------------------------------------===//
74// StorageTensorSpecifier methods.
75//===----------------------------------------------------------------------===//
76
77Value SparseTensorSpecifier::getInitValue(OpBuilder &builder, Location loc,
78 SparseTensorType stt) {
79 return builder.create<StorageSpecifierInitOp>(
80 loc, StorageSpecifierType::get(stt.getEncoding()));
81}
82
83Value SparseTensorSpecifier::getSpecifierField(OpBuilder &builder, Location loc,
84 StorageSpecifierKind kind,
85 std::optional<Level> lvl) {
86 return builder.create<GetStorageSpecifierOp>(
87 loc, specifier, kind, optionalLevelAttr(specifier.getContext(), lvl));
88}
89
90void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc,
91 Value v,
92 StorageSpecifierKind kind,
93 std::optional<Level> lvl) {
94 // TODO: make `v` have type `TypedValue<IndexType>` instead.
95 assert(v.getType().isIndex());
96 specifier = builder.create<SetStorageSpecifierOp>(
97 loc, specifier, kind, optionalLevelAttr(specifier.getContext(), lvl), v);
98}
99
100//===----------------------------------------------------------------------===//
101// SparseTensorDescriptor methods.
102//===----------------------------------------------------------------------===//
103
104Value sparse_tensor::SparseTensorDescriptor::getCrdMemRefOrView(
105 OpBuilder &builder, Location loc, Level lvl) const {
106 const Level cooStart = rType.getAoSCOOStart();
107 if (lvl < cooStart)
108 return getMemRefField(SparseTensorFieldKind::CrdMemRef, lvl);
109
110 Value stride = constantIndex(builder, loc, rType.getLvlRank() - cooStart);
111 Value size = getCrdMemSize(builder, loc, cooStart);
112 size = builder.create<arith::DivUIOp>(loc, size, stride);
113 return builder.create<memref::SubViewOp>(
114 loc, getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart),
115 /*offset=*/ValueRange{constantIndex(builder, loc, lvl - cooStart)},
116 /*size=*/ValueRange{size},
117 /*step=*/ValueRange{stride});
118}
119

source code of mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp