| 1 | //===- Var.h ----------------------------------------------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H |
| 10 | #define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H |
| 11 | |
| 12 | #include "TemplateExtras.h" |
| 13 | |
| 14 | #include "mlir/IR/OpImplementation.h" |
| 15 | #include "llvm/ADT/EnumeratedArray.h" |
| 16 | #include "llvm/ADT/STLForwardCompat.h" |
| 17 | #include "llvm/ADT/SmallBitVector.h" |
| 18 | #include "llvm/ADT/StringMap.h" |
| 19 | |
| 20 | namespace mlir { |
| 21 | namespace sparse_tensor { |
| 22 | namespace ir_detail { |
| 23 | |
| 24 | //===----------------------------------------------------------------------===// |
| 25 | /// The three kinds of variables that `Var` can be. |
| 26 | /// |
| 27 | /// NOTE: The numerical values used to represent this enum should be |
| 28 | /// treated as an implementation detail, not as part of the API. In the |
| 29 | /// API below we use the canonical ordering `{Symbol,Dimension,Level}` even |
| 30 | /// though that does not agree with the numerical ordering of the numerical |
| 31 | /// representation. |
| 32 | enum class VarKind { Symbol = 1, Dimension = 0, Level = 2 }; |
| 33 | |
| 34 | [[nodiscard]] constexpr bool isWF(VarKind vk) { |
| 35 | const auto vk_ = llvm::to_underlying(E: vk); |
| 36 | return 0 <= vk_ && vk_ <= 2; |
| 37 | } |
| 38 | |
| 39 | /// Gets the ASCII character used as the prefix when printing `Var`. |
| 40 | constexpr char toChar(VarKind vk) { |
| 41 | // If `isWF(vk)` then this computation's intermediate results are always |
| 42 | // in the range [-44..126] (where that lower bound is under worst-case |
| 43 | // rearranging of the expression); and `int_fast8_t` is the fastest type |
| 44 | // which can support that range without over-/underflow. |
| 45 | const auto vk_ = static_cast<int_fast8_t>(llvm::to_underlying(E: vk)); |
| 46 | return static_cast<char>(100 + vk_ * (26 - vk_ * 11)); |
| 47 | } |
| 48 | static_assert(toChar(vk: VarKind::Symbol) == 's' && |
| 49 | toChar(vk: VarKind::Dimension) == 'd' && |
| 50 | toChar(vk: VarKind::Level) == 'l'); |
| 51 | |
| 52 | //===----------------------------------------------------------------------===// |
| 53 | /// The type of arrays indexed by `VarKind`. |
| 54 | template <typename T> |
| 55 | using VarKindArray = llvm::EnumeratedArray<T, VarKind, VarKind::Level>; |
| 56 | |
| 57 | //===----------------------------------------------------------------------===// |
| 58 | /// A concrete variable, to be used in our variant of `AffineExpr`. |
| 59 | /// Client-facing class for `VarKind` + `Var::Num` pairs, with RTTI |
| 60 | /// support for subclasses with a fixed `VarKind`. |
| 61 | class Var { |
| 62 | public: |
| 63 | /// Typedef for the type of variable numbers. |
| 64 | using Num = unsigned; |
| 65 | |
| 66 | private: |
| 67 | /// Typedef for the underlying storage of `Var::Impl`. |
| 68 | using Storage = unsigned; |
| 69 | |
| 70 | /// The largest `Var::Num` supported by `Var`/`Var::Impl`/`Var::Storage`. |
| 71 | /// Two low-order bits are reserved for storing the `VarKind`, |
| 72 | /// and one high-order bit is reserved for future use (e.g., to support |
| 73 | /// `DenseMapInfo<Var>` while maintaining the usual numeric values for |
| 74 | /// "empty" and "tombstone"). |
| 75 | static constexpr Num kMaxNum = |
| 76 | static_cast<Num>(std::numeric_limits<Storage>::max() >> 3); |
| 77 | |
| 78 | public: |
| 79 | /// Checks whether the number would be accepted by `Var(VarKind,Var::Num)`. |
| 80 | // |
| 81 | // This must be public for `VarInfo` to use it (whereas we don't want |
| 82 | // to expose the `impl` field via friendship). |
| 83 | [[nodiscard]] static constexpr bool isWF_Num(Num n) { return n <= kMaxNum; } |
| 84 | |
| 85 | protected: |
| 86 | /// The underlying implementation of `Var`. Note that this must be kept |
| 87 | /// distinct from `Var` itself, since we want to ensure that the RTTI |
| 88 | /// methods will select the `U(Var::Impl)` ctor rather than selecting |
| 89 | /// the `U(Var::Num)` ctor. |
| 90 | class Impl final { |
| 91 | Storage data; |
| 92 | |
| 93 | public: |
| 94 | constexpr Impl(VarKind vk, Num n) |
| 95 | : data((static_cast<Storage>(n) << 2) | |
| 96 | static_cast<Storage>(llvm::to_underlying(E: vk))) { |
| 97 | assert(isWF(vk) && "unknown VarKind" ); |
| 98 | assert(isWF_Num(n) && "Var::Num is too large" ); |
| 99 | } |
| 100 | constexpr bool operator==(Impl other) const { return data == other.data; } |
| 101 | constexpr bool operator!=(Impl other) const { return !(*this == other); } |
| 102 | constexpr VarKind getKind() const { return static_cast<VarKind>(data & 3); } |
| 103 | constexpr Num getNum() const { return static_cast<Num>(data >> 2); } |
| 104 | }; |
| 105 | static_assert(IsZeroCostAbstraction<Impl>); |
| 106 | |
| 107 | private: |
| 108 | Impl impl; |
| 109 | |
| 110 | protected: |
| 111 | /// Protected ctor for the RTTI methods to use. |
| 112 | constexpr explicit Var(Impl impl) : impl(impl) {} |
| 113 | |
| 114 | public: |
| 115 | constexpr Var(VarKind vk, Num n) : impl(Impl(vk, n)) {} |
| 116 | Var(AffineSymbolExpr sym) : Var(VarKind::Symbol, sym.getPosition()) {} |
| 117 | Var(VarKind vk, AffineDimExpr var) : Var(vk, var.getPosition()) { |
| 118 | assert(vk != VarKind::Symbol); |
| 119 | } |
| 120 | |
| 121 | constexpr bool operator==(Var other) const { return impl == other.impl; } |
| 122 | constexpr bool operator!=(Var other) const { return !(*this == other); } |
| 123 | |
| 124 | constexpr VarKind getKind() const { return impl.getKind(); } |
| 125 | constexpr Num getNum() const { return impl.getNum(); } |
| 126 | |
| 127 | template <typename U> |
| 128 | constexpr bool isa() const; |
| 129 | template <typename U> |
| 130 | constexpr U cast() const; |
| 131 | template <typename U> |
| 132 | constexpr std::optional<U> dyn_cast() const; |
| 133 | |
| 134 | std::string str() const; |
| 135 | void print(llvm::raw_ostream &os) const; |
| 136 | void print(AsmPrinter &printer) const; |
| 137 | void dump() const; |
| 138 | }; |
| 139 | static_assert(IsZeroCostAbstraction<Var>); |
| 140 | |
| 141 | class SymVar final : public Var { |
| 142 | using Var::Var; // inherit `Var(Impl)` ctor for RTTI use. |
| 143 | public: |
| 144 | static constexpr VarKind Kind = VarKind::Symbol; |
| 145 | static constexpr bool classof(Var const *var) { |
| 146 | return var->getKind() == Kind; |
| 147 | } |
| 148 | constexpr SymVar(Num sym) : Var(Kind, sym) {} |
| 149 | SymVar(AffineSymbolExpr symExpr) : Var(symExpr) {} |
| 150 | }; |
| 151 | static_assert(IsZeroCostAbstraction<SymVar>); |
| 152 | |
| 153 | class DimVar final : public Var { |
| 154 | using Var::Var; // inherit `Var(Impl)` ctor for RTTI use. |
| 155 | public: |
| 156 | static constexpr VarKind Kind = VarKind::Dimension; |
| 157 | static constexpr bool classof(Var const *var) { |
| 158 | return var->getKind() == Kind; |
| 159 | } |
| 160 | constexpr DimVar(Num dim) : Var(Kind, dim) {} |
| 161 | DimVar(AffineDimExpr dimExpr) : Var(Kind, dimExpr) {} |
| 162 | }; |
| 163 | static_assert(IsZeroCostAbstraction<DimVar>); |
| 164 | |
| 165 | class LvlVar final : public Var { |
| 166 | using Var::Var; // inherit `Var(Impl)` ctor for RTTI use. |
| 167 | public: |
| 168 | static constexpr VarKind Kind = VarKind::Level; |
| 169 | static constexpr bool classof(Var const *var) { |
| 170 | return var->getKind() == Kind; |
| 171 | } |
| 172 | constexpr LvlVar(Num lvl) : Var(Kind, lvl) {} |
| 173 | LvlVar(AffineDimExpr lvlExpr) : Var(Kind, lvlExpr) {} |
| 174 | }; |
| 175 | static_assert(IsZeroCostAbstraction<LvlVar>); |
| 176 | |
| 177 | template <typename U> |
| 178 | constexpr bool Var::isa() const { |
| 179 | if constexpr (std::is_same_v<U, SymVar>) |
| 180 | return getKind() == VarKind::Symbol; |
| 181 | if constexpr (std::is_same_v<U, DimVar>) |
| 182 | return getKind() == VarKind::Dimension; |
| 183 | if constexpr (std::is_same_v<U, LvlVar>) |
| 184 | return getKind() == VarKind::Level; |
| 185 | } |
| 186 | |
| 187 | template <typename U> |
| 188 | constexpr U Var::cast() const { |
| 189 | assert(isa<U>()); |
| 190 | // NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)` |
| 191 | return U(impl); |
| 192 | } |
| 193 | |
| 194 | template <typename U> |
| 195 | constexpr std::optional<U> Var::dyn_cast() const { |
| 196 | // NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)` |
| 197 | return isa<U>() ? std::make_optional(U(impl)) : std::nullopt; |
| 198 | } |
| 199 | |
| 200 | //===----------------------------------------------------------------------===// |
| 201 | // Forward-decl so that we can declare methods of `Ranks` and `VarSet`. |
| 202 | class DimLvlExpr; |
| 203 | |
| 204 | //===----------------------------------------------------------------------===// |
| 205 | class Ranks final { |
| 206 | // Not using `VarKindArray` since `EnumeratedArray` doesn't support constexpr. |
| 207 | unsigned impl[3]; |
| 208 | |
| 209 | static constexpr unsigned to_index(VarKind vk) { |
| 210 | assert(isWF(vk) && "unknown VarKind" ); |
| 211 | return static_cast<unsigned>(llvm::to_underlying(E: vk)); |
| 212 | } |
| 213 | |
| 214 | public: |
| 215 | constexpr Ranks(unsigned symRank, unsigned dimRank, unsigned lvlRank) |
| 216 | : impl() { |
| 217 | impl[to_index(vk: VarKind::Symbol)] = symRank; |
| 218 | impl[to_index(vk: VarKind::Dimension)] = dimRank; |
| 219 | impl[to_index(vk: VarKind::Level)] = lvlRank; |
| 220 | } |
| 221 | Ranks(VarKindArray<unsigned> const &ranks) |
| 222 | : Ranks(ranks[VarKind::Symbol], ranks[VarKind::Dimension], |
| 223 | ranks[VarKind::Level]) {} |
| 224 | |
| 225 | bool operator==(Ranks const &other) const; |
| 226 | bool operator!=(Ranks const &other) const { return !(*this == other); } |
| 227 | |
| 228 | constexpr unsigned getRank(VarKind vk) const { return impl[to_index(vk)]; } |
| 229 | constexpr unsigned getSymRank() const { return getRank(vk: VarKind::Symbol); } |
| 230 | constexpr unsigned getDimRank() const { return getRank(vk: VarKind::Dimension); } |
| 231 | constexpr unsigned getLvlRank() const { return getRank(vk: VarKind::Level); } |
| 232 | |
| 233 | [[nodiscard]] constexpr bool isValid(Var var) const { |
| 234 | return var.getNum() < getRank(vk: var.getKind()); |
| 235 | } |
| 236 | [[nodiscard]] bool isValid(DimLvlExpr expr) const; |
| 237 | }; |
| 238 | static_assert(IsZeroCostAbstraction<Ranks>); |
| 239 | |
| 240 | //===----------------------------------------------------------------------===// |
| 241 | /// Efficient representation of a set of `Var`. |
| 242 | class VarSet final { |
| 243 | VarKindArray<llvm::SmallBitVector> impl; |
| 244 | |
| 245 | public: |
| 246 | explicit VarSet(Ranks const &ranks); |
| 247 | |
| 248 | unsigned getRank(VarKind vk) const { return impl[vk].size(); } |
| 249 | unsigned getSymRank() const { return getRank(vk: VarKind::Symbol); } |
| 250 | unsigned getDimRank() const { return getRank(vk: VarKind::Dimension); } |
| 251 | unsigned getLvlRank() const { return getRank(vk: VarKind::Level); } |
| 252 | Ranks getRanks() const { |
| 253 | return Ranks(getSymRank(), getDimRank(), getLvlRank()); |
| 254 | } |
| 255 | /// For the `contains` method: if variables occurring in |
| 256 | /// the method parameter are OOB for the `VarSet`, then these methods will |
| 257 | /// always return false. |
| 258 | bool contains(Var var) const; |
| 259 | |
| 260 | /// For the `add` methods: OOB parameters cause undefined behavior. |
| 261 | /// Currently the `add` methods will raise an assertion error. |
| 262 | void add(Var var); |
| 263 | void add(VarSet const &vars); |
| 264 | void add(DimLvlExpr expr); |
| 265 | }; |
| 266 | |
| 267 | //===----------------------------------------------------------------------===// |
| 268 | /// A record of metadata for/about a variable, used by `VarEnv`. |
| 269 | /// The principal goal of this record is to enable `VarEnv` to be used for |
| 270 | /// incremental parsing; in particular, `VarInfo` allows the `Var::Num` to |
| 271 | /// remain unknown, since each record is instead identified by `VarInfo::ID`. |
| 272 | /// Therefore the `VarEnv` can freely allocate `VarInfo::ID` in whatever |
| 273 | /// order it likes, irrespective of the binding order (`Var::Num`) of the |
| 274 | /// associated variable. |
| 275 | class VarInfo final { |
| 276 | public: |
| 277 | /// Newtype for unique identifiers of `VarInfo` records, to ensure |
| 278 | /// they aren't confused with `Var::Num`. |
| 279 | enum class ID : unsigned {}; |
| 280 | |
| 281 | private: |
| 282 | StringRef name; // The bare-id used in the MLIR source. |
| 283 | llvm::SMLoc loc; // The location of the first occurence. |
| 284 | ID id; // The unique `VarInfo`-identifier. |
| 285 | std::optional<Var::Num> num; // The unique `Var`-identifier (if resolved). |
| 286 | VarKind kind; // The kind of variable. |
| 287 | |
| 288 | public: |
| 289 | constexpr VarInfo(ID id, StringRef name, llvm::SMLoc loc, VarKind vk, |
| 290 | std::optional<Var::Num> n = {}) |
| 291 | : name(name), loc(loc), id(id), num(n), kind(vk) { |
| 292 | assert(!name.empty() && "null StringRef" ); |
| 293 | assert(loc.isValid() && "null SMLoc" ); |
| 294 | assert(isWF(vk) && "unknown VarKind" ); |
| 295 | assert((!n || Var::isWF_Num(*n)) && "Var::Num is too large" ); |
| 296 | } |
| 297 | |
| 298 | constexpr StringRef getName() const { return name; } |
| 299 | constexpr llvm::SMLoc getLoc() const { return loc; } |
| 300 | Location getLocation(AsmParser &parser) const { |
| 301 | return parser.getEncodedSourceLoc(loc); |
| 302 | } |
| 303 | constexpr ID getID() const { return id; } |
| 304 | constexpr VarKind getKind() const { return kind; } |
| 305 | constexpr std::optional<Var::Num> getNum() const { return num; } |
| 306 | constexpr bool hasNum() const { return num.has_value(); } |
| 307 | void setNum(Var::Num n); |
| 308 | constexpr Var getVar() const { |
| 309 | assert(hasNum()); |
| 310 | return Var(kind, *num); |
| 311 | } |
| 312 | }; |
| 313 | |
| 314 | //===----------------------------------------------------------------------===// |
| 315 | enum class Policy { MustNot, May, Must }; |
| 316 | |
| 317 | //===----------------------------------------------------------------------===// |
| 318 | class VarEnv final { |
| 319 | /// Map from `VarKind` to the next free `Var::Num`; used by `bindVar`. |
| 320 | VarKindArray<Var::Num> nextNum; |
| 321 | /// Map from `VarInfo::ID` to shared storage for the actual `VarInfo` objects. |
| 322 | SmallVector<VarInfo> vars; |
| 323 | /// Map from variable names to their `VarInfo::ID`. |
| 324 | llvm::StringMap<VarInfo::ID> ids; |
| 325 | |
| 326 | VarInfo::ID nextID() const { return static_cast<VarInfo::ID>(vars.size()); } |
| 327 | |
| 328 | public: |
| 329 | VarEnv() : nextNum(0) {} |
| 330 | |
| 331 | /// Gets the underlying storage for the `VarInfo` identified by |
| 332 | /// the `VarInfo::ID`. |
| 333 | /// |
| 334 | /// NOTE: The returned reference can become dangling if the `VarEnv` |
| 335 | /// object is mutated during the lifetime of the pointer. Therefore, |
| 336 | /// client code should not store the reference nor otherwise allow it |
| 337 | /// to live too long. |
| 338 | VarInfo const &access(VarInfo::ID id) const { |
| 339 | // `SmallVector::operator[]` already asserts the index is in-bounds. |
| 340 | return vars[llvm::to_underlying(E: id)]; |
| 341 | } |
| 342 | VarInfo const *access(std::optional<VarInfo::ID> oid) const { |
| 343 | return oid ? &access(id: *oid) : nullptr; |
| 344 | } |
| 345 | |
| 346 | private: |
| 347 | VarInfo &access(VarInfo::ID id) { |
| 348 | return const_cast<VarInfo &>(std::as_const(t&: *this).access(id)); |
| 349 | } |
| 350 | VarInfo *access(std::optional<VarInfo::ID> oid) { |
| 351 | return const_cast<VarInfo *>(std::as_const(t&: *this).access(oid)); |
| 352 | } |
| 353 | |
| 354 | public: |
| 355 | /// Looks up the variable with the given name. |
| 356 | std::optional<VarInfo::ID> lookup(StringRef name) const; |
| 357 | |
| 358 | /// Creates a new currently-unbound variable. When a variable |
| 359 | /// of that name already exists: if `verifyUsage` is true, then will assert |
| 360 | /// that the variable has the same kind and a consistent location; otherwise, |
| 361 | /// when `verifyUsage` is false, this is a noop. Returns the identifier |
| 362 | /// for the variable with the given name, and a bool indicating whether |
| 363 | /// a new variable was created. |
| 364 | std::optional<std::pair<VarInfo::ID, bool>> |
| 365 | create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage = false); |
| 366 | |
| 367 | /// Looks up or creates a variable according to the given |
| 368 | /// `Policy`. Returns nullopt in one of two circumstances: |
| 369 | /// (1) the policy says we `Must` create, yet the variable already exists; |
| 370 | /// (2) the policy says we `MustNot` create, yet no such variable exists. |
| 371 | /// Otherwise, if the variable already exists then it is validated against |
| 372 | /// the given kind and location to ensure consistency. |
| 373 | std::optional<std::pair<VarInfo::ID, bool>> |
| 374 | lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc, |
| 375 | VarKind vk); |
| 376 | |
| 377 | /// Binds the given variable to the next free `Var::Num` for its `VarKind`. |
| 378 | Var bindVar(VarInfo::ID id); |
| 379 | |
| 380 | /// Creates a new variable of the given kind and immediately binds it. |
| 381 | /// This should only be used whenever the variable is known to be unused |
| 382 | /// and therefore does not have a name. |
| 383 | Var bindUnusedVar(VarKind vk); |
| 384 | |
| 385 | InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const; |
| 386 | |
| 387 | /// Returns the current ranks of bound variables. This method should |
| 388 | /// only be used after the environment is "finished", since binding new |
| 389 | /// variables will (semantically) invalidate any previously returned `Ranks`. |
| 390 | Ranks getRanks() const { return Ranks(nextNum); } |
| 391 | |
| 392 | /// Gets the `Var` identified by the `VarInfo::ID`, raising an assertion |
| 393 | /// failure if the variable is not bound. |
| 394 | Var getVar(VarInfo::ID id) const { return access(id).getVar(); } |
| 395 | }; |
| 396 | |
| 397 | //===----------------------------------------------------------------------===// |
| 398 | |
| 399 | } // namespace ir_detail |
| 400 | } // namespace sparse_tensor |
| 401 | } // namespace mlir |
| 402 | |
| 403 | #endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H |
| 404 | |