1//===- AsmParserState.cpp -------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/AsmParser/AsmParserState.h"
10#include "mlir/IR/Attributes.h"
11#include "mlir/IR/Operation.h"
12#include "mlir/IR/SymbolTable.h"
13#include "mlir/IR/Types.h"
14#include "mlir/IR/Value.h"
15#include "mlir/Support/LLVM.h"
16#include "mlir/Support/LogicalResult.h"
17#include "llvm/ADT/ArrayRef.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/StringExtras.h"
20#include "llvm/ADT/StringMap.h"
21#include "llvm/ADT/iterator.h"
22#include "llvm/Support/ErrorHandling.h"
23#include <cassert>
24#include <cctype>
25#include <memory>
26#include <utility>
27
28using namespace mlir;
29
30//===----------------------------------------------------------------------===//
31// AsmParserState::Impl
32//===----------------------------------------------------------------------===//
33
34struct AsmParserState::Impl {
35 /// A map from a SymbolRefAttr to a range of uses.
36 using SymbolUseMap =
37 DenseMap<Attribute, SmallVector<SmallVector<SMRange>, 0>>;
38
39 struct PartialOpDef {
40 explicit PartialOpDef(const OperationName &opName) {
41 if (opName.hasTrait<OpTrait::SymbolTable>())
42 symbolTable = std::make_unique<SymbolUseMap>();
43 }
44
45 /// Return if this operation is a symbol table.
46 bool isSymbolTable() const { return symbolTable.get(); }
47
48 /// If this operation is a symbol table, the following contains symbol uses
49 /// within this operation.
50 std::unique_ptr<SymbolUseMap> symbolTable;
51 };
52
53 /// Resolve any symbol table uses in the IR.
54 void resolveSymbolUses();
55
56 /// A mapping from operations in the input source file to their parser state.
57 SmallVector<std::unique_ptr<OperationDefinition>> operations;
58 DenseMap<Operation *, unsigned> operationToIdx;
59
60 /// A mapping from blocks in the input source file to their parser state.
61 SmallVector<std::unique_ptr<BlockDefinition>> blocks;
62 DenseMap<Block *, unsigned> blocksToIdx;
63
64 /// A mapping from aliases in the input source file to their parser state.
65 SmallVector<std::unique_ptr<AttributeAliasDefinition>> attrAliases;
66 SmallVector<std::unique_ptr<TypeAliasDefinition>> typeAliases;
67 llvm::StringMap<unsigned> attrAliasToIdx;
68 llvm::StringMap<unsigned> typeAliasToIdx;
69
70 /// A set of value definitions that are placeholders for forward references.
71 /// This map should be empty if the parser finishes successfully.
72 DenseMap<Value, SmallVector<SMLoc>> placeholderValueUses;
73
74 /// The symbol table operations within the IR.
75 SmallVector<std::pair<Operation *, std::unique_ptr<SymbolUseMap>>>
76 symbolTableOperations;
77
78 /// A stack of partial operation definitions that have been started but not
79 /// yet finalized.
80 SmallVector<PartialOpDef> partialOperations;
81
82 /// A stack of symbol use scopes. This is used when collecting symbol table
83 /// uses during parsing.
84 SmallVector<SymbolUseMap *> symbolUseScopes;
85
86 /// A symbol table containing all of the symbol table operations in the IR.
87 SymbolTableCollection symbolTable;
88};
89
90void AsmParserState::Impl::resolveSymbolUses() {
91 SmallVector<Operation *> symbolOps;
92 for (auto &opAndUseMapIt : symbolTableOperations) {
93 for (auto &it : *opAndUseMapIt.second) {
94 symbolOps.clear();
95 if (failed(symbolTable.lookupSymbolIn(
96 opAndUseMapIt.first, cast<SymbolRefAttr>(it.first), symbolOps)))
97 continue;
98
99 for (ArrayRef<SMRange> useRange : it.second) {
100 for (const auto &symIt : llvm::zip(t&: symbolOps, u&: useRange)) {
101 auto opIt = operationToIdx.find(Val: std::get<0>(t: symIt));
102 if (opIt != operationToIdx.end())
103 operations[opIt->second]->symbolUses.push_back(Elt: std::get<1>(t: symIt));
104 }
105 }
106 }
107 }
108}
109
110//===----------------------------------------------------------------------===//
111// AsmParserState
112//===----------------------------------------------------------------------===//
113
114AsmParserState::AsmParserState() : impl(std::make_unique<Impl>()) {}
115AsmParserState::~AsmParserState() = default;
116AsmParserState &AsmParserState::operator=(AsmParserState &&other) {
117 impl = std::move(other.impl);
118 return *this;
119}
120
121//===----------------------------------------------------------------------===//
122// Access State
123
124auto AsmParserState::getBlockDefs() const -> iterator_range<BlockDefIterator> {
125 return llvm::make_pointee_range(Range: llvm::ArrayRef(impl->blocks));
126}
127
128auto AsmParserState::getBlockDef(Block *block) const
129 -> const BlockDefinition * {
130 auto it = impl->blocksToIdx.find(Val: block);
131 return it == impl->blocksToIdx.end() ? nullptr : &*impl->blocks[it->second];
132}
133
134auto AsmParserState::getOpDefs() const -> iterator_range<OperationDefIterator> {
135 return llvm::make_pointee_range(Range: llvm::ArrayRef(impl->operations));
136}
137
138auto AsmParserState::getOpDef(Operation *op) const
139 -> const OperationDefinition * {
140 auto it = impl->operationToIdx.find(Val: op);
141 return it == impl->operationToIdx.end() ? nullptr
142 : &*impl->operations[it->second];
143}
144
145auto AsmParserState::getAttributeAliasDefs() const
146 -> iterator_range<AttributeDefIterator> {
147 return llvm::make_pointee_range(Range: ArrayRef(impl->attrAliases));
148}
149
150auto AsmParserState::getAttributeAliasDef(StringRef name) const
151 -> const AttributeAliasDefinition * {
152 auto it = impl->attrAliasToIdx.find(Key: name);
153 return it == impl->attrAliasToIdx.end() ? nullptr
154 : &*impl->attrAliases[it->second];
155}
156
157auto AsmParserState::getTypeAliasDefs() const
158 -> iterator_range<TypeDefIterator> {
159 return llvm::make_pointee_range(Range: ArrayRef(impl->typeAliases));
160}
161
162auto AsmParserState::getTypeAliasDef(StringRef name) const
163 -> const TypeAliasDefinition * {
164 auto it = impl->typeAliasToIdx.find(Key: name);
165 return it == impl->typeAliasToIdx.end() ? nullptr
166 : &*impl->typeAliases[it->second];
167}
168
169/// Lex a string token whose contents start at the given `curPtr`. Returns the
170/// position at the end of the string, after a terminal or invalid character
171/// (e.g. `"` or `\0`).
172static const char *lexLocStringTok(const char *curPtr) {
173 while (char c = *curPtr++) {
174 // Check for various terminal characters.
175 if (StringRef("\"\n\v\f").contains(C: c))
176 return curPtr;
177
178 // Check for escape sequences.
179 if (c == '\\') {
180 // Check a few known escapes and \xx hex digits.
181 if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' || *curPtr == 't')
182 ++curPtr;
183 else if (llvm::isHexDigit(C: *curPtr) && llvm::isHexDigit(C: curPtr[1]))
184 curPtr += 2;
185 else
186 return curPtr;
187 }
188 }
189
190 // If we hit this point, we've reached the end of the buffer. Update the end
191 // pointer to not point past the buffer.
192 return curPtr - 1;
193}
194
195SMRange AsmParserState::convertIdLocToRange(SMLoc loc) {
196 if (!loc.isValid())
197 return SMRange();
198 const char *curPtr = loc.getPointer();
199
200 // Check if this is a string token.
201 if (*curPtr == '"') {
202 curPtr = lexLocStringTok(curPtr: curPtr + 1);
203
204 // Otherwise, default to handling an identifier.
205 } else {
206 // Return if the given character is a valid identifier character.
207 auto isIdentifierChar = [](char c) {
208 return isalnum(c) || c == '$' || c == '.' || c == '_' || c == '-';
209 };
210
211 while (*curPtr && isIdentifierChar(*(++curPtr)))
212 continue;
213 }
214
215 return SMRange(loc, SMLoc::getFromPointer(Ptr: curPtr));
216}
217
218//===----------------------------------------------------------------------===//
219// Populate State
220
221void AsmParserState::initialize(Operation *topLevelOp) {
222 startOperationDefinition(opName: topLevelOp->getName());
223
224 // If the top-level operation is a symbol table, push a new symbol scope.
225 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
226 if (partialOpDef.isSymbolTable())
227 impl->symbolUseScopes.push_back(Elt: partialOpDef.symbolTable.get());
228}
229
230void AsmParserState::finalize(Operation *topLevelOp) {
231 assert(!impl->partialOperations.empty() &&
232 "expected valid partial operation definition");
233 Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
234
235 // If this operation is a symbol table, resolve any symbol uses.
236 if (partialOpDef.isSymbolTable()) {
237 impl->symbolTableOperations.emplace_back(
238 Args&: topLevelOp, Args: std::move(partialOpDef.symbolTable));
239 }
240 impl->resolveSymbolUses();
241}
242
243void AsmParserState::startOperationDefinition(const OperationName &opName) {
244 impl->partialOperations.emplace_back(Args: opName);
245}
246
247void AsmParserState::finalizeOperationDefinition(
248 Operation *op, SMRange nameLoc, SMLoc endLoc,
249 ArrayRef<std::pair<unsigned, SMLoc>> resultGroups) {
250 assert(!impl->partialOperations.empty() &&
251 "expected valid partial operation definition");
252 Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
253
254 // Build the full operation definition.
255 std::unique_ptr<OperationDefinition> def =
256 std::make_unique<OperationDefinition>(args&: op, args&: nameLoc, args&: endLoc);
257 for (auto &resultGroup : resultGroups)
258 def->resultGroups.emplace_back(Args: resultGroup.first,
259 Args: convertIdLocToRange(loc: resultGroup.second));
260 impl->operationToIdx.try_emplace(Key: op, Args: impl->operations.size());
261 impl->operations.emplace_back(Args: std::move(def));
262
263 // If this operation is a symbol table, resolve any symbol uses.
264 if (partialOpDef.isSymbolTable()) {
265 impl->symbolTableOperations.emplace_back(
266 Args&: op, Args: std::move(partialOpDef.symbolTable));
267 }
268}
269
270void AsmParserState::startRegionDefinition() {
271 assert(!impl->partialOperations.empty() &&
272 "expected valid partial operation definition");
273
274 // If the parent operation of this region is a symbol table, we also push a
275 // new symbol scope.
276 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
277 if (partialOpDef.isSymbolTable())
278 impl->symbolUseScopes.push_back(Elt: partialOpDef.symbolTable.get());
279}
280
281void AsmParserState::finalizeRegionDefinition() {
282 assert(!impl->partialOperations.empty() &&
283 "expected valid partial operation definition");
284
285 // If the parent operation of this region is a symbol table, pop the symbol
286 // scope for this region.
287 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
288 if (partialOpDef.isSymbolTable())
289 impl->symbolUseScopes.pop_back();
290}
291
292void AsmParserState::addDefinition(Block *block, SMLoc location) {
293 auto it = impl->blocksToIdx.find(Val: block);
294 if (it == impl->blocksToIdx.end()) {
295 impl->blocksToIdx.try_emplace(Key: block, Args: impl->blocks.size());
296 impl->blocks.emplace_back(Args: std::make_unique<BlockDefinition>(
297 args&: block, args: convertIdLocToRange(loc: location)));
298 return;
299 }
300
301 // If an entry already exists, this was a forward declaration that now has a
302 // proper definition.
303 impl->blocks[it->second]->definition.loc = convertIdLocToRange(loc: location);
304}
305
306void AsmParserState::addDefinition(BlockArgument blockArg, SMLoc location) {
307 auto it = impl->blocksToIdx.find(Val: blockArg.getOwner());
308 assert(it != impl->blocksToIdx.end() &&
309 "expected owner block to have an entry");
310 BlockDefinition &def = *impl->blocks[it->second];
311 unsigned argIdx = blockArg.getArgNumber();
312
313 if (def.arguments.size() <= argIdx)
314 def.arguments.resize(N: argIdx + 1);
315 def.arguments[argIdx] = SMDefinition(convertIdLocToRange(loc: location));
316}
317
318void AsmParserState::addAttrAliasDefinition(StringRef name, SMRange location,
319 Attribute value) {
320 auto [it, inserted] =
321 impl->attrAliasToIdx.try_emplace(Key: name, Args: impl->attrAliases.size());
322 // Location aliases may be referenced before they are defined.
323 if (inserted) {
324 impl->attrAliases.push_back(
325 Elt: std::make_unique<AttributeAliasDefinition>(args&: name, args&: location, args&: value));
326 } else {
327 AttributeAliasDefinition &attr = *impl->attrAliases[it->second];
328 attr.definition.loc = location;
329 attr.value = value;
330 }
331}
332
333void AsmParserState::addTypeAliasDefinition(StringRef name, SMRange location,
334 Type value) {
335 [[maybe_unused]] auto [it, inserted] =
336 impl->typeAliasToIdx.try_emplace(Key: name, Args: impl->typeAliases.size());
337 assert(inserted && "unexpected attribute alias redefinition");
338 impl->typeAliases.push_back(
339 Elt: std::make_unique<TypeAliasDefinition>(args&: name, args&: location, args&: value));
340}
341
342void AsmParserState::addUses(Value value, ArrayRef<SMLoc> locations) {
343 // Handle the case where the value is an operation result.
344 if (OpResult result = dyn_cast<OpResult>(Val&: value)) {
345 // Check to see if a definition for the parent operation has been recorded.
346 // If one hasn't, we treat the provided value as a placeholder value that
347 // will be refined further later.
348 Operation *parentOp = result.getOwner();
349 auto existingIt = impl->operationToIdx.find(Val: parentOp);
350 if (existingIt == impl->operationToIdx.end()) {
351 impl->placeholderValueUses[value].append(in_start: locations.begin(),
352 in_end: locations.end());
353 return;
354 }
355
356 // If a definition does exist, locate the value's result group and add the
357 // use. The result groups are ordered by increasing start index, so we just
358 // need to find the last group that has a smaller/equal start index.
359 unsigned resultNo = result.getResultNumber();
360 OperationDefinition &def = *impl->operations[existingIt->second];
361 for (auto &resultGroup : llvm::reverse(C&: def.resultGroups)) {
362 if (resultNo >= resultGroup.startIndex) {
363 for (SMLoc loc : locations)
364 resultGroup.definition.uses.push_back(Elt: convertIdLocToRange(loc));
365 return;
366 }
367 }
368 llvm_unreachable("expected valid result group for value use");
369 }
370
371 // Otherwise, this is a block argument.
372 BlockArgument arg = cast<BlockArgument>(Val&: value);
373 auto existingIt = impl->blocksToIdx.find(Val: arg.getOwner());
374 assert(existingIt != impl->blocksToIdx.end() &&
375 "expected valid block definition for block argument");
376 BlockDefinition &blockDef = *impl->blocks[existingIt->second];
377 SMDefinition &argDef = blockDef.arguments[arg.getArgNumber()];
378 for (SMLoc loc : locations)
379 argDef.uses.emplace_back(Args: convertIdLocToRange(loc));
380}
381
382void AsmParserState::addUses(Block *block, ArrayRef<SMLoc> locations) {
383 auto it = impl->blocksToIdx.find(Val: block);
384 if (it == impl->blocksToIdx.end()) {
385 it = impl->blocksToIdx.try_emplace(Key: block, Args: impl->blocks.size()).first;
386 impl->blocks.emplace_back(Args: std::make_unique<BlockDefinition>(args&: block));
387 }
388
389 BlockDefinition &def = *impl->blocks[it->second];
390 for (SMLoc loc : locations)
391 def.definition.uses.push_back(Elt: convertIdLocToRange(loc));
392}
393
394void AsmParserState::addUses(SymbolRefAttr refAttr,
395 ArrayRef<SMRange> locations) {
396 // Ignore this symbol if no scopes are active.
397 if (impl->symbolUseScopes.empty())
398 return;
399
400 assert((refAttr.getNestedReferences().size() + 1) == locations.size() &&
401 "expected the same number of references as provided locations");
402 (*impl->symbolUseScopes.back())[refAttr].emplace_back(locations.begin(),
403 locations.end());
404}
405
406void AsmParserState::addAttrAliasUses(StringRef name, SMRange location) {
407 auto it = impl->attrAliasToIdx.find(Key: name);
408 // Location aliases may be referenced before they are defined.
409 if (it == impl->attrAliasToIdx.end()) {
410 it = impl->attrAliasToIdx.try_emplace(Key: name, Args: impl->attrAliases.size()).first;
411 impl->attrAliases.push_back(
412 Elt: std::make_unique<AttributeAliasDefinition>(args&: name));
413 }
414 AttributeAliasDefinition &def = *impl->attrAliases[it->second];
415 def.definition.uses.push_back(Elt: location);
416}
417
418void AsmParserState::addTypeAliasUses(StringRef name, SMRange location) {
419 auto it = impl->typeAliasToIdx.find(Key: name);
420 // Location aliases may be referenced before they are defined.
421 assert(it != impl->typeAliasToIdx.end() &&
422 "expected valid type alias definition");
423 TypeAliasDefinition &def = *impl->typeAliases[it->second];
424 def.definition.uses.push_back(Elt: location);
425}
426
427void AsmParserState::refineDefinition(Value oldValue, Value newValue) {
428 auto it = impl->placeholderValueUses.find(Val: oldValue);
429 assert(it != impl->placeholderValueUses.end() &&
430 "expected `oldValue` to be a placeholder");
431 addUses(value: newValue, locations: it->second);
432 impl->placeholderValueUses.erase(Val: oldValue);
433}
434

source code of mlir/lib/AsmParser/AsmParserState.cpp