1//===- Var.h ----------------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
10#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
11
12#include "TemplateExtras.h"
13
14#include "mlir/IR/OpImplementation.h"
15#include "llvm/ADT/EnumeratedArray.h"
16#include "llvm/ADT/STLForwardCompat.h"
17#include "llvm/ADT/SmallBitVector.h"
18#include "llvm/ADT/StringMap.h"
19
20namespace mlir {
21namespace sparse_tensor {
22namespace ir_detail {
23
24//===----------------------------------------------------------------------===//
25/// The three kinds of variables that `Var` can be.
26///
27/// NOTE: The numerical values used to represent this enum should be
28/// treated as an implementation detail, not as part of the API. In the
29/// API below we use the canonical ordering `{Symbol,Dimension,Level}` even
30/// though that does not agree with the numerical ordering of the numerical
31/// representation.
32enum class VarKind { Symbol = 1, Dimension = 0, Level = 2 };
33
34[[nodiscard]] constexpr bool isWF(VarKind vk) {
35 const auto vk_ = llvm::to_underlying(E: vk);
36 return 0 <= vk_ && vk_ <= 2;
37}
38
39/// Gets the ASCII character used as the prefix when printing `Var`.
40constexpr char toChar(VarKind vk) {
41 // If `isWF(vk)` then this computation's intermediate results are always
42 // in the range [-44..126] (where that lower bound is under worst-case
43 // rearranging of the expression); and `int_fast8_t` is the fastest type
44 // which can support that range without over-/underflow.
45 const auto vk_ = static_cast<int_fast8_t>(llvm::to_underlying(E: vk));
46 return static_cast<char>(100 + vk_ * (26 - vk_ * 11));
47}
48static_assert(toChar(vk: VarKind::Symbol) == 's' &&
49 toChar(vk: VarKind::Dimension) == 'd' &&
50 toChar(vk: VarKind::Level) == 'l');
51
52//===----------------------------------------------------------------------===//
53/// The type of arrays indexed by `VarKind`.
54template <typename T>
55using VarKindArray = llvm::EnumeratedArray<T, VarKind, VarKind::Level>;
56
57//===----------------------------------------------------------------------===//
58/// A concrete variable, to be used in our variant of `AffineExpr`.
59/// Client-facing class for `VarKind` + `Var::Num` pairs, with RTTI
60/// support for subclasses with a fixed `VarKind`.
61class Var {
62public:
63 /// Typedef for the type of variable numbers.
64 using Num = unsigned;
65
66private:
67 /// Typedef for the underlying storage of `Var::Impl`.
68 using Storage = unsigned;
69
70 /// The largest `Var::Num` supported by `Var`/`Var::Impl`/`Var::Storage`.
71 /// Two low-order bits are reserved for storing the `VarKind`,
72 /// and one high-order bit is reserved for future use (e.g., to support
73 /// `DenseMapInfo<Var>` while maintaining the usual numeric values for
74 /// "empty" and "tombstone").
75 static constexpr Num kMaxNum =
76 static_cast<Num>(std::numeric_limits<Storage>::max() >> 3);
77
78public:
79 /// Checks whether the number would be accepted by `Var(VarKind,Var::Num)`.
80 //
81 // This must be public for `VarInfo` to use it (whereas we don't want
82 // to expose the `impl` field via friendship).
83 [[nodiscard]] static constexpr bool isWF_Num(Num n) { return n <= kMaxNum; }
84
85protected:
86 /// The underlying implementation of `Var`. Note that this must be kept
87 /// distinct from `Var` itself, since we want to ensure that the RTTI
88 /// methods will select the `U(Var::Impl)` ctor rather than selecting
89 /// the `U(Var::Num)` ctor.
90 class Impl final {
91 Storage data;
92
93 public:
94 constexpr Impl(VarKind vk, Num n)
95 : data((static_cast<Storage>(n) << 2) |
96 static_cast<Storage>(llvm::to_underlying(E: vk))) {
97 assert(isWF(vk) && "unknown VarKind");
98 assert(isWF_Num(n) && "Var::Num is too large");
99 }
100 constexpr bool operator==(Impl other) const { return data == other.data; }
101 constexpr bool operator!=(Impl other) const { return !(*this == other); }
102 constexpr VarKind getKind() const { return static_cast<VarKind>(data & 3); }
103 constexpr Num getNum() const { return static_cast<Num>(data >> 2); }
104 };
105 static_assert(IsZeroCostAbstraction<Impl>);
106
107private:
108 Impl impl;
109
110protected:
111 /// Protected ctor for the RTTI methods to use.
112 constexpr explicit Var(Impl impl) : impl(impl) {}
113
114public:
115 constexpr Var(VarKind vk, Num n) : impl(Impl(vk, n)) {}
116 Var(AffineSymbolExpr sym) : Var(VarKind::Symbol, sym.getPosition()) {}
117 Var(VarKind vk, AffineDimExpr var) : Var(vk, var.getPosition()) {
118 assert(vk != VarKind::Symbol);
119 }
120
121 constexpr bool operator==(Var other) const { return impl == other.impl; }
122 constexpr bool operator!=(Var other) const { return !(*this == other); }
123
124 constexpr VarKind getKind() const { return impl.getKind(); }
125 constexpr Num getNum() const { return impl.getNum(); }
126
127 template <typename U>
128 constexpr bool isa() const;
129 template <typename U>
130 constexpr U cast() const;
131 template <typename U>
132 constexpr std::optional<U> dyn_cast() const;
133
134 std::string str() const;
135 void print(llvm::raw_ostream &os) const;
136 void print(AsmPrinter &printer) const;
137 void dump() const;
138};
139static_assert(IsZeroCostAbstraction<Var>);
140
141class SymVar final : public Var {
142 using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
143public:
144 static constexpr VarKind Kind = VarKind::Symbol;
145 static constexpr bool classof(Var const *var) {
146 return var->getKind() == Kind;
147 }
148 constexpr SymVar(Num sym) : Var(Kind, sym) {}
149 SymVar(AffineSymbolExpr symExpr) : Var(symExpr) {}
150};
151static_assert(IsZeroCostAbstraction<SymVar>);
152
153class DimVar final : public Var {
154 using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
155public:
156 static constexpr VarKind Kind = VarKind::Dimension;
157 static constexpr bool classof(Var const *var) {
158 return var->getKind() == Kind;
159 }
160 constexpr DimVar(Num dim) : Var(Kind, dim) {}
161 DimVar(AffineDimExpr dimExpr) : Var(Kind, dimExpr) {}
162};
163static_assert(IsZeroCostAbstraction<DimVar>);
164
165class LvlVar final : public Var {
166 using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
167public:
168 static constexpr VarKind Kind = VarKind::Level;
169 static constexpr bool classof(Var const *var) {
170 return var->getKind() == Kind;
171 }
172 constexpr LvlVar(Num lvl) : Var(Kind, lvl) {}
173 LvlVar(AffineDimExpr lvlExpr) : Var(Kind, lvlExpr) {}
174};
175static_assert(IsZeroCostAbstraction<LvlVar>);
176
177template <typename U>
178constexpr bool Var::isa() const {
179 if constexpr (std::is_same_v<U, SymVar>)
180 return getKind() == VarKind::Symbol;
181 if constexpr (std::is_same_v<U, DimVar>)
182 return getKind() == VarKind::Dimension;
183 if constexpr (std::is_same_v<U, LvlVar>)
184 return getKind() == VarKind::Level;
185}
186
187template <typename U>
188constexpr U Var::cast() const {
189 assert(isa<U>());
190 // NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)`
191 return U(impl);
192}
193
194template <typename U>
195constexpr std::optional<U> Var::dyn_cast() const {
196 // NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)`
197 return isa<U>() ? std::make_optional(U(impl)) : std::nullopt;
198}
199
200//===----------------------------------------------------------------------===//
201// Forward-decl so that we can declare methods of `Ranks` and `VarSet`.
202class DimLvlExpr;
203
204//===----------------------------------------------------------------------===//
205class Ranks final {
206 // Not using `VarKindArray` since `EnumeratedArray` doesn't support constexpr.
207 unsigned impl[3];
208
209 static constexpr unsigned to_index(VarKind vk) {
210 assert(isWF(vk) && "unknown VarKind");
211 return static_cast<unsigned>(llvm::to_underlying(E: vk));
212 }
213
214public:
215 constexpr Ranks(unsigned symRank, unsigned dimRank, unsigned lvlRank)
216 : impl() {
217 impl[to_index(vk: VarKind::Symbol)] = symRank;
218 impl[to_index(vk: VarKind::Dimension)] = dimRank;
219 impl[to_index(vk: VarKind::Level)] = lvlRank;
220 }
221 Ranks(VarKindArray<unsigned> const &ranks)
222 : Ranks(ranks[VarKind::Symbol], ranks[VarKind::Dimension],
223 ranks[VarKind::Level]) {}
224
225 bool operator==(Ranks const &other) const;
226 bool operator!=(Ranks const &other) const { return !(*this == other); }
227
228 constexpr unsigned getRank(VarKind vk) const { return impl[to_index(vk)]; }
229 constexpr unsigned getSymRank() const { return getRank(vk: VarKind::Symbol); }
230 constexpr unsigned getDimRank() const { return getRank(vk: VarKind::Dimension); }
231 constexpr unsigned getLvlRank() const { return getRank(vk: VarKind::Level); }
232
233 [[nodiscard]] constexpr bool isValid(Var var) const {
234 return var.getNum() < getRank(vk: var.getKind());
235 }
236 [[nodiscard]] bool isValid(DimLvlExpr expr) const;
237};
238static_assert(IsZeroCostAbstraction<Ranks>);
239
240//===----------------------------------------------------------------------===//
241/// Efficient representation of a set of `Var`.
242class VarSet final {
243 VarKindArray<llvm::SmallBitVector> impl;
244
245public:
246 explicit VarSet(Ranks const &ranks);
247
248 unsigned getRank(VarKind vk) const { return impl[vk].size(); }
249 unsigned getSymRank() const { return getRank(vk: VarKind::Symbol); }
250 unsigned getDimRank() const { return getRank(vk: VarKind::Dimension); }
251 unsigned getLvlRank() const { return getRank(vk: VarKind::Level); }
252 Ranks getRanks() const {
253 return Ranks(getSymRank(), getDimRank(), getLvlRank());
254 }
255 /// For the `contains` method: if variables occurring in
256 /// the method parameter are OOB for the `VarSet`, then these methods will
257 /// always return false.
258 bool contains(Var var) const;
259
260 /// For the `add` methods: OOB parameters cause undefined behavior.
261 /// Currently the `add` methods will raise an assertion error.
262 void add(Var var);
263 void add(VarSet const &vars);
264 void add(DimLvlExpr expr);
265};
266
267//===----------------------------------------------------------------------===//
268/// A record of metadata for/about a variable, used by `VarEnv`.
269/// The principal goal of this record is to enable `VarEnv` to be used for
270/// incremental parsing; in particular, `VarInfo` allows the `Var::Num` to
271/// remain unknown, since each record is instead identified by `VarInfo::ID`.
272/// Therefore the `VarEnv` can freely allocate `VarInfo::ID` in whatever
273/// order it likes, irrespective of the binding order (`Var::Num`) of the
274/// associated variable.
275class VarInfo final {
276public:
277 /// Newtype for unique identifiers of `VarInfo` records, to ensure
278 /// they aren't confused with `Var::Num`.
279 enum class ID : unsigned {};
280
281private:
282 StringRef name; // The bare-id used in the MLIR source.
283 llvm::SMLoc loc; // The location of the first occurence.
284 ID id; // The unique `VarInfo`-identifier.
285 std::optional<Var::Num> num; // The unique `Var`-identifier (if resolved).
286 VarKind kind; // The kind of variable.
287
288public:
289 constexpr VarInfo(ID id, StringRef name, llvm::SMLoc loc, VarKind vk,
290 std::optional<Var::Num> n = {})
291 : name(name), loc(loc), id(id), num(n), kind(vk) {
292 assert(!name.empty() && "null StringRef");
293 assert(loc.isValid() && "null SMLoc");
294 assert(isWF(vk) && "unknown VarKind");
295 assert((!n || Var::isWF_Num(*n)) && "Var::Num is too large");
296 }
297
298 constexpr StringRef getName() const { return name; }
299 constexpr llvm::SMLoc getLoc() const { return loc; }
300 Location getLocation(AsmParser &parser) const {
301 return parser.getEncodedSourceLoc(loc);
302 }
303 constexpr ID getID() const { return id; }
304 constexpr VarKind getKind() const { return kind; }
305 constexpr std::optional<Var::Num> getNum() const { return num; }
306 constexpr bool hasNum() const { return num.has_value(); }
307 void setNum(Var::Num n);
308 constexpr Var getVar() const {
309 assert(hasNum());
310 return Var(kind, *num);
311 }
312};
313
314//===----------------------------------------------------------------------===//
315enum class Policy { MustNot, May, Must };
316
317//===----------------------------------------------------------------------===//
318class VarEnv final {
319 /// Map from `VarKind` to the next free `Var::Num`; used by `bindVar`.
320 VarKindArray<Var::Num> nextNum;
321 /// Map from `VarInfo::ID` to shared storage for the actual `VarInfo` objects.
322 SmallVector<VarInfo> vars;
323 /// Map from variable names to their `VarInfo::ID`.
324 llvm::StringMap<VarInfo::ID> ids;
325
326 VarInfo::ID nextID() const { return static_cast<VarInfo::ID>(vars.size()); }
327
328public:
329 VarEnv() : nextNum(0) {}
330
331 /// Gets the underlying storage for the `VarInfo` identified by
332 /// the `VarInfo::ID`.
333 ///
334 /// NOTE: The returned reference can become dangling if the `VarEnv`
335 /// object is mutated during the lifetime of the pointer. Therefore,
336 /// client code should not store the reference nor otherwise allow it
337 /// to live too long.
338 VarInfo const &access(VarInfo::ID id) const {
339 // `SmallVector::operator[]` already asserts the index is in-bounds.
340 return vars[llvm::to_underlying(E: id)];
341 }
342 VarInfo const *access(std::optional<VarInfo::ID> oid) const {
343 return oid ? &access(id: *oid) : nullptr;
344 }
345
346private:
347 VarInfo &access(VarInfo::ID id) {
348 return const_cast<VarInfo &>(std::as_const(t&: *this).access(id));
349 }
350 VarInfo *access(std::optional<VarInfo::ID> oid) {
351 return const_cast<VarInfo *>(std::as_const(t&: *this).access(oid));
352 }
353
354public:
355 /// Looks up the variable with the given name.
356 std::optional<VarInfo::ID> lookup(StringRef name) const;
357
358 /// Creates a new currently-unbound variable. When a variable
359 /// of that name already exists: if `verifyUsage` is true, then will assert
360 /// that the variable has the same kind and a consistent location; otherwise,
361 /// when `verifyUsage` is false, this is a noop. Returns the identifier
362 /// for the variable with the given name, and a bool indicating whether
363 /// a new variable was created.
364 std::optional<std::pair<VarInfo::ID, bool>>
365 create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage = false);
366
367 /// Looks up or creates a variable according to the given
368 /// `Policy`. Returns nullopt in one of two circumstances:
369 /// (1) the policy says we `Must` create, yet the variable already exists;
370 /// (2) the policy says we `MustNot` create, yet no such variable exists.
371 /// Otherwise, if the variable already exists then it is validated against
372 /// the given kind and location to ensure consistency.
373 std::optional<std::pair<VarInfo::ID, bool>>
374 lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc,
375 VarKind vk);
376
377 /// Binds the given variable to the next free `Var::Num` for its `VarKind`.
378 Var bindVar(VarInfo::ID id);
379
380 /// Creates a new variable of the given kind and immediately binds it.
381 /// This should only be used whenever the variable is known to be unused
382 /// and therefore does not have a name.
383 Var bindUnusedVar(VarKind vk);
384
385 InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const;
386
387 /// Returns the current ranks of bound variables. This method should
388 /// only be used after the environment is "finished", since binding new
389 /// variables will (semantically) invalidate any previously returned `Ranks`.
390 Ranks getRanks() const { return Ranks(nextNum); }
391
392 /// Gets the `Var` identified by the `VarInfo::ID`, raising an assertion
393 /// failure if the variable is not bound.
394 Var getVar(VarInfo::ID id) const { return access(id).getVar(); }
395};
396
397//===----------------------------------------------------------------------===//
398
399} // namespace ir_detail
400} // namespace sparse_tensor
401} // namespace mlir
402
403#endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
404

source code of mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h