| 1 | //===- Var.cpp ------------------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "Var.h" |
| 10 | #include "DimLvlMap.h" |
| 11 | |
| 12 | using namespace mlir; |
| 13 | using namespace mlir::sparse_tensor; |
| 14 | using namespace mlir::sparse_tensor::ir_detail; |
| 15 | |
| 16 | //===----------------------------------------------------------------------===// |
| 17 | // `VarKind` helpers. |
| 18 | //===----------------------------------------------------------------------===// |
| 19 | |
| 20 | /// For use in foreach loops. |
| 21 | static constexpr const VarKind everyVarKind[] = { |
| 22 | VarKind::Dimension, VarKind::Symbol, VarKind::Level}; |
| 23 | |
| 24 | //===----------------------------------------------------------------------===// |
| 25 | // `Var` implementation. |
| 26 | //===----------------------------------------------------------------------===// |
| 27 | |
| 28 | std::string Var::str() const { |
| 29 | std::string str; |
| 30 | llvm::raw_string_ostream os(str); |
| 31 | print(os); |
| 32 | return str; |
| 33 | } |
| 34 | |
| 35 | void Var::print(AsmPrinter &printer) const { print(os&: printer.getStream()); } |
| 36 | |
| 37 | void Var::print(llvm::raw_ostream &os) const { |
| 38 | os << toChar(vk: getKind()) << getNum(); |
| 39 | } |
| 40 | |
| 41 | void Var::dump() const { |
| 42 | print(os&: llvm::errs()); |
| 43 | llvm::errs() << "\n" ; |
| 44 | } |
| 45 | |
| 46 | //===----------------------------------------------------------------------===// |
| 47 | // `Ranks` implementation. |
| 48 | //===----------------------------------------------------------------------===// |
| 49 | |
| 50 | bool Ranks::operator==(Ranks const &other) const { |
| 51 | for (const auto vk : everyVarKind) |
| 52 | if (getRank(vk) != other.getRank(vk)) |
| 53 | return false; |
| 54 | return true; |
| 55 | } |
| 56 | |
| 57 | bool Ranks::isValid(DimLvlExpr expr) const { |
| 58 | assert(expr); |
| 59 | // Compute the maximum identifiers for symbol-vars and dim/lvl-vars |
| 60 | // (each `DimLvlExpr` only allows one kind of non-symbol variable). |
| 61 | int64_t maxSym = -1, maxVar = -1; |
| 62 | mlir::getMaxDimAndSymbol<ArrayRef<AffineExpr>>(exprsList: {{expr.getAffineExpr()}}, |
| 63 | maxDim&: maxVar, maxSym); |
| 64 | return maxSym < getSymRank() && maxVar < getRank(vk: expr.getAllowedVarKind()); |
| 65 | } |
| 66 | |
| 67 | //===----------------------------------------------------------------------===// |
| 68 | // `VarSet` implementation. |
| 69 | //===----------------------------------------------------------------------===// |
| 70 | |
| 71 | VarSet::VarSet(Ranks const &ranks) { |
| 72 | for (const auto vk : everyVarKind) |
| 73 | impl[vk] = llvm::SmallBitVector(ranks.getRank(vk)); |
| 74 | assert(getRanks() == ranks); |
| 75 | } |
| 76 | |
| 77 | bool VarSet::contains(Var var) const { |
| 78 | // NOTE: We make sure to return false on OOB, for consistency with |
| 79 | // the `anyCommon` implementation of `VarSet::occursIn(VarSet)`. |
| 80 | // However beware that, as always with silencing OOB, this can hide |
| 81 | // bugs in client code. |
| 82 | const llvm::SmallBitVector &bits = impl[var.getKind()]; |
| 83 | const auto num = var.getNum(); |
| 84 | return num < bits.size() && bits[num]; |
| 85 | } |
| 86 | |
| 87 | void VarSet::add(Var var) { |
| 88 | // NOTE: `SmallBitVector::operator[]` will raise assertion errors for OOB. |
| 89 | impl[var.getKind()][var.getNum()] = true; |
| 90 | } |
| 91 | |
| 92 | void VarSet::add(VarSet const &other) { |
| 93 | // NOTE: `SmallBitVector::operator&=` will implicitly resize |
| 94 | // the bitvector (unlike `BitVector::operator&=`), so we add an |
| 95 | // assertion against OOB for consistency with the implementation |
| 96 | // of `VarSet::add(Var)`. |
| 97 | for (const auto vk : everyVarKind) { |
| 98 | assert(impl[vk].size() >= other.impl[vk].size()); |
| 99 | impl[vk] &= other.impl[vk]; |
| 100 | } |
| 101 | } |
| 102 | |
| 103 | void VarSet::add(DimLvlExpr expr) { |
| 104 | if (!expr) |
| 105 | return; |
| 106 | switch (expr.getAffineKind()) { |
| 107 | case AffineExprKind::Constant: |
| 108 | return; |
| 109 | case AffineExprKind::SymbolId: |
| 110 | add(var: expr.castSymVar()); |
| 111 | return; |
| 112 | case AffineExprKind::DimId: |
| 113 | add(var: expr.castDimLvlVar()); |
| 114 | return; |
| 115 | case AffineExprKind::Add: |
| 116 | case AffineExprKind::Mul: |
| 117 | case AffineExprKind::Mod: |
| 118 | case AffineExprKind::FloorDiv: |
| 119 | case AffineExprKind::CeilDiv: { |
| 120 | const auto [lhs, op, rhs] = expr.unpackBinop(); |
| 121 | (void)op; |
| 122 | add(expr: lhs); |
| 123 | add(expr: rhs); |
| 124 | return; |
| 125 | } |
| 126 | } |
| 127 | llvm_unreachable("unknown AffineExprKind" ); |
| 128 | } |
| 129 | |
| 130 | //===----------------------------------------------------------------------===// |
| 131 | // `VarInfo` implementation. |
| 132 | //===----------------------------------------------------------------------===// |
| 133 | |
| 134 | void VarInfo::setNum(Var::Num n) { |
| 135 | assert(!hasNum() && "Var::Num is already set" ); |
| 136 | assert(Var::isWF_Num(n) && "Var::Num is too large" ); |
| 137 | num = n; |
| 138 | } |
| 139 | |
| 140 | //===----------------------------------------------------------------------===// |
| 141 | // `VarEnv` implementation. |
| 142 | //===----------------------------------------------------------------------===// |
| 143 | |
| 144 | /// Helper function for `assertUsageConsistency` to better handle SMLoc |
| 145 | /// mismatches. |
| 146 | LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc |
| 147 | minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) { |
| 148 | const auto loc1 = dyn_cast<FileLineColLoc>(Val: parser.getEncodedSourceLoc(loc: sm1)); |
| 149 | assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`" ); |
| 150 | const auto loc2 = dyn_cast<FileLineColLoc>(Val: parser.getEncodedSourceLoc(loc: sm2)); |
| 151 | assert(loc2 && "Could not get `FileLineColLoc` for second `SMLoc`" ); |
| 152 | if (loc1.getFilename() != loc2.getFilename()) |
| 153 | return SMLoc(); |
| 154 | const auto pair1 = std::make_pair(x: loc1.getLine(), y: loc1.getColumn()); |
| 155 | const auto pair2 = std::make_pair(x: loc2.getLine(), y: loc2.getColumn()); |
| 156 | return pair1 <= pair2 ? sm1 : sm2; |
| 157 | } |
| 158 | |
| 159 | bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name) { |
| 160 | const auto &var = env.access(id); |
| 161 | return (var.getName() == name && var.getID() == id); |
| 162 | } |
| 163 | |
| 164 | bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc, |
| 165 | VarKind vk) { |
| 166 | const auto &var = env.access(id); |
| 167 | return var.getKind() == vk; |
| 168 | } |
| 169 | |
| 170 | std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const { |
| 171 | const auto iter = ids.find(Key: name); |
| 172 | if (iter == ids.end()) |
| 173 | return std::nullopt; |
| 174 | const auto id = iter->second; |
| 175 | if (!isInternalConsistent(env: *this, id, name)) |
| 176 | return std::nullopt; |
| 177 | return id; |
| 178 | } |
| 179 | |
| 180 | std::optional<std::pair<VarInfo::ID, bool>> |
| 181 | VarEnv::create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage) { |
| 182 | const auto &[iter, didInsert] = ids.try_emplace(Key: name, Args: nextID()); |
| 183 | const auto id = iter->second; |
| 184 | if (didInsert) { |
| 185 | vars.emplace_back(Args: id, Args&: name, Args&: loc, Args&: vk); |
| 186 | } else { |
| 187 | if (!isInternalConsistent(env: *this, id, name)) |
| 188 | return std::nullopt; |
| 189 | if (verifyUsage) |
| 190 | if (!isUsageConsistent(env: *this, id, loc, vk)) |
| 191 | return std::nullopt; |
| 192 | } |
| 193 | return std::make_pair(x: id, y: didInsert); |
| 194 | } |
| 195 | |
| 196 | std::optional<std::pair<VarInfo::ID, bool>> |
| 197 | VarEnv::lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc, |
| 198 | VarKind vk) { |
| 199 | switch (creationPolicy) { |
| 200 | case Policy::MustNot: { |
| 201 | const auto oid = lookup(name); |
| 202 | if (!oid) |
| 203 | return std::nullopt; // Doesn't exist, but must not create. |
| 204 | if (!isUsageConsistent(env: *this, id: *oid, loc, vk)) |
| 205 | return std::nullopt; |
| 206 | return std::make_pair(x: *oid, y: false); |
| 207 | } |
| 208 | case Policy::May: |
| 209 | return create(name, loc, vk, /*verifyUsage=*/true); |
| 210 | case Policy::Must: { |
| 211 | const auto res = create(name, loc, vk, /*verifyUsage=*/false); |
| 212 | const auto didCreate = res->second; |
| 213 | if (!didCreate) |
| 214 | return std::nullopt; // Already exists, but must create. |
| 215 | return res; |
| 216 | } |
| 217 | } |
| 218 | llvm_unreachable("unknown Policy" ); |
| 219 | } |
| 220 | |
| 221 | Var VarEnv::bindUnusedVar(VarKind vk) { return Var(vk, nextNum[vk]++); } |
| 222 | Var VarEnv::bindVar(VarInfo::ID id) { |
| 223 | auto &info = access(id); |
| 224 | const auto var = bindUnusedVar(vk: info.getKind()); |
| 225 | info.setNum(var.getNum()); |
| 226 | return var; |
| 227 | } |
| 228 | |
| 229 | InFlightDiagnostic VarEnv::emitErrorIfAnyUnbound(AsmParser &parser) const { |
| 230 | for (const auto &var : vars) |
| 231 | if (!var.hasNum()) |
| 232 | return parser.emitError(loc: var.getLoc(), |
| 233 | message: "Unbound variable: " + var.getName()); |
| 234 | return {}; |
| 235 | } |
| 236 | |
| 237 | //===----------------------------------------------------------------------===// |
| 238 | |