1 | //===- DimLvlMap.cpp ------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "DimLvlMap.h" |
10 | |
11 | using namespace mlir; |
12 | using namespace mlir::sparse_tensor; |
13 | using namespace mlir::sparse_tensor::ir_detail; |
14 | |
15 | //===----------------------------------------------------------------------===// |
16 | // `DimLvlExpr` implementation. |
17 | //===----------------------------------------------------------------------===// |
18 | |
19 | SymVar DimLvlExpr::castSymVar() const { |
20 | return SymVar(llvm::cast<AffineSymbolExpr>(Val: expr)); |
21 | } |
22 | |
23 | std::optional<SymVar> DimLvlExpr::dyn_castSymVar() const { |
24 | if (const auto s = dyn_cast_or_null<AffineSymbolExpr>(Val: expr)) |
25 | return SymVar(s); |
26 | return std::nullopt; |
27 | } |
28 | |
29 | Var DimLvlExpr::castDimLvlVar() const { |
30 | return Var(getAllowedVarKind(), llvm::cast<AffineDimExpr>(Val: expr)); |
31 | } |
32 | |
33 | std::optional<Var> DimLvlExpr::dyn_castDimLvlVar() const { |
34 | if (const auto x = dyn_cast_or_null<AffineDimExpr>(Val: expr)) |
35 | return Var(getAllowedVarKind(), x); |
36 | return std::nullopt; |
37 | } |
38 | |
39 | std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr> |
40 | DimLvlExpr::unpackBinop() const { |
41 | const auto ak = getAffineKind(); |
42 | const auto binop = llvm::dyn_cast<AffineBinaryOpExpr>(Val: expr); |
43 | const DimLvlExpr lhs(kind, binop ? binop.getLHS() : nullptr); |
44 | const DimLvlExpr rhs(kind, binop ? binop.getRHS() : nullptr); |
45 | return {lhs, ak, rhs}; |
46 | } |
47 | |
48 | //===----------------------------------------------------------------------===// |
49 | // `DimSpec` implementation. |
50 | //===----------------------------------------------------------------------===// |
51 | |
52 | DimSpec::DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice) |
53 | : var(var), expr(expr), slice(slice) {} |
54 | |
55 | bool DimSpec::isValid(Ranks const &ranks) const { |
56 | // Nothing in `slice` needs additional validation. |
57 | // We explicitly consider null-expr to be vacuously valid. |
58 | return ranks.isValid(var) && (!expr || ranks.isValid(expr)); |
59 | } |
60 | |
61 | //===----------------------------------------------------------------------===// |
62 | // `LvlSpec` implementation. |
63 | //===----------------------------------------------------------------------===// |
64 | |
65 | LvlSpec::LvlSpec(LvlVar var, LvlExpr expr, LevelType type) |
66 | : var(var), expr(expr), type(type) { |
67 | assert(expr); |
68 | assert(isValidLT(type) && !isUndefLT(type)); |
69 | } |
70 | |
71 | bool LvlSpec::isValid(Ranks const &ranks) const { |
72 | // Nothing in `type` needs additional validation. |
73 | return ranks.isValid(var) && ranks.isValid(expr); |
74 | } |
75 | |
76 | //===----------------------------------------------------------------------===// |
77 | // `DimLvlMap` implementation. |
78 | //===----------------------------------------------------------------------===// |
79 | |
80 | DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs, |
81 | ArrayRef<LvlSpec> lvlSpecs) |
82 | : symRank(symRank), dimSpecs(dimSpecs), lvlSpecs(lvlSpecs), |
83 | mustPrintLvlVars(false) { |
84 | // First, check integrity of the variable-binding structure. |
85 | // NOTE: This establishes the invariant that calls to `VarSet::add` |
86 | // below cannot cause OOB errors. |
87 | assert(isWF()); |
88 | |
89 | VarSet usedVars(getRanks()); |
90 | for (const auto &dimSpec : dimSpecs) |
91 | if (!dimSpec.canElideExpr()) |
92 | usedVars.add(expr: dimSpec.getExpr()); |
93 | for (auto &lvlSpec : this->lvlSpecs) { |
94 | // Is this LvlVar used in any overt expression? |
95 | const bool isUsed = usedVars.contains(var: lvlSpec.getBoundVar()); |
96 | // This LvlVar can be elided iff it isn't overtly used. |
97 | lvlSpec.setElideVar(!isUsed); |
98 | // If any LvlVar cannot be elided, then must forward-declare all LvlVars. |
99 | mustPrintLvlVars = mustPrintLvlVars || isUsed; |
100 | } |
101 | } |
102 | |
103 | bool DimLvlMap::isWF() const { |
104 | const auto ranks = getRanks(); |
105 | unsigned dimNum = 0; |
106 | for (const auto &dimSpec : dimSpecs) |
107 | if (dimSpec.getBoundVar().getNum() != dimNum++ || !dimSpec.isValid(ranks)) |
108 | return false; |
109 | assert(dimNum == ranks.getDimRank()); |
110 | unsigned lvlNum = 0; |
111 | for (const auto &lvlSpec : lvlSpecs) |
112 | if (lvlSpec.getBoundVar().getNum() != lvlNum++ || !lvlSpec.isValid(ranks)) |
113 | return false; |
114 | assert(lvlNum == ranks.getLvlRank()); |
115 | return true; |
116 | } |
117 | |
118 | AffineMap DimLvlMap::getDimToLvlMap(MLIRContext *context) const { |
119 | SmallVector<AffineExpr> lvlAffines; |
120 | lvlAffines.reserve(N: getLvlRank()); |
121 | for (const auto &lvlSpec : lvlSpecs) |
122 | lvlAffines.push_back(Elt: lvlSpec.getExpr().getAffineExpr()); |
123 | auto map = AffineMap::get(dimCount: getDimRank(), symbolCount: getSymRank(), results: lvlAffines, context); |
124 | return map; |
125 | } |
126 | |
127 | AffineMap DimLvlMap::getLvlToDimMap(MLIRContext *context) const { |
128 | SmallVector<AffineExpr> dimAffines; |
129 | dimAffines.reserve(N: getDimRank()); |
130 | for (const auto &dimSpec : dimSpecs) { |
131 | auto expr = dimSpec.getExpr().getAffineExpr(); |
132 | if (expr) { |
133 | dimAffines.push_back(Elt: expr); |
134 | } |
135 | } |
136 | auto map = AffineMap::get(dimCount: getLvlRank(), symbolCount: getSymRank(), results: dimAffines, context); |
137 | // If no lvlToDim map was passed in, returns a null AffineMap and infers it |
138 | // in SparseTensorEncodingAttr::parse. |
139 | if (dimAffines.empty()) |
140 | return AffineMap(); |
141 | return map; |
142 | } |
143 | |
144 | //===----------------------------------------------------------------------===// |
145 | |