1 | //===- Var.cpp ------------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "Var.h" |
10 | #include "DimLvlMap.h" |
11 | |
12 | using namespace mlir; |
13 | using namespace mlir::sparse_tensor; |
14 | using namespace mlir::sparse_tensor::ir_detail; |
15 | |
16 | //===----------------------------------------------------------------------===// |
17 | // `VarKind` helpers. |
18 | //===----------------------------------------------------------------------===// |
19 | |
20 | /// For use in foreach loops. |
21 | static constexpr const VarKind everyVarKind[] = { |
22 | VarKind::Dimension, VarKind::Symbol, VarKind::Level}; |
23 | |
24 | //===----------------------------------------------------------------------===// |
25 | // `Var` implementation. |
26 | //===----------------------------------------------------------------------===// |
27 | |
28 | std::string Var::str() const { |
29 | std::string str; |
30 | llvm::raw_string_ostream os(str); |
31 | print(os); |
32 | return os.str(); |
33 | } |
34 | |
35 | void Var::print(AsmPrinter &printer) const { print(os&: printer.getStream()); } |
36 | |
37 | void Var::print(llvm::raw_ostream &os) const { |
38 | os << toChar(vk: getKind()) << getNum(); |
39 | } |
40 | |
41 | void Var::dump() const { |
42 | print(os&: llvm::errs()); |
43 | llvm::errs() << "\n" ; |
44 | } |
45 | |
46 | //===----------------------------------------------------------------------===// |
47 | // `Ranks` implementation. |
48 | //===----------------------------------------------------------------------===// |
49 | |
50 | bool Ranks::operator==(Ranks const &other) const { |
51 | for (const auto vk : everyVarKind) |
52 | if (getRank(vk) != other.getRank(vk)) |
53 | return false; |
54 | return true; |
55 | } |
56 | |
57 | bool Ranks::isValid(DimLvlExpr expr) const { |
58 | assert(expr); |
59 | // Compute the maximum identifiers for symbol-vars and dim/lvl-vars |
60 | // (each `DimLvlExpr` only allows one kind of non-symbol variable). |
61 | int64_t maxSym = -1, maxVar = -1; |
62 | mlir::getMaxDimAndSymbol<ArrayRef<AffineExpr>>(exprsList: {{expr.getAffineExpr()}}, |
63 | maxDim&: maxVar, maxSym); |
64 | return maxSym < getSymRank() && maxVar < getRank(vk: expr.getAllowedVarKind()); |
65 | } |
66 | |
67 | //===----------------------------------------------------------------------===// |
68 | // `VarSet` implementation. |
69 | //===----------------------------------------------------------------------===// |
70 | |
71 | VarSet::VarSet(Ranks const &ranks) { |
72 | for (const auto vk : everyVarKind) |
73 | impl[vk] = llvm::SmallBitVector(ranks.getRank(vk)); |
74 | assert(getRanks() == ranks); |
75 | } |
76 | |
77 | bool VarSet::contains(Var var) const { |
78 | // NOTE: We make sure to return false on OOB, for consistency with |
79 | // the `anyCommon` implementation of `VarSet::occursIn(VarSet)`. |
80 | // However beware that, as always with silencing OOB, this can hide |
81 | // bugs in client code. |
82 | const llvm::SmallBitVector &bits = impl[var.getKind()]; |
83 | const auto num = var.getNum(); |
84 | return num < bits.size() && bits[num]; |
85 | } |
86 | |
87 | void VarSet::add(Var var) { |
88 | // NOTE: `SmallBitVector::operator[]` will raise assertion errors for OOB. |
89 | impl[var.getKind()][var.getNum()] = true; |
90 | } |
91 | |
92 | void VarSet::add(VarSet const &other) { |
93 | // NOTE: `SmallBitVector::operator&=` will implicitly resize |
94 | // the bitvector (unlike `BitVector::operator&=`), so we add an |
95 | // assertion against OOB for consistency with the implementation |
96 | // of `VarSet::add(Var)`. |
97 | for (const auto vk : everyVarKind) { |
98 | assert(impl[vk].size() >= other.impl[vk].size()); |
99 | impl[vk] &= other.impl[vk]; |
100 | } |
101 | } |
102 | |
103 | void VarSet::add(DimLvlExpr expr) { |
104 | if (!expr) |
105 | return; |
106 | switch (expr.getAffineKind()) { |
107 | case AffineExprKind::Constant: |
108 | return; |
109 | case AffineExprKind::SymbolId: |
110 | add(var: expr.castSymVar()); |
111 | return; |
112 | case AffineExprKind::DimId: |
113 | add(var: expr.castDimLvlVar()); |
114 | return; |
115 | case AffineExprKind::Add: |
116 | case AffineExprKind::Mul: |
117 | case AffineExprKind::Mod: |
118 | case AffineExprKind::FloorDiv: |
119 | case AffineExprKind::CeilDiv: { |
120 | const auto [lhs, op, rhs] = expr.unpackBinop(); |
121 | (void)op; |
122 | add(expr: lhs); |
123 | add(expr: rhs); |
124 | return; |
125 | } |
126 | } |
127 | llvm_unreachable("unknown AffineExprKind" ); |
128 | } |
129 | |
130 | //===----------------------------------------------------------------------===// |
131 | // `VarInfo` implementation. |
132 | //===----------------------------------------------------------------------===// |
133 | |
134 | void VarInfo::setNum(Var::Num n) { |
135 | assert(!hasNum() && "Var::Num is already set" ); |
136 | assert(Var::isWF_Num(n) && "Var::Num is too large" ); |
137 | num = n; |
138 | } |
139 | |
140 | //===----------------------------------------------------------------------===// |
141 | // `VarEnv` implementation. |
142 | //===----------------------------------------------------------------------===// |
143 | |
144 | /// Helper function for `assertUsageConsistency` to better handle SMLoc |
145 | /// mismatches. |
146 | LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc |
147 | minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) { |
148 | const auto loc1 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(loc: sm1)); |
149 | assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`" ); |
150 | const auto loc2 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(loc: sm2)); |
151 | assert(loc2 && "Could not get `FileLineColLoc` for second `SMLoc`" ); |
152 | if (loc1.getFilename() != loc2.getFilename()) |
153 | return SMLoc(); |
154 | const auto pair1 = std::make_pair(loc1.getLine(), loc1.getColumn()); |
155 | const auto pair2 = std::make_pair(loc2.getLine(), loc2.getColumn()); |
156 | return pair1 <= pair2 ? sm1 : sm2; |
157 | } |
158 | |
159 | bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name) { |
160 | const auto &var = env.access(id); |
161 | return (var.getName() == name && var.getID() == id); |
162 | } |
163 | |
164 | bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc, |
165 | VarKind vk) { |
166 | const auto &var = env.access(id); |
167 | return var.getKind() == vk; |
168 | } |
169 | |
170 | std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const { |
171 | const auto iter = ids.find(Key: name); |
172 | if (iter == ids.end()) |
173 | return std::nullopt; |
174 | const auto id = iter->second; |
175 | if (!isInternalConsistent(env: *this, id, name)) |
176 | return std::nullopt; |
177 | return id; |
178 | } |
179 | |
180 | std::optional<std::pair<VarInfo::ID, bool>> |
181 | VarEnv::create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage) { |
182 | const auto &[iter, didInsert] = ids.try_emplace(Key: name, Args: nextID()); |
183 | const auto id = iter->second; |
184 | if (didInsert) { |
185 | vars.emplace_back(Args: id, Args&: name, Args&: loc, Args&: vk); |
186 | } else { |
187 | if (!isInternalConsistent(env: *this, id, name)) |
188 | return std::nullopt; |
189 | if (verifyUsage) |
190 | if (!isUsageConsistent(env: *this, id, loc, vk)) |
191 | return std::nullopt; |
192 | } |
193 | return std::make_pair(x: id, y: didInsert); |
194 | } |
195 | |
196 | std::optional<std::pair<VarInfo::ID, bool>> |
197 | VarEnv::lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc, |
198 | VarKind vk) { |
199 | switch (creationPolicy) { |
200 | case Policy::MustNot: { |
201 | const auto oid = lookup(name); |
202 | if (!oid) |
203 | return std::nullopt; // Doesn't exist, but must not create. |
204 | if (!isUsageConsistent(env: *this, id: *oid, loc, vk)) |
205 | return std::nullopt; |
206 | return std::make_pair(x: *oid, y: false); |
207 | } |
208 | case Policy::May: |
209 | return create(name, loc, vk, /*verifyUsage=*/true); |
210 | case Policy::Must: { |
211 | const auto res = create(name, loc, vk, /*verifyUsage=*/false); |
212 | const auto didCreate = res->second; |
213 | if (!didCreate) |
214 | return std::nullopt; // Already exists, but must create. |
215 | return res; |
216 | } |
217 | } |
218 | llvm_unreachable("unknown Policy" ); |
219 | } |
220 | |
221 | Var VarEnv::bindUnusedVar(VarKind vk) { return Var(vk, nextNum[vk]++); } |
222 | Var VarEnv::bindVar(VarInfo::ID id) { |
223 | auto &info = access(id); |
224 | const auto var = bindUnusedVar(vk: info.getKind()); |
225 | info.setNum(var.getNum()); |
226 | return var; |
227 | } |
228 | |
229 | InFlightDiagnostic VarEnv::emitErrorIfAnyUnbound(AsmParser &parser) const { |
230 | for (const auto &var : vars) |
231 | if (!var.hasNum()) |
232 | return parser.emitError(loc: var.getLoc(), |
233 | message: "Unbound variable: " + var.getName()); |
234 | return {}; |
235 | } |
236 | |
237 | //===----------------------------------------------------------------------===// |
238 | |