1//===- Var.cpp ------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Var.h"
10#include "DimLvlMap.h"
11
12using namespace mlir;
13using namespace mlir::sparse_tensor;
14using namespace mlir::sparse_tensor::ir_detail;
15
16//===----------------------------------------------------------------------===//
17// `VarKind` helpers.
18//===----------------------------------------------------------------------===//
19
20/// For use in foreach loops.
21static constexpr const VarKind everyVarKind[] = {
22 VarKind::Dimension, VarKind::Symbol, VarKind::Level};
23
24//===----------------------------------------------------------------------===//
25// `Var` implementation.
26//===----------------------------------------------------------------------===//
27
28std::string Var::str() const {
29 std::string str;
30 llvm::raw_string_ostream os(str);
31 print(os);
32 return os.str();
33}
34
35void Var::print(AsmPrinter &printer) const { print(os&: printer.getStream()); }
36
37void Var::print(llvm::raw_ostream &os) const {
38 os << toChar(vk: getKind()) << getNum();
39}
40
41void Var::dump() const {
42 print(os&: llvm::errs());
43 llvm::errs() << "\n";
44}
45
46//===----------------------------------------------------------------------===//
47// `Ranks` implementation.
48//===----------------------------------------------------------------------===//
49
50bool Ranks::operator==(Ranks const &other) const {
51 for (const auto vk : everyVarKind)
52 if (getRank(vk) != other.getRank(vk))
53 return false;
54 return true;
55}
56
57bool Ranks::isValid(DimLvlExpr expr) const {
58 assert(expr);
59 // Compute the maximum identifiers for symbol-vars and dim/lvl-vars
60 // (each `DimLvlExpr` only allows one kind of non-symbol variable).
61 int64_t maxSym = -1, maxVar = -1;
62 mlir::getMaxDimAndSymbol<ArrayRef<AffineExpr>>(exprsList: {{expr.getAffineExpr()}},
63 maxDim&: maxVar, maxSym);
64 return maxSym < getSymRank() && maxVar < getRank(vk: expr.getAllowedVarKind());
65}
66
67//===----------------------------------------------------------------------===//
68// `VarSet` implementation.
69//===----------------------------------------------------------------------===//
70
71VarSet::VarSet(Ranks const &ranks) {
72 for (const auto vk : everyVarKind)
73 impl[vk] = llvm::SmallBitVector(ranks.getRank(vk));
74 assert(getRanks() == ranks);
75}
76
77bool VarSet::contains(Var var) const {
78 // NOTE: We make sure to return false on OOB, for consistency with
79 // the `anyCommon` implementation of `VarSet::occursIn(VarSet)`.
80 // However beware that, as always with silencing OOB, this can hide
81 // bugs in client code.
82 const llvm::SmallBitVector &bits = impl[var.getKind()];
83 const auto num = var.getNum();
84 return num < bits.size() && bits[num];
85}
86
87void VarSet::add(Var var) {
88 // NOTE: `SmallBitVector::operator[]` will raise assertion errors for OOB.
89 impl[var.getKind()][var.getNum()] = true;
90}
91
92void VarSet::add(VarSet const &other) {
93 // NOTE: `SmallBitVector::operator&=` will implicitly resize
94 // the bitvector (unlike `BitVector::operator&=`), so we add an
95 // assertion against OOB for consistency with the implementation
96 // of `VarSet::add(Var)`.
97 for (const auto vk : everyVarKind) {
98 assert(impl[vk].size() >= other.impl[vk].size());
99 impl[vk] &= other.impl[vk];
100 }
101}
102
103void VarSet::add(DimLvlExpr expr) {
104 if (!expr)
105 return;
106 switch (expr.getAffineKind()) {
107 case AffineExprKind::Constant:
108 return;
109 case AffineExprKind::SymbolId:
110 add(var: expr.castSymVar());
111 return;
112 case AffineExprKind::DimId:
113 add(var: expr.castDimLvlVar());
114 return;
115 case AffineExprKind::Add:
116 case AffineExprKind::Mul:
117 case AffineExprKind::Mod:
118 case AffineExprKind::FloorDiv:
119 case AffineExprKind::CeilDiv: {
120 const auto [lhs, op, rhs] = expr.unpackBinop();
121 (void)op;
122 add(expr: lhs);
123 add(expr: rhs);
124 return;
125 }
126 }
127 llvm_unreachable("unknown AffineExprKind");
128}
129
130//===----------------------------------------------------------------------===//
131// `VarInfo` implementation.
132//===----------------------------------------------------------------------===//
133
134void VarInfo::setNum(Var::Num n) {
135 assert(!hasNum() && "Var::Num is already set");
136 assert(Var::isWF_Num(n) && "Var::Num is too large");
137 num = n;
138}
139
140//===----------------------------------------------------------------------===//
141// `VarEnv` implementation.
142//===----------------------------------------------------------------------===//
143
144/// Helper function for `assertUsageConsistency` to better handle SMLoc
145/// mismatches.
146LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc
147minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
148 const auto loc1 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(loc: sm1));
149 assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`");
150 const auto loc2 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(loc: sm2));
151 assert(loc2 && "Could not get `FileLineColLoc` for second `SMLoc`");
152 if (loc1.getFilename() != loc2.getFilename())
153 return SMLoc();
154 const auto pair1 = std::make_pair(loc1.getLine(), loc1.getColumn());
155 const auto pair2 = std::make_pair(loc2.getLine(), loc2.getColumn());
156 return pair1 <= pair2 ? sm1 : sm2;
157}
158
159bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name) {
160 const auto &var = env.access(id);
161 return (var.getName() == name && var.getID() == id);
162}
163
164bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc,
165 VarKind vk) {
166 const auto &var = env.access(id);
167 return var.getKind() == vk;
168}
169
170std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const {
171 const auto iter = ids.find(Key: name);
172 if (iter == ids.end())
173 return std::nullopt;
174 const auto id = iter->second;
175 if (!isInternalConsistent(env: *this, id, name))
176 return std::nullopt;
177 return id;
178}
179
180std::optional<std::pair<VarInfo::ID, bool>>
181VarEnv::create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage) {
182 const auto &[iter, didInsert] = ids.try_emplace(Key: name, Args: nextID());
183 const auto id = iter->second;
184 if (didInsert) {
185 vars.emplace_back(Args: id, Args&: name, Args&: loc, Args&: vk);
186 } else {
187 if (!isInternalConsistent(env: *this, id, name))
188 return std::nullopt;
189 if (verifyUsage)
190 if (!isUsageConsistent(env: *this, id, loc, vk))
191 return std::nullopt;
192 }
193 return std::make_pair(x: id, y: didInsert);
194}
195
196std::optional<std::pair<VarInfo::ID, bool>>
197VarEnv::lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc,
198 VarKind vk) {
199 switch (creationPolicy) {
200 case Policy::MustNot: {
201 const auto oid = lookup(name);
202 if (!oid)
203 return std::nullopt; // Doesn't exist, but must not create.
204 if (!isUsageConsistent(env: *this, id: *oid, loc, vk))
205 return std::nullopt;
206 return std::make_pair(x: *oid, y: false);
207 }
208 case Policy::May:
209 return create(name, loc, vk, /*verifyUsage=*/true);
210 case Policy::Must: {
211 const auto res = create(name, loc, vk, /*verifyUsage=*/false);
212 const auto didCreate = res->second;
213 if (!didCreate)
214 return std::nullopt; // Already exists, but must create.
215 return res;
216 }
217 }
218 llvm_unreachable("unknown Policy");
219}
220
221Var VarEnv::bindUnusedVar(VarKind vk) { return Var(vk, nextNum[vk]++); }
222Var VarEnv::bindVar(VarInfo::ID id) {
223 auto &info = access(id);
224 const auto var = bindUnusedVar(vk: info.getKind());
225 info.setNum(var.getNum());
226 return var;
227}
228
229InFlightDiagnostic VarEnv::emitErrorIfAnyUnbound(AsmParser &parser) const {
230 for (const auto &var : vars)
231 if (!var.hasNum())
232 return parser.emitError(loc: var.getLoc(),
233 message: "Unbound variable: " + var.getName());
234 return {};
235}
236
237//===----------------------------------------------------------------------===//
238

source code of mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp