1 | //===- DimLvlMap.h ----------------------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H |
10 | #define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H |
11 | |
12 | #include "Var.h" |
13 | |
14 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
15 | #include "llvm/ADT/STLForwardCompat.h" |
16 | |
17 | namespace mlir { |
18 | namespace sparse_tensor { |
19 | namespace ir_detail { |
20 | |
21 | //===----------------------------------------------------------------------===// |
22 | enum class ExprKind : bool { Dimension = false, Level = true }; |
23 | |
24 | constexpr VarKind getVarKindAllowedInExpr(ExprKind ek) { |
25 | using VK = std::underlying_type_t<VarKind>; |
26 | return VarKind{2 * static_cast<VK>(!llvm::to_underlying(E: ek))}; |
27 | } |
28 | static_assert(getVarKindAllowedInExpr(ek: ExprKind::Dimension) == VarKind::Level && |
29 | getVarKindAllowedInExpr(ek: ExprKind::Level) == VarKind::Dimension); |
30 | |
31 | //===----------------------------------------------------------------------===// |
32 | class DimLvlExpr { |
33 | private: |
34 | ExprKind kind; |
35 | AffineExpr expr; |
36 | |
37 | public: |
38 | constexpr DimLvlExpr(ExprKind ek, AffineExpr expr) : kind(ek), expr(expr) {} |
39 | |
40 | // |
41 | // Boolean operators. |
42 | // |
43 | constexpr bool operator==(DimLvlExpr other) const { |
44 | return kind == other.kind && expr == other.expr; |
45 | } |
46 | constexpr bool operator!=(DimLvlExpr other) const { |
47 | return !(*this == other); |
48 | } |
49 | explicit operator bool() const { return static_cast<bool>(expr); } |
50 | |
51 | // |
52 | // RTTI support (for the `DimLvlExpr` class itself). |
53 | // |
54 | template <typename U> |
55 | constexpr bool isa() const; |
56 | template <typename U> |
57 | constexpr U cast() const; |
58 | template <typename U> |
59 | constexpr U dyn_cast() const; |
60 | |
61 | // |
62 | // Simple getters. |
63 | // |
64 | constexpr ExprKind getExprKind() const { return kind; } |
65 | constexpr VarKind getAllowedVarKind() const { |
66 | return getVarKindAllowedInExpr(ek: kind); |
67 | } |
68 | constexpr AffineExpr getAffineExpr() const { return expr; } |
69 | AffineExprKind getAffineKind() const { |
70 | assert(expr); |
71 | return expr.getKind(); |
72 | } |
73 | MLIRContext *tryGetContext() const { |
74 | return expr ? expr.getContext() : nullptr; |
75 | } |
76 | |
77 | // |
78 | // Getters for handling `AffineExpr` subclasses. |
79 | // |
80 | SymVar castSymVar() const; |
81 | std::optional<SymVar> dyn_castSymVar() const; |
82 | Var castDimLvlVar() const; |
83 | std::optional<Var> dyn_castDimLvlVar() const; |
84 | std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr> unpackBinop() const; |
85 | |
86 | /// Checks whether the variables bound/used by this spec are valid |
87 | /// with respect to the given ranks. |
88 | [[nodiscard]] bool isValid(Ranks const &ranks) const; |
89 | |
90 | protected: |
91 | // Variant of `mlir::AsmPrinter::Impl::BindingStrength` |
92 | enum class BindingStrength : bool { Weak = false, Strong = true }; |
93 | }; |
94 | static_assert(IsZeroCostAbstraction<DimLvlExpr>); |
95 | |
96 | class DimExpr final : public DimLvlExpr { |
97 | friend class DimLvlExpr; |
98 | constexpr explicit DimExpr(DimLvlExpr expr) : DimLvlExpr(expr) {} |
99 | |
100 | public: |
101 | static constexpr ExprKind Kind = ExprKind::Dimension; |
102 | static constexpr bool classof(DimLvlExpr const *expr) { |
103 | return expr->getExprKind() == Kind; |
104 | } |
105 | constexpr explicit DimExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {} |
106 | |
107 | LvlVar castLvlVar() const { return castDimLvlVar().cast<LvlVar>(); } |
108 | std::optional<LvlVar> dyn_castLvlVar() const { |
109 | const auto var = dyn_castDimLvlVar(); |
110 | return var ? std::make_optional(t: var->cast<LvlVar>()) : std::nullopt; |
111 | } |
112 | }; |
113 | static_assert(IsZeroCostAbstraction<DimExpr>); |
114 | |
115 | class LvlExpr final : public DimLvlExpr { |
116 | friend class DimLvlExpr; |
117 | constexpr explicit LvlExpr(DimLvlExpr expr) : DimLvlExpr(expr) {} |
118 | |
119 | public: |
120 | static constexpr ExprKind Kind = ExprKind::Level; |
121 | static constexpr bool classof(DimLvlExpr const *expr) { |
122 | return expr->getExprKind() == Kind; |
123 | } |
124 | constexpr explicit LvlExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {} |
125 | |
126 | DimVar castDimVar() const { return castDimLvlVar().cast<DimVar>(); } |
127 | std::optional<DimVar> dyn_castDimVar() const { |
128 | const auto var = dyn_castDimLvlVar(); |
129 | return var ? std::make_optional(t: var->cast<DimVar>()) : std::nullopt; |
130 | } |
131 | }; |
132 | static_assert(IsZeroCostAbstraction<LvlExpr>); |
133 | |
134 | template <typename U> |
135 | constexpr bool DimLvlExpr::isa() const { |
136 | if constexpr (std::is_same_v<U, DimExpr>) |
137 | return getExprKind() == ExprKind::Dimension; |
138 | if constexpr (std::is_same_v<U, LvlExpr>) |
139 | return getExprKind() == ExprKind::Level; |
140 | } |
141 | |
142 | template <typename U> |
143 | constexpr U DimLvlExpr::cast() const { |
144 | assert(isa<U>()); |
145 | return U(*this); |
146 | } |
147 | |
148 | template <typename U> |
149 | constexpr U DimLvlExpr::dyn_cast() const { |
150 | return isa<U>() ? U(*this) : U(); |
151 | } |
152 | |
153 | //===----------------------------------------------------------------------===// |
154 | /// The full `dimVar = dimExpr : dimSlice` specification for a given dimension. |
155 | class DimSpec final { |
156 | /// The dimension-variable bound by this specification. |
157 | DimVar var; |
158 | /// The dimension-expression. The `DimSpec` ctor treats this field |
159 | /// as optional; whereas the `DimLvlMap` ctor will fill in (or verify) |
160 | /// the expression via function-inversion inference. |
161 | DimExpr expr; |
162 | /// Can the `expr` be elided when printing? The `DimSpec` ctor assumes |
163 | /// not (though if `expr` is null it will elide printing that); whereas |
164 | /// the `DimLvlMap` ctor will reset it as appropriate. |
165 | bool elideExpr = false; |
166 | /// The dimension-slice; optional, default is null. |
167 | SparseTensorDimSliceAttr slice; |
168 | |
169 | public: |
170 | DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice); |
171 | |
172 | MLIRContext *tryGetContext() const { return expr.tryGetContext(); } |
173 | |
174 | constexpr DimVar getBoundVar() const { return var; } |
175 | bool hasExpr() const { return static_cast<bool>(expr); } |
176 | constexpr DimExpr getExpr() const { return expr; } |
177 | void setExpr(DimExpr newExpr) { |
178 | assert(!hasExpr()); |
179 | expr = newExpr; |
180 | } |
181 | constexpr bool canElideExpr() const { return elideExpr; } |
182 | void setElideExpr(bool b) { elideExpr = b; } |
183 | constexpr SparseTensorDimSliceAttr getSlice() const { return slice; } |
184 | |
185 | /// Checks whether the variables bound/used by this spec are valid with |
186 | /// respect to the given ranks. Note that null `DimExpr` is considered |
187 | /// to be vacuously valid, and therefore calling `setExpr` invalidates |
188 | /// the result of this predicate. |
189 | [[nodiscard]] bool isValid(Ranks const &ranks) const; |
190 | }; |
191 | |
192 | static_assert(IsZeroCostAbstraction<DimSpec>); |
193 | |
194 | //===----------------------------------------------------------------------===// |
195 | /// The full `lvlVar = lvlExpr : lvlType` specification for a given level. |
196 | class LvlSpec final { |
197 | /// The level-variable bound by this specification. |
198 | LvlVar var; |
199 | /// Can the `var` be elided when printing? The `LvlSpec` ctor assumes not; |
200 | /// whereas the `DimLvlMap` ctor will reset this as appropriate. |
201 | bool elideVar = false; |
202 | /// The level-expression. |
203 | LvlExpr expr; |
204 | /// The level-type (== level-format + lvl-properties). |
205 | LevelType type; |
206 | |
207 | public: |
208 | LvlSpec(LvlVar var, LvlExpr expr, LevelType type); |
209 | |
210 | MLIRContext *getContext() const { |
211 | MLIRContext *ctx = expr.tryGetContext(); |
212 | assert(ctx); |
213 | return ctx; |
214 | } |
215 | |
216 | constexpr LvlVar getBoundVar() const { return var; } |
217 | constexpr bool canElideVar() const { return elideVar; } |
218 | void setElideVar(bool b) { elideVar = b; } |
219 | constexpr LvlExpr getExpr() const { return expr; } |
220 | constexpr LevelType getType() const { return type; } |
221 | |
222 | /// Checks whether the variables bound/used by this spec are valid |
223 | /// with respect to the given ranks. |
224 | [[nodiscard]] bool isValid(Ranks const &ranks) const; |
225 | }; |
226 | |
227 | static_assert(IsZeroCostAbstraction<LvlSpec>); |
228 | |
229 | //===----------------------------------------------------------------------===// |
230 | class DimLvlMap final { |
231 | public: |
232 | DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs, |
233 | ArrayRef<LvlSpec> lvlSpecs); |
234 | |
235 | unsigned getSymRank() const { return symRank; } |
236 | unsigned getDimRank() const { return dimSpecs.size(); } |
237 | unsigned getLvlRank() const { return lvlSpecs.size(); } |
238 | unsigned getRank(VarKind vk) const { return getRanks().getRank(vk); } |
239 | Ranks getRanks() const { return {getSymRank(), getDimRank(), getLvlRank()}; } |
240 | |
241 | ArrayRef<DimSpec> getDims() const { return dimSpecs; } |
242 | const DimSpec &getDim(Dimension dim) const { return dimSpecs[dim]; } |
243 | SparseTensorDimSliceAttr getDimSlice(Dimension dim) const { |
244 | return getDim(dim).getSlice(); |
245 | } |
246 | |
247 | ArrayRef<LvlSpec> getLvls() const { return lvlSpecs; } |
248 | const LvlSpec &getLvl(Level lvl) const { return lvlSpecs[lvl]; } |
249 | LevelType getLvlType(Level lvl) const { return getLvl(lvl).getType(); } |
250 | |
251 | AffineMap getDimToLvlMap(MLIRContext *context) const; |
252 | AffineMap getLvlToDimMap(MLIRContext *context) const; |
253 | |
254 | private: |
255 | /// Checks for integrity of variable-binding structure. |
256 | /// This is already called by the ctor. |
257 | [[nodiscard]] bool isWF() const; |
258 | |
259 | /// Helper function to call `DimSpec::setExpr` while asserting that |
260 | /// the invariant established by `DimLvlMap:isWF` is maintained. |
261 | /// This is used by the ctor. |
262 | void setDimExpr(Dimension dim, DimExpr expr) { |
263 | assert(expr && getRanks().isValid(expr)); |
264 | dimSpecs[dim].setExpr(expr); |
265 | } |
266 | |
267 | // All these fields are const-after-ctor. |
268 | unsigned symRank; |
269 | SmallVector<DimSpec> dimSpecs; |
270 | SmallVector<LvlSpec> lvlSpecs; |
271 | bool mustPrintLvlVars; |
272 | }; |
273 | |
274 | //===----------------------------------------------------------------------===// |
275 | |
276 | } // namespace ir_detail |
277 | } // namespace sparse_tensor |
278 | } // namespace mlir |
279 | |
280 | #endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H |
281 | |