1//===- SparseTensor.h - Sparse tensor dialect -------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSOR_H_
10#define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSOR_H_
11
12#include "mlir/Bytecode/BytecodeOpInterface.h"
13#include "mlir/Dialect/SparseTensor/IR/Enums.h"
14#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h"
15#include "mlir/IR/BuiltinTypes.h"
16#include "mlir/IR/Dialect.h"
17#include "mlir/IR/OpDefinition.h"
18#include "mlir/IR/OpImplementation.h"
19#include "mlir/IR/TensorEncoding.h"
20#include "mlir/Interfaces/InferTypeOpInterface.h"
21#include "mlir/Interfaces/SideEffectInterfaces.h"
22
23//===----------------------------------------------------------------------===//
24//
25// Type aliases to help code be more self-documenting. Unfortunately
26// these are not type-checked, so they only provide documentation rather
27// than doing anything to prevent mixups.
28//
29//===----------------------------------------------------------------------===//
30
31namespace mlir {
32namespace sparse_tensor {
33
34/// The type of dimension identifiers and dimension-ranks.
35using Dimension = uint64_t;
36
37/// The type of level identifiers and level-ranks.
38using Level = uint64_t;
39
40/// The type for individual components of a compile-time shape,
41/// including the value `ShapedType::kDynamic` (for shapes).
42using Size = int64_t;
43
44} // namespace sparse_tensor
45} // namespace mlir
46
47//===----------------------------------------------------------------------===//
48// TableGen-defined classes
49//===----------------------------------------------------------------------===//
50
51#define GET_ATTRDEF_CLASSES
52#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.h.inc"
53
54#define GET_ATTRDEF_CLASSES
55#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.h.inc"
56
57#define GET_TYPEDEF_CLASSES
58#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.h.inc"
59
60#define GET_OP_CLASSES
61#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.h.inc"
62
63#include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.h.inc"
64
65//===----------------------------------------------------------------------===//
66// Additional convenience methods.
67//===----------------------------------------------------------------------===//
68
69namespace mlir {
70namespace sparse_tensor {
71
72/// Convenience method to abbreviate casting `getType()`.
73template <typename T>
74inline RankedTensorType getRankedTensorType(T &&t) {
75 assert(static_cast<bool>(std::forward<T>(t)) &&
76 "getRankedTensorType got null argument");
77 return dyn_cast<RankedTensorType>(std::forward<T>(t).getType());
78}
79
80/// Convenience method to abbreviate casting `getType()`.
81template <typename T>
82inline MemRefType getMemRefType(T &&t) {
83 assert(static_cast<bool>(std::forward<T>(t)) &&
84 "getMemRefType got null argument");
85 return cast<MemRefType>(std::forward<T>(t).getType());
86}
87
88/// Convenience method to get a sparse encoding attribute from a type.
89/// Returns null-attribute for any type without an encoding.
90SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
91
92/// Returns true iff MLIR operand has any sparse operand.
93inline bool hasAnySparseOperand(Operation *op) {
94 return llvm::any_of(Range: op->getOperands().getTypes(), P: [](Type t) {
95 return getSparseTensorEncoding(t) != nullptr;
96 });
97}
98
99/// Returns true iff MLIR operand has any sparse result.
100inline bool hasAnySparseResult(Operation *op) {
101 return llvm::any_of(Range: op->getResults().getTypes(), P: [](Type t) {
102 return getSparseTensorEncoding(t) != nullptr;
103 });
104}
105
106/// Returns true iff MLIR operand has any sparse operand or result.
107inline bool hasAnySparseOperandOrResult(Operation *op) {
108 return hasAnySparseOperand(op) || hasAnySparseResult(op);
109}
110
111/// Returns true iff MLIR operation has any sparse tensor with non-identity
112/// dim2lvl maps.
113bool hasAnyNonIdentityOperandsOrResults(Operation *op);
114
115//
116// Inference.
117//
118
119/// Given the dimToLvl map, infers the lvlToDim map, or returns
120/// empty Affine map when inference fails.
121AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context);
122
123/// Returns the lvlToDim map for the given dimToLvl map specific
124/// to the block sparse cases.
125/// Asserts on failure (so only use when known to succeed).
126AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context);
127
128/// Given the dimToLvl map, returns the block sizes in a vector.
129/// For instance, a 2x3 block will return [2, 3]. Unblocked dimension i
130/// will return 0, and i floordiv 1, i mod 1 will return 1. Therefore,
131/// the example below will return [0, 1].
132/// map = ( i, j ) ->
133/// ( i : dense,
134/// j floordiv 1 : compressed,
135/// j mod 1 : dense
136/// )
137/// Only valid block sparsity will be accepted.
138SmallVector<unsigned> getBlockSize(AffineMap dimToLvl);
139
140/// Given the dimToLvl map, returns if it's block sparsity.
141bool isBlockSparsity(AffineMap dimToLvl);
142
143//
144// Reordering.
145//
146
147/// Convenience method to translate the given level to the corresponding
148/// dimension.
149/// Requires: `enc` has a permuted dim2lvl map and `0 <= l < lvlRank`.
150Dimension toDim(SparseTensorEncodingAttr enc, Level l);
151
152/// Convenience method to translate the given dimension to the corresponding
153/// level.
154/// Requires: `enc` has a permuted dim2lvl map and `0 <= d < dimRank`.
155Level toLvl(SparseTensorEncodingAttr enc, Dimension d);
156
157} // namespace sparse_tensor
158} // namespace mlir
159
160#endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSOR_H_
161

source code of mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h