1//===- DimLvlMap.h ----------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
10#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
11
12#include "Var.h"
13
14#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15#include "llvm/ADT/STLForwardCompat.h"
16
17namespace mlir {
18namespace sparse_tensor {
19namespace ir_detail {
20
21//===----------------------------------------------------------------------===//
22enum class ExprKind : bool { Dimension = false, Level = true };
23
24constexpr VarKind getVarKindAllowedInExpr(ExprKind ek) {
25 using VK = std::underlying_type_t<VarKind>;
26 return VarKind{2 * static_cast<VK>(!llvm::to_underlying(E: ek))};
27}
28static_assert(getVarKindAllowedInExpr(ek: ExprKind::Dimension) == VarKind::Level &&
29 getVarKindAllowedInExpr(ek: ExprKind::Level) == VarKind::Dimension);
30
31//===----------------------------------------------------------------------===//
32class DimLvlExpr {
33private:
34 ExprKind kind;
35 AffineExpr expr;
36
37public:
38 constexpr DimLvlExpr(ExprKind ek, AffineExpr expr) : kind(ek), expr(expr) {}
39
40 //
41 // Boolean operators.
42 //
43 constexpr bool operator==(DimLvlExpr other) const {
44 return kind == other.kind && expr == other.expr;
45 }
46 constexpr bool operator!=(DimLvlExpr other) const {
47 return !(*this == other);
48 }
49 explicit operator bool() const { return static_cast<bool>(expr); }
50
51 //
52 // RTTI support (for the `DimLvlExpr` class itself).
53 //
54 template <typename U>
55 constexpr bool isa() const;
56 template <typename U>
57 constexpr U cast() const;
58 template <typename U>
59 constexpr U dyn_cast() const;
60
61 //
62 // Simple getters.
63 //
64 constexpr ExprKind getExprKind() const { return kind; }
65 constexpr VarKind getAllowedVarKind() const {
66 return getVarKindAllowedInExpr(ek: kind);
67 }
68 constexpr AffineExpr getAffineExpr() const { return expr; }
69 AffineExprKind getAffineKind() const {
70 assert(expr);
71 return expr.getKind();
72 }
73 MLIRContext *tryGetContext() const {
74 return expr ? expr.getContext() : nullptr;
75 }
76
77 //
78 // Getters for handling `AffineExpr` subclasses.
79 //
80 SymVar castSymVar() const;
81 std::optional<SymVar> dyn_castSymVar() const;
82 Var castDimLvlVar() const;
83 std::optional<Var> dyn_castDimLvlVar() const;
84 std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr> unpackBinop() const;
85
86 /// Checks whether the variables bound/used by this spec are valid
87 /// with respect to the given ranks.
88 [[nodiscard]] bool isValid(Ranks const &ranks) const;
89
90protected:
91 // Variant of `mlir::AsmPrinter::Impl::BindingStrength`
92 enum class BindingStrength : bool { Weak = false, Strong = true };
93};
94static_assert(IsZeroCostAbstraction<DimLvlExpr>);
95
96class DimExpr final : public DimLvlExpr {
97 friend class DimLvlExpr;
98 constexpr explicit DimExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
99
100public:
101 static constexpr ExprKind Kind = ExprKind::Dimension;
102 static constexpr bool classof(DimLvlExpr const *expr) {
103 return expr->getExprKind() == Kind;
104 }
105 constexpr explicit DimExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
106
107 LvlVar castLvlVar() const { return castDimLvlVar().cast<LvlVar>(); }
108 std::optional<LvlVar> dyn_castLvlVar() const {
109 const auto var = dyn_castDimLvlVar();
110 return var ? std::make_optional(t: var->cast<LvlVar>()) : std::nullopt;
111 }
112};
113static_assert(IsZeroCostAbstraction<DimExpr>);
114
115class LvlExpr final : public DimLvlExpr {
116 friend class DimLvlExpr;
117 constexpr explicit LvlExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
118
119public:
120 static constexpr ExprKind Kind = ExprKind::Level;
121 static constexpr bool classof(DimLvlExpr const *expr) {
122 return expr->getExprKind() == Kind;
123 }
124 constexpr explicit LvlExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
125
126 DimVar castDimVar() const { return castDimLvlVar().cast<DimVar>(); }
127 std::optional<DimVar> dyn_castDimVar() const {
128 const auto var = dyn_castDimLvlVar();
129 return var ? std::make_optional(t: var->cast<DimVar>()) : std::nullopt;
130 }
131};
132static_assert(IsZeroCostAbstraction<LvlExpr>);
133
134template <typename U>
135constexpr bool DimLvlExpr::isa() const {
136 if constexpr (std::is_same_v<U, DimExpr>)
137 return getExprKind() == ExprKind::Dimension;
138 if constexpr (std::is_same_v<U, LvlExpr>)
139 return getExprKind() == ExprKind::Level;
140}
141
142template <typename U>
143constexpr U DimLvlExpr::cast() const {
144 assert(isa<U>());
145 return U(*this);
146}
147
148template <typename U>
149constexpr U DimLvlExpr::dyn_cast() const {
150 return isa<U>() ? U(*this) : U();
151}
152
153//===----------------------------------------------------------------------===//
154/// The full `dimVar = dimExpr : dimSlice` specification for a given dimension.
155class DimSpec final {
156 /// The dimension-variable bound by this specification.
157 DimVar var;
158 /// The dimension-expression. The `DimSpec` ctor treats this field
159 /// as optional; whereas the `DimLvlMap` ctor will fill in (or verify)
160 /// the expression via function-inversion inference.
161 DimExpr expr;
162 /// Can the `expr` be elided when printing? The `DimSpec` ctor assumes
163 /// not (though if `expr` is null it will elide printing that); whereas
164 /// the `DimLvlMap` ctor will reset it as appropriate.
165 bool elideExpr = false;
166 /// The dimension-slice; optional, default is null.
167 SparseTensorDimSliceAttr slice;
168
169public:
170 DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice);
171
172 MLIRContext *tryGetContext() const { return expr.tryGetContext(); }
173
174 constexpr DimVar getBoundVar() const { return var; }
175 bool hasExpr() const { return static_cast<bool>(expr); }
176 constexpr DimExpr getExpr() const { return expr; }
177 void setExpr(DimExpr newExpr) {
178 assert(!hasExpr());
179 expr = newExpr;
180 }
181 constexpr bool canElideExpr() const { return elideExpr; }
182 void setElideExpr(bool b) { elideExpr = b; }
183 constexpr SparseTensorDimSliceAttr getSlice() const { return slice; }
184
185 /// Checks whether the variables bound/used by this spec are valid with
186 /// respect to the given ranks. Note that null `DimExpr` is considered
187 /// to be vacuously valid, and therefore calling `setExpr` invalidates
188 /// the result of this predicate.
189 [[nodiscard]] bool isValid(Ranks const &ranks) const;
190};
191
192static_assert(IsZeroCostAbstraction<DimSpec>);
193
194//===----------------------------------------------------------------------===//
195/// The full `lvlVar = lvlExpr : lvlType` specification for a given level.
196class LvlSpec final {
197 /// The level-variable bound by this specification.
198 LvlVar var;
199 /// Can the `var` be elided when printing? The `LvlSpec` ctor assumes not;
200 /// whereas the `DimLvlMap` ctor will reset this as appropriate.
201 bool elideVar = false;
202 /// The level-expression.
203 LvlExpr expr;
204 /// The level-type (== level-format + lvl-properties).
205 LevelType type;
206
207public:
208 LvlSpec(LvlVar var, LvlExpr expr, LevelType type);
209
210 MLIRContext *getContext() const {
211 MLIRContext *ctx = expr.tryGetContext();
212 assert(ctx);
213 return ctx;
214 }
215
216 constexpr LvlVar getBoundVar() const { return var; }
217 constexpr bool canElideVar() const { return elideVar; }
218 void setElideVar(bool b) { elideVar = b; }
219 constexpr LvlExpr getExpr() const { return expr; }
220 constexpr LevelType getType() const { return type; }
221
222 /// Checks whether the variables bound/used by this spec are valid
223 /// with respect to the given ranks.
224 [[nodiscard]] bool isValid(Ranks const &ranks) const;
225};
226
227static_assert(IsZeroCostAbstraction<LvlSpec>);
228
229//===----------------------------------------------------------------------===//
230class DimLvlMap final {
231public:
232 DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
233 ArrayRef<LvlSpec> lvlSpecs);
234
235 unsigned getSymRank() const { return symRank; }
236 unsigned getDimRank() const { return dimSpecs.size(); }
237 unsigned getLvlRank() const { return lvlSpecs.size(); }
238 unsigned getRank(VarKind vk) const { return getRanks().getRank(vk); }
239 Ranks getRanks() const { return {getSymRank(), getDimRank(), getLvlRank()}; }
240
241 ArrayRef<DimSpec> getDims() const { return dimSpecs; }
242 const DimSpec &getDim(Dimension dim) const { return dimSpecs[dim]; }
243 SparseTensorDimSliceAttr getDimSlice(Dimension dim) const {
244 return getDim(dim).getSlice();
245 }
246
247 ArrayRef<LvlSpec> getLvls() const { return lvlSpecs; }
248 const LvlSpec &getLvl(Level lvl) const { return lvlSpecs[lvl]; }
249 LevelType getLvlType(Level lvl) const { return getLvl(lvl).getType(); }
250
251 AffineMap getDimToLvlMap(MLIRContext *context) const;
252 AffineMap getLvlToDimMap(MLIRContext *context) const;
253
254private:
255 /// Checks for integrity of variable-binding structure.
256 /// This is already called by the ctor.
257 [[nodiscard]] bool isWF() const;
258
259 /// Helper function to call `DimSpec::setExpr` while asserting that
260 /// the invariant established by `DimLvlMap:isWF` is maintained.
261 /// This is used by the ctor.
262 void setDimExpr(Dimension dim, DimExpr expr) {
263 assert(expr && getRanks().isValid(expr));
264 dimSpecs[dim].setExpr(expr);
265 }
266
267 // All these fields are const-after-ctor.
268 unsigned symRank;
269 SmallVector<DimSpec> dimSpecs;
270 SmallVector<LvlSpec> lvlSpecs;
271 bool mustPrintLvlVars;
272};
273
274//===----------------------------------------------------------------------===//
275
276} // namespace ir_detail
277} // namespace sparse_tensor
278} // namespace mlir
279
280#endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
281

source code of mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h