| 1 | //===- DimLvlMap.h ----------------------------------------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H |
| 10 | #define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H |
| 11 | |
| 12 | #include "Var.h" |
| 13 | |
| 14 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| 15 | #include "llvm/ADT/STLForwardCompat.h" |
| 16 | |
| 17 | namespace mlir { |
| 18 | namespace sparse_tensor { |
| 19 | namespace ir_detail { |
| 20 | |
| 21 | //===----------------------------------------------------------------------===// |
| 22 | enum class ExprKind : bool { Dimension = false, Level = true }; |
| 23 | |
| 24 | constexpr VarKind getVarKindAllowedInExpr(ExprKind ek) { |
| 25 | using VK = std::underlying_type_t<VarKind>; |
| 26 | return VarKind{2 * static_cast<VK>(!llvm::to_underlying(E: ek))}; |
| 27 | } |
| 28 | static_assert(getVarKindAllowedInExpr(ek: ExprKind::Dimension) == VarKind::Level && |
| 29 | getVarKindAllowedInExpr(ek: ExprKind::Level) == VarKind::Dimension); |
| 30 | |
| 31 | //===----------------------------------------------------------------------===// |
| 32 | class DimLvlExpr { |
| 33 | private: |
| 34 | ExprKind kind; |
| 35 | AffineExpr expr; |
| 36 | |
| 37 | public: |
| 38 | constexpr DimLvlExpr(ExprKind ek, AffineExpr expr) : kind(ek), expr(expr) {} |
| 39 | |
| 40 | // |
| 41 | // Boolean operators. |
| 42 | // |
| 43 | constexpr bool operator==(DimLvlExpr other) const { |
| 44 | return kind == other.kind && expr == other.expr; |
| 45 | } |
| 46 | constexpr bool operator!=(DimLvlExpr other) const { |
| 47 | return !(*this == other); |
| 48 | } |
| 49 | explicit operator bool() const { return static_cast<bool>(expr); } |
| 50 | |
| 51 | // |
| 52 | // RTTI support (for the `DimLvlExpr` class itself). |
| 53 | // |
| 54 | template <typename U> |
| 55 | constexpr bool isa() const; |
| 56 | template <typename U> |
| 57 | constexpr U cast() const; |
| 58 | template <typename U> |
| 59 | constexpr U dyn_cast() const; |
| 60 | |
| 61 | // |
| 62 | // Simple getters. |
| 63 | // |
| 64 | constexpr ExprKind getExprKind() const { return kind; } |
| 65 | constexpr VarKind getAllowedVarKind() const { |
| 66 | return getVarKindAllowedInExpr(ek: kind); |
| 67 | } |
| 68 | constexpr AffineExpr getAffineExpr() const { return expr; } |
| 69 | AffineExprKind getAffineKind() const { |
| 70 | assert(expr); |
| 71 | return expr.getKind(); |
| 72 | } |
| 73 | MLIRContext *tryGetContext() const { |
| 74 | return expr ? expr.getContext() : nullptr; |
| 75 | } |
| 76 | |
| 77 | // |
| 78 | // Getters for handling `AffineExpr` subclasses. |
| 79 | // |
| 80 | SymVar castSymVar() const; |
| 81 | std::optional<SymVar> dyn_castSymVar() const; |
| 82 | Var castDimLvlVar() const; |
| 83 | std::optional<Var> dyn_castDimLvlVar() const; |
| 84 | std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr> unpackBinop() const; |
| 85 | |
| 86 | /// Checks whether the variables bound/used by this spec are valid |
| 87 | /// with respect to the given ranks. |
| 88 | [[nodiscard]] bool isValid(Ranks const &ranks) const; |
| 89 | |
| 90 | protected: |
| 91 | // Variant of `mlir::AsmPrinter::Impl::BindingStrength` |
| 92 | enum class BindingStrength : bool { Weak = false, Strong = true }; |
| 93 | }; |
| 94 | static_assert(IsZeroCostAbstraction<DimLvlExpr>); |
| 95 | |
| 96 | class DimExpr final : public DimLvlExpr { |
| 97 | friend class DimLvlExpr; |
| 98 | constexpr explicit DimExpr(DimLvlExpr expr) : DimLvlExpr(expr) {} |
| 99 | |
| 100 | public: |
| 101 | static constexpr ExprKind Kind = ExprKind::Dimension; |
| 102 | static constexpr bool classof(DimLvlExpr const *expr) { |
| 103 | return expr->getExprKind() == Kind; |
| 104 | } |
| 105 | constexpr explicit DimExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {} |
| 106 | |
| 107 | LvlVar castLvlVar() const { return castDimLvlVar().cast<LvlVar>(); } |
| 108 | std::optional<LvlVar> dyn_castLvlVar() const { |
| 109 | const auto var = dyn_castDimLvlVar(); |
| 110 | return var ? std::make_optional(t: var->cast<LvlVar>()) : std::nullopt; |
| 111 | } |
| 112 | }; |
| 113 | static_assert(IsZeroCostAbstraction<DimExpr>); |
| 114 | |
| 115 | class LvlExpr final : public DimLvlExpr { |
| 116 | friend class DimLvlExpr; |
| 117 | constexpr explicit LvlExpr(DimLvlExpr expr) : DimLvlExpr(expr) {} |
| 118 | |
| 119 | public: |
| 120 | static constexpr ExprKind Kind = ExprKind::Level; |
| 121 | static constexpr bool classof(DimLvlExpr const *expr) { |
| 122 | return expr->getExprKind() == Kind; |
| 123 | } |
| 124 | constexpr explicit LvlExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {} |
| 125 | |
| 126 | DimVar castDimVar() const { return castDimLvlVar().cast<DimVar>(); } |
| 127 | std::optional<DimVar> dyn_castDimVar() const { |
| 128 | const auto var = dyn_castDimLvlVar(); |
| 129 | return var ? std::make_optional(t: var->cast<DimVar>()) : std::nullopt; |
| 130 | } |
| 131 | }; |
| 132 | static_assert(IsZeroCostAbstraction<LvlExpr>); |
| 133 | |
| 134 | template <typename U> |
| 135 | constexpr bool DimLvlExpr::isa() const { |
| 136 | if constexpr (std::is_same_v<U, DimExpr>) |
| 137 | return getExprKind() == ExprKind::Dimension; |
| 138 | if constexpr (std::is_same_v<U, LvlExpr>) |
| 139 | return getExprKind() == ExprKind::Level; |
| 140 | } |
| 141 | |
| 142 | template <typename U> |
| 143 | constexpr U DimLvlExpr::cast() const { |
| 144 | assert(isa<U>()); |
| 145 | return U(*this); |
| 146 | } |
| 147 | |
| 148 | template <typename U> |
| 149 | constexpr U DimLvlExpr::dyn_cast() const { |
| 150 | return isa<U>() ? U(*this) : U(); |
| 151 | } |
| 152 | |
| 153 | //===----------------------------------------------------------------------===// |
| 154 | /// The full `dimVar = dimExpr : dimSlice` specification for a given dimension. |
| 155 | class DimSpec final { |
| 156 | /// The dimension-variable bound by this specification. |
| 157 | DimVar var; |
| 158 | /// The dimension-expression. The `DimSpec` ctor treats this field |
| 159 | /// as optional; whereas the `DimLvlMap` ctor will fill in (or verify) |
| 160 | /// the expression via function-inversion inference. |
| 161 | DimExpr expr; |
| 162 | /// Can the `expr` be elided when printing? The `DimSpec` ctor assumes |
| 163 | /// not (though if `expr` is null it will elide printing that); whereas |
| 164 | /// the `DimLvlMap` ctor will reset it as appropriate. |
| 165 | bool elideExpr = false; |
| 166 | /// The dimension-slice; optional, default is null. |
| 167 | SparseTensorDimSliceAttr slice; |
| 168 | |
| 169 | public: |
| 170 | DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice); |
| 171 | |
| 172 | MLIRContext *tryGetContext() const { return expr.tryGetContext(); } |
| 173 | |
| 174 | constexpr DimVar getBoundVar() const { return var; } |
| 175 | bool hasExpr() const { return static_cast<bool>(expr); } |
| 176 | constexpr DimExpr getExpr() const { return expr; } |
| 177 | void setExpr(DimExpr newExpr) { |
| 178 | assert(!hasExpr()); |
| 179 | expr = newExpr; |
| 180 | } |
| 181 | constexpr bool canElideExpr() const { return elideExpr; } |
| 182 | void setElideExpr(bool b) { elideExpr = b; } |
| 183 | constexpr SparseTensorDimSliceAttr getSlice() const { return slice; } |
| 184 | |
| 185 | /// Checks whether the variables bound/used by this spec are valid with |
| 186 | /// respect to the given ranks. Note that null `DimExpr` is considered |
| 187 | /// to be vacuously valid, and therefore calling `setExpr` invalidates |
| 188 | /// the result of this predicate. |
| 189 | [[nodiscard]] bool isValid(Ranks const &ranks) const; |
| 190 | }; |
| 191 | |
| 192 | static_assert(IsZeroCostAbstraction<DimSpec>); |
| 193 | |
| 194 | //===----------------------------------------------------------------------===// |
| 195 | /// The full `lvlVar = lvlExpr : lvlType` specification for a given level. |
| 196 | class LvlSpec final { |
| 197 | /// The level-variable bound by this specification. |
| 198 | LvlVar var; |
| 199 | /// Can the `var` be elided when printing? The `LvlSpec` ctor assumes not; |
| 200 | /// whereas the `DimLvlMap` ctor will reset this as appropriate. |
| 201 | bool elideVar = false; |
| 202 | /// The level-expression. |
| 203 | LvlExpr expr; |
| 204 | /// The level-type (== level-format + lvl-properties). |
| 205 | LevelType type; |
| 206 | |
| 207 | public: |
| 208 | LvlSpec(LvlVar var, LvlExpr expr, LevelType type); |
| 209 | |
| 210 | MLIRContext *getContext() const { |
| 211 | MLIRContext *ctx = expr.tryGetContext(); |
| 212 | assert(ctx); |
| 213 | return ctx; |
| 214 | } |
| 215 | |
| 216 | constexpr LvlVar getBoundVar() const { return var; } |
| 217 | constexpr bool canElideVar() const { return elideVar; } |
| 218 | void setElideVar(bool b) { elideVar = b; } |
| 219 | constexpr LvlExpr getExpr() const { return expr; } |
| 220 | constexpr LevelType getType() const { return type; } |
| 221 | |
| 222 | /// Checks whether the variables bound/used by this spec are valid |
| 223 | /// with respect to the given ranks. |
| 224 | [[nodiscard]] bool isValid(Ranks const &ranks) const; |
| 225 | }; |
| 226 | |
| 227 | static_assert(IsZeroCostAbstraction<LvlSpec>); |
| 228 | |
| 229 | //===----------------------------------------------------------------------===// |
| 230 | class DimLvlMap final { |
| 231 | public: |
| 232 | DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs, |
| 233 | ArrayRef<LvlSpec> lvlSpecs); |
| 234 | |
| 235 | unsigned getSymRank() const { return symRank; } |
| 236 | unsigned getDimRank() const { return dimSpecs.size(); } |
| 237 | unsigned getLvlRank() const { return lvlSpecs.size(); } |
| 238 | unsigned getRank(VarKind vk) const { return getRanks().getRank(vk); } |
| 239 | Ranks getRanks() const { return {getSymRank(), getDimRank(), getLvlRank()}; } |
| 240 | |
| 241 | ArrayRef<DimSpec> getDims() const { return dimSpecs; } |
| 242 | const DimSpec &getDim(Dimension dim) const { return dimSpecs[dim]; } |
| 243 | SparseTensorDimSliceAttr getDimSlice(Dimension dim) const { |
| 244 | return getDim(dim).getSlice(); |
| 245 | } |
| 246 | |
| 247 | ArrayRef<LvlSpec> getLvls() const { return lvlSpecs; } |
| 248 | const LvlSpec &getLvl(Level lvl) const { return lvlSpecs[lvl]; } |
| 249 | LevelType getLvlType(Level lvl) const { return getLvl(lvl).getType(); } |
| 250 | |
| 251 | AffineMap getDimToLvlMap(MLIRContext *context) const; |
| 252 | AffineMap getLvlToDimMap(MLIRContext *context) const; |
| 253 | |
| 254 | private: |
| 255 | /// Checks for integrity of variable-binding structure. |
| 256 | /// This is already called by the ctor. |
| 257 | [[nodiscard]] bool isWF() const; |
| 258 | |
| 259 | /// Helper function to call `DimSpec::setExpr` while asserting that |
| 260 | /// the invariant established by `DimLvlMap:isWF` is maintained. |
| 261 | /// This is used by the ctor. |
| 262 | void setDimExpr(Dimension dim, DimExpr expr) { |
| 263 | assert(expr && getRanks().isValid(expr)); |
| 264 | dimSpecs[dim].setExpr(expr); |
| 265 | } |
| 266 | |
| 267 | // All these fields are const-after-ctor. |
| 268 | unsigned symRank; |
| 269 | SmallVector<DimSpec> dimSpecs; |
| 270 | SmallVector<LvlSpec> lvlSpecs; |
| 271 | bool mustPrintLvlVars; |
| 272 | }; |
| 273 | |
| 274 | //===----------------------------------------------------------------------===// |
| 275 | |
| 276 | } // namespace ir_detail |
| 277 | } // namespace sparse_tensor |
| 278 | } // namespace mlir |
| 279 | |
| 280 | #endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H |
| 281 | |