1//===- DimLvlMapParser.cpp - `DimLvlMap` parser implementation ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "DimLvlMapParser.h"
10
11using namespace mlir;
12using namespace mlir::sparse_tensor;
13using namespace mlir::sparse_tensor::ir_detail;
14
15#define FAILURE_IF_FAILED(RES) \
16 if (failed(RES)) { \
17 return failure(); \
18 }
19
20/// Helper function for `FAILURE_IF_NULLOPT_OR_FAILED` to avoid duplicating
21/// its `RES` parameter.
22static inline bool didntSucceed(OptionalParseResult res) {
23 return !res.has_value() || failed(result: *res);
24}
25
26#define FAILURE_IF_NULLOPT_OR_FAILED(RES) \
27 if (didntSucceed(RES)) { \
28 return failure(); \
29 }
30
31// NOTE: this macro assumes `AsmParser parser` and `SMLoc loc` are in scope.
32#define ERROR_IF(COND, MSG) \
33 if (COND) { \
34 return parser.emitError(loc, MSG); \
35 }
36
37//===----------------------------------------------------------------------===//
38// `DimLvlMapParser` implementation for variable parsing.
39//===----------------------------------------------------------------------===//
40
41// Our variation on `AffineParser::{parseBareIdExpr,parseIdentifierDefinition}`
42OptionalParseResult DimLvlMapParser::parseVar(VarKind vk, bool isOptional,
43 Policy creationPolicy,
44 VarInfo::ID &varID,
45 bool &didCreate) {
46 // Save the current location so that we can have error messages point to
47 // the right place.
48 const auto loc = parser.getCurrentLocation();
49 StringRef name;
50 if (failed(result: parser.parseOptionalKeyword(keyword: &name))) {
51 ERROR_IF(!isOptional, "expected bare identifier")
52 return std::nullopt;
53 }
54
55 if (const auto res = env.lookupOrCreate(creationPolicy, name, loc, vk)) {
56 varID = res->first;
57 didCreate = res->second;
58 return success();
59 }
60
61 switch (creationPolicy) {
62 case Policy::MustNot:
63 return parser.emitError(loc, message: "use of undeclared identifier '" + name + "'");
64 case Policy::May:
65 llvm_unreachable("got nullopt for Policy::May");
66 case Policy::Must:
67 return parser.emitError(loc, message: "redefinition of identifier '" + name + "'");
68 }
69 llvm_unreachable("unknown Policy");
70}
71
72FailureOr<VarInfo::ID> DimLvlMapParser::parseVarUsage(VarKind vk,
73 bool requireKnown) {
74 VarInfo::ID id;
75 bool didCreate;
76 const bool isOptional = false;
77 const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::May;
78 const auto res = parseVar(vk, isOptional, creationPolicy, varID&: id, didCreate);
79 FAILURE_IF_NULLOPT_OR_FAILED(res)
80 assert(requireKnown ? !didCreate : true);
81 return id;
82}
83
84FailureOr<VarInfo::ID> DimLvlMapParser::parseVarBinding(VarKind vk,
85 bool requireKnown) {
86 const auto loc = parser.getCurrentLocation();
87 VarInfo::ID id;
88 bool didCreate;
89 const bool isOptional = false;
90 const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::Must;
91 const auto res = parseVar(vk, isOptional, creationPolicy, varID&: id, didCreate);
92 FAILURE_IF_NULLOPT_OR_FAILED(res)
93 assert(requireKnown ? !didCreate : didCreate);
94 bindVar(loc, id);
95 return id;
96}
97
98FailureOr<std::pair<Var, bool>>
99DimLvlMapParser::parseOptionalVarBinding(VarKind vk, bool requireKnown) {
100 const auto loc = parser.getCurrentLocation();
101 VarInfo::ID id;
102 bool didCreate;
103 const bool isOptional = true;
104 const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::Must;
105 const auto res = parseVar(vk, isOptional, creationPolicy, varID&: id, didCreate);
106 if (res.has_value()) {
107 FAILURE_IF_FAILED(*res)
108 assert(didCreate);
109 return std::make_pair(x: bindVar(loc, id), y: true);
110 }
111 assert(!didCreate);
112 return std::make_pair(x: env.bindUnusedVar(vk), y: false);
113}
114
115Var DimLvlMapParser::bindVar(llvm::SMLoc loc, VarInfo::ID id) {
116 MLIRContext *context = parser.getContext();
117 const auto var = env.bindVar(id);
118 const auto &info = std::as_const(t&: env).access(id);
119 const auto name = info.getName();
120 const auto num = *info.getNum();
121 switch (info.getKind()) {
122 case VarKind::Symbol: {
123 const auto affine = getAffineSymbolExpr(position: num, context);
124 dimsAndSymbols.emplace_back(Args: name, Args: affine);
125 lvlsAndSymbols.emplace_back(Args: name, Args: affine);
126 return var;
127 }
128 case VarKind::Dimension:
129 dimsAndSymbols.emplace_back(Args: name, Args: getAffineDimExpr(position: num, context));
130 return var;
131 case VarKind::Level:
132 lvlsAndSymbols.emplace_back(Args: name, Args: getAffineDimExpr(position: num, context));
133 return var;
134 }
135 llvm_unreachable("unknown VarKind");
136}
137
138//===----------------------------------------------------------------------===//
139// `DimLvlMapParser` implementation for `DimLvlMap` per se.
140//===----------------------------------------------------------------------===//
141
142FailureOr<DimLvlMap> DimLvlMapParser::parseDimLvlMap() {
143 FAILURE_IF_FAILED(parseSymbolBindingList())
144 FAILURE_IF_FAILED(parseLvlVarBindingList())
145 FAILURE_IF_FAILED(parseDimSpecList())
146 FAILURE_IF_FAILED(parser.parseArrow())
147 FAILURE_IF_FAILED(parseLvlSpecList())
148 InFlightDiagnostic ifd = env.emitErrorIfAnyUnbound(parser);
149 if (failed(result: ifd))
150 return ifd;
151 return DimLvlMap(env.getRanks().getSymRank(), dimSpecs, lvlSpecs);
152}
153
154ParseResult DimLvlMapParser::parseSymbolBindingList() {
155 return parser.parseCommaSeparatedList(
156 delimiter: OpAsmParser::Delimiter::OptionalSquare,
157 parseElementFn: [this]() { return ParseResult(parseVarBinding(vk: VarKind::Symbol)); },
158 contextMessage: " in symbol binding list");
159}
160
161ParseResult DimLvlMapParser::parseLvlVarBindingList() {
162 return parser.parseCommaSeparatedList(
163 delimiter: OpAsmParser::Delimiter::OptionalBraces,
164 parseElementFn: [this]() { return ParseResult(parseVarBinding(vk: VarKind::Level)); },
165 contextMessage: " in level declaration list");
166}
167
168//===----------------------------------------------------------------------===//
169// `DimLvlMapParser` implementation for `DimSpec`.
170//===----------------------------------------------------------------------===//
171
172ParseResult DimLvlMapParser::parseDimSpecList() {
173 return parser.parseCommaSeparatedList(
174 delimiter: OpAsmParser::Delimiter::Paren,
175 parseElementFn: [this]() -> ParseResult { return parseDimSpec(); },
176 contextMessage: " in dimension-specifier list");
177}
178
179ParseResult DimLvlMapParser::parseDimSpec() {
180 // Parse the requisite dim-var binding.
181 const auto varID = parseVarBinding(vk: VarKind::Dimension);
182 FAILURE_IF_FAILED(varID)
183 const DimVar var = env.getVar(id: *varID).cast<DimVar>();
184
185 // Parse an optional dimension expression.
186 AffineExpr affine;
187 if (succeeded(result: parser.parseOptionalEqual())) {
188 // Parse the dim affine expr, with only any lvl-vars in scope.
189 FAILURE_IF_FAILED(parser.parseAffineExpr(lvlsAndSymbols, affine))
190 }
191 DimExpr expr{affine};
192
193 // Parse an optional slice.
194 SparseTensorDimSliceAttr slice;
195 if (succeeded(result: parser.parseOptionalColon())) {
196 const auto loc = parser.getCurrentLocation();
197 Attribute attr;
198 FAILURE_IF_FAILED(parser.parseAttribute(attr))
199 slice = llvm::dyn_cast<SparseTensorDimSliceAttr>(attr);
200 ERROR_IF(!slice, "expected SparseTensorDimSliceAttr")
201 }
202
203 dimSpecs.emplace_back(var, expr, slice);
204 return success();
205}
206
207//===----------------------------------------------------------------------===//
208// `DimLvlMapParser` implementation for `LvlSpec`.
209//===----------------------------------------------------------------------===//
210
211ParseResult DimLvlMapParser::parseLvlSpecList() {
212 // This method currently only supports two syntaxes:
213 //
214 // (1) There are no forward-declarations, and no lvl-var bindings:
215 // (d0, d1) -> (d0 : dense, d1 : compressed)
216 // Therefore `parseLvlVarBindingList` didn't bind any lvl-vars, and thus
217 // `parseLvlSpec` will need to use `VarEnv::bindUnusedVar` to ensure that
218 // the level-rank is correct at the end of parsing.
219 //
220 // (2) There are forward-declarations, and every lvl-spec must have
221 // a lvl-var binding:
222 // {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed)
223 // However, this introduces duplicate information since the order of
224 // the lvl-vars in `parseLvlVarBindingList` must agree with their order
225 // in the list of lvl-specs. Therefore, `parseLvlSpec` will not call
226 // `VarEnv::bindVar` (since `parseLvlVarBindingList` already did so),
227 // and must also validate the consistency between the two lvl-var orders.
228 const auto declaredLvlRank = env.getRanks().getLvlRank();
229 const bool requireLvlVarBinding = declaredLvlRank != 0;
230 // Have `ERROR_IF` point to the start of the list.
231 const auto loc = parser.getCurrentLocation();
232 const auto res = parser.parseCommaSeparatedList(
233 delimiter: mlir::OpAsmParser::Delimiter::Paren,
234 parseElementFn: [=]() -> ParseResult { return parseLvlSpec(requireLvlVarBinding); },
235 contextMessage: " in level-specifier list");
236 FAILURE_IF_FAILED(res)
237 const auto specLvlRank = lvlSpecs.size();
238 ERROR_IF(requireLvlVarBinding && specLvlRank != declaredLvlRank,
239 "Level-rank mismatch between forward-declarations and specifiers. "
240 "Declared " +
241 Twine(declaredLvlRank) + " level-variables; but got " +
242 Twine(specLvlRank) + " level-specifiers.")
243 return success();
244}
245
246static inline Twine nth(Var::Num n) {
247 switch (n) {
248 case 1:
249 return "1st";
250 case 2:
251 return "2nd";
252 default:
253 return Twine(n) + "th";
254 }
255}
256
257FailureOr<LvlVar>
258DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
259 // Nothing to parse, just bind an unnamed variable.
260 if (!requireLvlVarBinding)
261 return env.bindUnusedVar(vk: VarKind::Level).cast<LvlVar>();
262
263 const auto loc = parser.getCurrentLocation();
264 // NOTE: Calling `parseVarUsage` here is semantically inappropriate,
265 // since the thing we're parsing is supposed to be a variable *binding*
266 // rather than a variable *use*. However, the call to `VarEnv::bindVar`
267 // (and its corresponding call to `DimLvlMapParser::recordVarBinding`)
268 // already occured in `parseLvlVarBindingList`, and therefore we must
269 // use `parseVarUsage` here in order to operationally do the right thing.
270 const auto varID = parseVarUsage(vk: VarKind::Level, /*requireKnown=*/true);
271 FAILURE_IF_FAILED(varID)
272 const auto &info = std::as_const(t&: env).access(id: *varID);
273 const auto var = info.getVar().cast<LvlVar>();
274 const auto forwardNum = var.getNum();
275 const auto specNum = lvlSpecs.size();
276 ERROR_IF(forwardNum != specNum,
277 "Level-variable ordering mismatch. The variable '" + info.getName() +
278 "' was forward-declared as the " + nth(forwardNum) +
279 " level; but is bound by the " + nth(specNum) +
280 " specification.")
281 FAILURE_IF_FAILED(parser.parseEqual())
282 return var;
283}
284
285ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
286 // Parse the optional lvl-var binding. `requireLvlVarBinding`
287 // specifies whether that "optional" is actually Must or MustNot.
288 const auto varRes = parseLvlVarBinding(requireLvlVarBinding);
289 FAILURE_IF_FAILED(varRes)
290 const LvlVar var = *varRes;
291
292 // Parse the lvl affine expr, with only the dim-vars in scope.
293 AffineExpr affine;
294 FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
295 LvlExpr expr{affine};
296
297 FAILURE_IF_FAILED(parser.parseColon())
298 const auto type = lvlTypeParser.parseLvlType(parser);
299 FAILURE_IF_FAILED(type)
300
301 lvlSpecs.emplace_back(Args: var, Args&: expr, Args: static_cast<LevelType>(*type));
302 return success();
303}
304
305//===----------------------------------------------------------------------===//
306

source code of mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp