1 | //===- Var.h ----------------------------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H |
10 | #define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H |
11 | |
12 | #include "TemplateExtras.h" |
13 | |
14 | #include "mlir/IR/OpImplementation.h" |
15 | #include "llvm/ADT/EnumeratedArray.h" |
16 | #include "llvm/ADT/STLForwardCompat.h" |
17 | #include "llvm/ADT/SmallBitVector.h" |
18 | #include "llvm/ADT/StringMap.h" |
19 | |
20 | namespace mlir { |
21 | namespace sparse_tensor { |
22 | namespace ir_detail { |
23 | |
24 | //===----------------------------------------------------------------------===// |
25 | /// The three kinds of variables that `Var` can be. |
26 | /// |
27 | /// NOTE: The numerical values used to represent this enum should be |
28 | /// treated as an implementation detail, not as part of the API. In the |
29 | /// API below we use the canonical ordering `{Symbol,Dimension,Level}` even |
30 | /// though that does not agree with the numerical ordering of the numerical |
31 | /// representation. |
32 | enum class VarKind { Symbol = 1, Dimension = 0, Level = 2 }; |
33 | |
34 | [[nodiscard]] constexpr bool isWF(VarKind vk) { |
35 | const auto vk_ = llvm::to_underlying(E: vk); |
36 | return 0 <= vk_ && vk_ <= 2; |
37 | } |
38 | |
39 | /// Gets the ASCII character used as the prefix when printing `Var`. |
40 | constexpr char toChar(VarKind vk) { |
41 | // If `isWF(vk)` then this computation's intermediate results are always |
42 | // in the range [-44..126] (where that lower bound is under worst-case |
43 | // rearranging of the expression); and `int_fast8_t` is the fastest type |
44 | // which can support that range without over-/underflow. |
45 | const auto vk_ = static_cast<int_fast8_t>(llvm::to_underlying(E: vk)); |
46 | return static_cast<char>(100 + vk_ * (26 - vk_ * 11)); |
47 | } |
48 | static_assert(toChar(vk: VarKind::Symbol) == 's' && |
49 | toChar(vk: VarKind::Dimension) == 'd' && |
50 | toChar(vk: VarKind::Level) == 'l'); |
51 | |
52 | //===----------------------------------------------------------------------===// |
53 | /// The type of arrays indexed by `VarKind`. |
54 | template <typename T> |
55 | using VarKindArray = llvm::EnumeratedArray<T, VarKind, VarKind::Level>; |
56 | |
57 | //===----------------------------------------------------------------------===// |
58 | /// A concrete variable, to be used in our variant of `AffineExpr`. |
59 | /// Client-facing class for `VarKind` + `Var::Num` pairs, with RTTI |
60 | /// support for subclasses with a fixed `VarKind`. |
61 | class Var { |
62 | public: |
63 | /// Typedef for the type of variable numbers. |
64 | using Num = unsigned; |
65 | |
66 | private: |
67 | /// Typedef for the underlying storage of `Var::Impl`. |
68 | using Storage = unsigned; |
69 | |
70 | /// The largest `Var::Num` supported by `Var`/`Var::Impl`/`Var::Storage`. |
71 | /// Two low-order bits are reserved for storing the `VarKind`, |
72 | /// and one high-order bit is reserved for future use (e.g., to support |
73 | /// `DenseMapInfo<Var>` while maintaining the usual numeric values for |
74 | /// "empty" and "tombstone"). |
75 | static constexpr Num kMaxNum = |
76 | static_cast<Num>(std::numeric_limits<Storage>::max() >> 3); |
77 | |
78 | public: |
79 | /// Checks whether the number would be accepted by `Var(VarKind,Var::Num)`. |
80 | // |
81 | // This must be public for `VarInfo` to use it (whereas we don't want |
82 | // to expose the `impl` field via friendship). |
83 | [[nodiscard]] static constexpr bool isWF_Num(Num n) { return n <= kMaxNum; } |
84 | |
85 | protected: |
86 | /// The underlying implementation of `Var`. Note that this must be kept |
87 | /// distinct from `Var` itself, since we want to ensure that the RTTI |
88 | /// methods will select the `U(Var::Impl)` ctor rather than selecting |
89 | /// the `U(Var::Num)` ctor. |
90 | class Impl final { |
91 | Storage data; |
92 | |
93 | public: |
94 | constexpr Impl(VarKind vk, Num n) |
95 | : data((static_cast<Storage>(n) << 2) | |
96 | static_cast<Storage>(llvm::to_underlying(E: vk))) { |
97 | assert(isWF(vk) && "unknown VarKind" ); |
98 | assert(isWF_Num(n) && "Var::Num is too large" ); |
99 | } |
100 | constexpr bool operator==(Impl other) const { return data == other.data; } |
101 | constexpr bool operator!=(Impl other) const { return !(*this == other); } |
102 | constexpr VarKind getKind() const { return static_cast<VarKind>(data & 3); } |
103 | constexpr Num getNum() const { return static_cast<Num>(data >> 2); } |
104 | }; |
105 | static_assert(IsZeroCostAbstraction<Impl>); |
106 | |
107 | private: |
108 | Impl impl; |
109 | |
110 | protected: |
111 | /// Protected ctor for the RTTI methods to use. |
112 | constexpr explicit Var(Impl impl) : impl(impl) {} |
113 | |
114 | public: |
115 | constexpr Var(VarKind vk, Num n) : impl(Impl(vk, n)) {} |
116 | Var(AffineSymbolExpr sym) : Var(VarKind::Symbol, sym.getPosition()) {} |
117 | Var(VarKind vk, AffineDimExpr var) : Var(vk, var.getPosition()) { |
118 | assert(vk != VarKind::Symbol); |
119 | } |
120 | |
121 | constexpr bool operator==(Var other) const { return impl == other.impl; } |
122 | constexpr bool operator!=(Var other) const { return !(*this == other); } |
123 | |
124 | constexpr VarKind getKind() const { return impl.getKind(); } |
125 | constexpr Num getNum() const { return impl.getNum(); } |
126 | |
127 | template <typename U> |
128 | constexpr bool isa() const; |
129 | template <typename U> |
130 | constexpr U cast() const; |
131 | template <typename U> |
132 | constexpr std::optional<U> dyn_cast() const; |
133 | |
134 | std::string str() const; |
135 | void print(llvm::raw_ostream &os) const; |
136 | void print(AsmPrinter &printer) const; |
137 | void dump() const; |
138 | }; |
139 | static_assert(IsZeroCostAbstraction<Var>); |
140 | |
141 | class SymVar final : public Var { |
142 | using Var::Var; // inherit `Var(Impl)` ctor for RTTI use. |
143 | public: |
144 | static constexpr VarKind Kind = VarKind::Symbol; |
145 | static constexpr bool classof(Var const *var) { |
146 | return var->getKind() == Kind; |
147 | } |
148 | constexpr SymVar(Num sym) : Var(Kind, sym) {} |
149 | SymVar(AffineSymbolExpr symExpr) : Var(symExpr) {} |
150 | }; |
151 | static_assert(IsZeroCostAbstraction<SymVar>); |
152 | |
153 | class DimVar final : public Var { |
154 | using Var::Var; // inherit `Var(Impl)` ctor for RTTI use. |
155 | public: |
156 | static constexpr VarKind Kind = VarKind::Dimension; |
157 | static constexpr bool classof(Var const *var) { |
158 | return var->getKind() == Kind; |
159 | } |
160 | constexpr DimVar(Num dim) : Var(Kind, dim) {} |
161 | DimVar(AffineDimExpr dimExpr) : Var(Kind, dimExpr) {} |
162 | }; |
163 | static_assert(IsZeroCostAbstraction<DimVar>); |
164 | |
165 | class LvlVar final : public Var { |
166 | using Var::Var; // inherit `Var(Impl)` ctor for RTTI use. |
167 | public: |
168 | static constexpr VarKind Kind = VarKind::Level; |
169 | static constexpr bool classof(Var const *var) { |
170 | return var->getKind() == Kind; |
171 | } |
172 | constexpr LvlVar(Num lvl) : Var(Kind, lvl) {} |
173 | LvlVar(AffineDimExpr lvlExpr) : Var(Kind, lvlExpr) {} |
174 | }; |
175 | static_assert(IsZeroCostAbstraction<LvlVar>); |
176 | |
177 | template <typename U> |
178 | constexpr bool Var::isa() const { |
179 | if constexpr (std::is_same_v<U, SymVar>) |
180 | return getKind() == VarKind::Symbol; |
181 | if constexpr (std::is_same_v<U, DimVar>) |
182 | return getKind() == VarKind::Dimension; |
183 | if constexpr (std::is_same_v<U, LvlVar>) |
184 | return getKind() == VarKind::Level; |
185 | } |
186 | |
187 | template <typename U> |
188 | constexpr U Var::cast() const { |
189 | assert(isa<U>()); |
190 | // NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)` |
191 | return U(impl); |
192 | } |
193 | |
194 | template <typename U> |
195 | constexpr std::optional<U> Var::dyn_cast() const { |
196 | // NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)` |
197 | return isa<U>() ? std::make_optional(U(impl)) : std::nullopt; |
198 | } |
199 | |
200 | //===----------------------------------------------------------------------===// |
201 | // Forward-decl so that we can declare methods of `Ranks` and `VarSet`. |
202 | class DimLvlExpr; |
203 | |
204 | //===----------------------------------------------------------------------===// |
205 | class Ranks final { |
206 | // Not using `VarKindArray` since `EnumeratedArray` doesn't support constexpr. |
207 | unsigned impl[3]; |
208 | |
209 | static constexpr unsigned to_index(VarKind vk) { |
210 | assert(isWF(vk) && "unknown VarKind" ); |
211 | return static_cast<unsigned>(llvm::to_underlying(E: vk)); |
212 | } |
213 | |
214 | public: |
215 | constexpr Ranks(unsigned symRank, unsigned dimRank, unsigned lvlRank) |
216 | : impl() { |
217 | impl[to_index(vk: VarKind::Symbol)] = symRank; |
218 | impl[to_index(vk: VarKind::Dimension)] = dimRank; |
219 | impl[to_index(vk: VarKind::Level)] = lvlRank; |
220 | } |
221 | Ranks(VarKindArray<unsigned> const &ranks) |
222 | : Ranks(ranks[VarKind::Symbol], ranks[VarKind::Dimension], |
223 | ranks[VarKind::Level]) {} |
224 | |
225 | bool operator==(Ranks const &other) const; |
226 | bool operator!=(Ranks const &other) const { return !(*this == other); } |
227 | |
228 | constexpr unsigned getRank(VarKind vk) const { return impl[to_index(vk)]; } |
229 | constexpr unsigned getSymRank() const { return getRank(vk: VarKind::Symbol); } |
230 | constexpr unsigned getDimRank() const { return getRank(vk: VarKind::Dimension); } |
231 | constexpr unsigned getLvlRank() const { return getRank(vk: VarKind::Level); } |
232 | |
233 | [[nodiscard]] constexpr bool isValid(Var var) const { |
234 | return var.getNum() < getRank(vk: var.getKind()); |
235 | } |
236 | [[nodiscard]] bool isValid(DimLvlExpr expr) const; |
237 | }; |
238 | static_assert(IsZeroCostAbstraction<Ranks>); |
239 | |
240 | //===----------------------------------------------------------------------===// |
241 | /// Efficient representation of a set of `Var`. |
242 | class VarSet final { |
243 | VarKindArray<llvm::SmallBitVector> impl; |
244 | |
245 | public: |
246 | explicit VarSet(Ranks const &ranks); |
247 | |
248 | unsigned getRank(VarKind vk) const { return impl[vk].size(); } |
249 | unsigned getSymRank() const { return getRank(vk: VarKind::Symbol); } |
250 | unsigned getDimRank() const { return getRank(vk: VarKind::Dimension); } |
251 | unsigned getLvlRank() const { return getRank(vk: VarKind::Level); } |
252 | Ranks getRanks() const { |
253 | return Ranks(getSymRank(), getDimRank(), getLvlRank()); |
254 | } |
255 | /// For the `contains` method: if variables occurring in |
256 | /// the method parameter are OOB for the `VarSet`, then these methods will |
257 | /// always return false. |
258 | bool contains(Var var) const; |
259 | |
260 | /// For the `add` methods: OOB parameters cause undefined behavior. |
261 | /// Currently the `add` methods will raise an assertion error. |
262 | void add(Var var); |
263 | void add(VarSet const &vars); |
264 | void add(DimLvlExpr expr); |
265 | }; |
266 | |
267 | //===----------------------------------------------------------------------===// |
268 | /// A record of metadata for/about a variable, used by `VarEnv`. |
269 | /// The principal goal of this record is to enable `VarEnv` to be used for |
270 | /// incremental parsing; in particular, `VarInfo` allows the `Var::Num` to |
271 | /// remain unknown, since each record is instead identified by `VarInfo::ID`. |
272 | /// Therefore the `VarEnv` can freely allocate `VarInfo::ID` in whatever |
273 | /// order it likes, irrespective of the binding order (`Var::Num`) of the |
274 | /// associated variable. |
275 | class VarInfo final { |
276 | public: |
277 | /// Newtype for unique identifiers of `VarInfo` records, to ensure |
278 | /// they aren't confused with `Var::Num`. |
279 | enum class ID : unsigned {}; |
280 | |
281 | private: |
282 | StringRef name; // The bare-id used in the MLIR source. |
283 | llvm::SMLoc loc; // The location of the first occurence. |
284 | ID id; // The unique `VarInfo`-identifier. |
285 | std::optional<Var::Num> num; // The unique `Var`-identifier (if resolved). |
286 | VarKind kind; // The kind of variable. |
287 | |
288 | public: |
289 | constexpr VarInfo(ID id, StringRef name, llvm::SMLoc loc, VarKind vk, |
290 | std::optional<Var::Num> n = {}) |
291 | : name(name), loc(loc), id(id), num(n), kind(vk) { |
292 | assert(!name.empty() && "null StringRef" ); |
293 | assert(loc.isValid() && "null SMLoc" ); |
294 | assert(isWF(vk) && "unknown VarKind" ); |
295 | assert((!n || Var::isWF_Num(*n)) && "Var::Num is too large" ); |
296 | } |
297 | |
298 | constexpr StringRef getName() const { return name; } |
299 | constexpr llvm::SMLoc getLoc() const { return loc; } |
300 | Location getLocation(AsmParser &parser) const { |
301 | return parser.getEncodedSourceLoc(loc); |
302 | } |
303 | constexpr ID getID() const { return id; } |
304 | constexpr VarKind getKind() const { return kind; } |
305 | constexpr std::optional<Var::Num> getNum() const { return num; } |
306 | constexpr bool hasNum() const { return num.has_value(); } |
307 | void setNum(Var::Num n); |
308 | constexpr Var getVar() const { |
309 | assert(hasNum()); |
310 | return Var(kind, *num); |
311 | } |
312 | }; |
313 | |
314 | //===----------------------------------------------------------------------===// |
315 | enum class Policy { MustNot, May, Must }; |
316 | |
317 | //===----------------------------------------------------------------------===// |
318 | class VarEnv final { |
319 | /// Map from `VarKind` to the next free `Var::Num`; used by `bindVar`. |
320 | VarKindArray<Var::Num> nextNum; |
321 | /// Map from `VarInfo::ID` to shared storage for the actual `VarInfo` objects. |
322 | SmallVector<VarInfo> vars; |
323 | /// Map from variable names to their `VarInfo::ID`. |
324 | llvm::StringMap<VarInfo::ID> ids; |
325 | |
326 | VarInfo::ID nextID() const { return static_cast<VarInfo::ID>(vars.size()); } |
327 | |
328 | public: |
329 | VarEnv() : nextNum(0) {} |
330 | |
331 | /// Gets the underlying storage for the `VarInfo` identified by |
332 | /// the `VarInfo::ID`. |
333 | /// |
334 | /// NOTE: The returned reference can become dangling if the `VarEnv` |
335 | /// object is mutated during the lifetime of the pointer. Therefore, |
336 | /// client code should not store the reference nor otherwise allow it |
337 | /// to live too long. |
338 | VarInfo const &access(VarInfo::ID id) const { |
339 | // `SmallVector::operator[]` already asserts the index is in-bounds. |
340 | return vars[llvm::to_underlying(E: id)]; |
341 | } |
342 | VarInfo const *access(std::optional<VarInfo::ID> oid) const { |
343 | return oid ? &access(id: *oid) : nullptr; |
344 | } |
345 | |
346 | private: |
347 | VarInfo &access(VarInfo::ID id) { |
348 | return const_cast<VarInfo &>(std::as_const(t&: *this).access(id)); |
349 | } |
350 | VarInfo *access(std::optional<VarInfo::ID> oid) { |
351 | return const_cast<VarInfo *>(std::as_const(t&: *this).access(oid)); |
352 | } |
353 | |
354 | public: |
355 | /// Looks up the variable with the given name. |
356 | std::optional<VarInfo::ID> lookup(StringRef name) const; |
357 | |
358 | /// Creates a new currently-unbound variable. When a variable |
359 | /// of that name already exists: if `verifyUsage` is true, then will assert |
360 | /// that the variable has the same kind and a consistent location; otherwise, |
361 | /// when `verifyUsage` is false, this is a noop. Returns the identifier |
362 | /// for the variable with the given name, and a bool indicating whether |
363 | /// a new variable was created. |
364 | std::optional<std::pair<VarInfo::ID, bool>> |
365 | create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage = false); |
366 | |
367 | /// Looks up or creates a variable according to the given |
368 | /// `Policy`. Returns nullopt in one of two circumstances: |
369 | /// (1) the policy says we `Must` create, yet the variable already exists; |
370 | /// (2) the policy says we `MustNot` create, yet no such variable exists. |
371 | /// Otherwise, if the variable already exists then it is validated against |
372 | /// the given kind and location to ensure consistency. |
373 | std::optional<std::pair<VarInfo::ID, bool>> |
374 | lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc, |
375 | VarKind vk); |
376 | |
377 | /// Binds the given variable to the next free `Var::Num` for its `VarKind`. |
378 | Var bindVar(VarInfo::ID id); |
379 | |
380 | /// Creates a new variable of the given kind and immediately binds it. |
381 | /// This should only be used whenever the variable is known to be unused |
382 | /// and therefore does not have a name. |
383 | Var bindUnusedVar(VarKind vk); |
384 | |
385 | InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const; |
386 | |
387 | /// Returns the current ranks of bound variables. This method should |
388 | /// only be used after the environment is "finished", since binding new |
389 | /// variables will (semantically) invalidate any previously returned `Ranks`. |
390 | Ranks getRanks() const { return Ranks(nextNum); } |
391 | |
392 | /// Gets the `Var` identified by the `VarInfo::ID`, raising an assertion |
393 | /// failure if the variable is not bound. |
394 | Var getVar(VarInfo::ID id) const { return access(id).getVar(); } |
395 | }; |
396 | |
397 | //===----------------------------------------------------------------------===// |
398 | |
399 | } // namespace ir_detail |
400 | } // namespace sparse_tensor |
401 | } // namespace mlir |
402 | |
403 | #endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H |
404 | |