1//===- SparseTensorDescriptor.cpp -----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "SparseTensorDescriptor.h"
10#include "CodegenUtils.h"
11
12#include "mlir/Dialect/MemRef/IR/MemRef.h"
13#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
15#include "mlir/Transforms/DialectConversion.h"
16
17using namespace mlir;
18using namespace sparse_tensor;
19
20//===----------------------------------------------------------------------===//
21// Private helper methods.
22//===----------------------------------------------------------------------===//
23
24/// Constructs a nullable `LevelAttr` from the `std::optional<Level>`.
25static IntegerAttr optionalLevelAttr(MLIRContext *ctx,
26 std::optional<Level> lvl) {
27 return lvl ? IntegerAttr::get(type: IndexType::get(context: ctx), value: lvl.value())
28 : IntegerAttr();
29}
30
31// This is only ever called from `SparseTensorTypeToBufferConverter`,
32// which is why the first argument is `RankedTensorType` rather than
33// `SparseTensorType`.
34static std::optional<LogicalResult>
35convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
36 const SparseTensorType stt(rtp);
37 if (!stt.hasEncoding())
38 return std::nullopt;
39
40 unsigned numFields = fields.size();
41 (void)numFields;
42 foreachFieldAndTypeInSparseTensor(
43 stt,
44 [&](Type fieldType, FieldIndex fieldIdx,
45 SparseTensorFieldKind /*fieldKind*/, Level /*lvl*/,
46 LevelType /*lt*/) -> bool {
47 assert(numFields + fieldIdx == fields.size());
48 fields.push_back(Elt: fieldType);
49 return true;
50 });
51 return success();
52}
53
54//===----------------------------------------------------------------------===//
55// The sparse tensor type converter (defined in Passes.h).
56//===----------------------------------------------------------------------===//
57
58static Value materializeTuple(OpBuilder &builder, RankedTensorType tp,
59 ValueRange inputs, Location loc) {
60 if (!getSparseTensorEncoding(type: tp))
61 // Not a sparse tensor.
62 return Value();
63 // Sparsifier knows how to cancel out these casts.
64 return genTuple(builder, loc, tp, values: inputs);
65}
66
67SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
68 addConversion(callback: [](Type type) { return type; });
69 addConversion(callback&: convertSparseTensorType);
70
71 // Required by scf.for 1:N type conversion.
72 addSourceMaterialization(callback&: materializeTuple);
73}
74
75//===----------------------------------------------------------------------===//
76// StorageTensorSpecifier methods.
77//===----------------------------------------------------------------------===//
78
79Value SparseTensorSpecifier::getInitValue(OpBuilder &builder, Location loc,
80 SparseTensorType stt) {
81 return builder.create<StorageSpecifierInitOp>(
82 location: loc, args: StorageSpecifierType::get(encoding: stt.getEncoding()));
83}
84
85Value SparseTensorSpecifier::getSpecifierField(OpBuilder &builder, Location loc,
86 StorageSpecifierKind kind,
87 std::optional<Level> lvl) {
88 return builder.create<GetStorageSpecifierOp>(
89 location: loc, args&: specifier, args&: kind, args: optionalLevelAttr(ctx: specifier.getContext(), lvl));
90}
91
92void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc,
93 Value v,
94 StorageSpecifierKind kind,
95 std::optional<Level> lvl) {
96 // TODO: make `v` have type `TypedValue<IndexType>` instead.
97 assert(v.getType().isIndex());
98 specifier = builder.create<SetStorageSpecifierOp>(
99 location: loc, args&: specifier, args&: kind, args: optionalLevelAttr(ctx: specifier.getContext(), lvl), args&: v);
100}
101
102//===----------------------------------------------------------------------===//
103// SparseTensorDescriptor methods.
104//===----------------------------------------------------------------------===//
105
106Value sparse_tensor::SparseTensorDescriptor::getCrdMemRefOrView(
107 OpBuilder &builder, Location loc, Level lvl) const {
108 const Level cooStart = rType.getAoSCOOStart();
109 if (lvl < cooStart)
110 return getMemRefField(kind: SparseTensorFieldKind::CrdMemRef, lvl);
111
112 Value stride = constantIndex(builder, loc, i: rType.getLvlRank() - cooStart);
113 Value size = getCrdMemSize(builder, loc, lvl: cooStart);
114 size = builder.create<arith::DivUIOp>(location: loc, args&: size, args&: stride);
115 return builder.create<memref::SubViewOp>(
116 location: loc, args: getMemRefField(kind: SparseTensorFieldKind::CrdMemRef, lvl: cooStart),
117 /*offset=*/args: ValueRange{constantIndex(builder, loc, i: lvl - cooStart)},
118 /*size=*/args: ValueRange{size},
119 /*step=*/args: ValueRange{stride});
120}
121

source code of mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp