1//===- AsmParserState.cpp -------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/AsmParser/AsmParserState.h"
10#include "mlir/IR/Attributes.h"
11#include "mlir/IR/Operation.h"
12#include "mlir/IR/SymbolTable.h"
13#include "mlir/IR/Types.h"
14#include "mlir/IR/Value.h"
15#include "mlir/Support/LLVM.h"
16#include "llvm/ADT/ArrayRef.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/StringExtras.h"
19#include "llvm/ADT/StringMap.h"
20#include "llvm/ADT/iterator.h"
21#include "llvm/Support/ErrorHandling.h"
22#include <cassert>
23#include <cctype>
24#include <memory>
25#include <utility>
26
27using namespace mlir;
28
29//===----------------------------------------------------------------------===//
30// AsmParserState::Impl
31//===----------------------------------------------------------------------===//
32
33struct AsmParserState::Impl {
34 /// A map from a SymbolRefAttr to a range of uses.
35 using SymbolUseMap =
36 DenseMap<Attribute, SmallVector<SmallVector<SMRange>, 0>>;
37
38 struct PartialOpDef {
39 explicit PartialOpDef(const OperationName &opName) {
40 if (opName.hasTrait<OpTrait::SymbolTable>())
41 symbolTable = std::make_unique<SymbolUseMap>();
42 }
43
44 /// Return if this operation is a symbol table.
45 bool isSymbolTable() const { return symbolTable.get(); }
46
47 /// If this operation is a symbol table, the following contains symbol uses
48 /// within this operation.
49 std::unique_ptr<SymbolUseMap> symbolTable;
50 };
51
52 /// Resolve any symbol table uses in the IR.
53 void resolveSymbolUses();
54
55 /// A mapping from operations in the input source file to their parser state.
56 SmallVector<std::unique_ptr<OperationDefinition>> operations;
57 DenseMap<Operation *, unsigned> operationToIdx;
58
59 /// A mapping from blocks in the input source file to their parser state.
60 SmallVector<std::unique_ptr<BlockDefinition>> blocks;
61 DenseMap<Block *, unsigned> blocksToIdx;
62
63 /// A mapping from aliases in the input source file to their parser state.
64 SmallVector<std::unique_ptr<AttributeAliasDefinition>> attrAliases;
65 SmallVector<std::unique_ptr<TypeAliasDefinition>> typeAliases;
66 llvm::StringMap<unsigned> attrAliasToIdx;
67 llvm::StringMap<unsigned> typeAliasToIdx;
68
69 /// A set of value definitions that are placeholders for forward references.
70 /// This map should be empty if the parser finishes successfully.
71 DenseMap<Value, SmallVector<SMLoc>> placeholderValueUses;
72
73 /// The symbol table operations within the IR.
74 SmallVector<std::pair<Operation *, std::unique_ptr<SymbolUseMap>>>
75 symbolTableOperations;
76
77 /// A stack of partial operation definitions that have been started but not
78 /// yet finalized.
79 SmallVector<PartialOpDef> partialOperations;
80
81 /// A stack of symbol use scopes. This is used when collecting symbol table
82 /// uses during parsing.
83 SmallVector<SymbolUseMap *> symbolUseScopes;
84
85 /// A symbol table containing all of the symbol table operations in the IR.
86 SymbolTableCollection symbolTable;
87};
88
89void AsmParserState::Impl::resolveSymbolUses() {
90 SmallVector<Operation *> symbolOps;
91 for (auto &opAndUseMapIt : symbolTableOperations) {
92 for (auto &it : *opAndUseMapIt.second) {
93 symbolOps.clear();
94 if (failed(symbolTable.lookupSymbolIn(
95 opAndUseMapIt.first, cast<SymbolRefAttr>(it.first), symbolOps)))
96 continue;
97
98 for (ArrayRef<SMRange> useRange : it.second) {
99 for (const auto &symIt : llvm::zip(t&: symbolOps, u&: useRange)) {
100 auto opIt = operationToIdx.find(Val: std::get<0>(t: symIt));
101 if (opIt != operationToIdx.end())
102 operations[opIt->second]->symbolUses.push_back(Elt: std::get<1>(t: symIt));
103 }
104 }
105 }
106 }
107}
108
109//===----------------------------------------------------------------------===//
110// AsmParserState
111//===----------------------------------------------------------------------===//
112
113AsmParserState::AsmParserState() : impl(std::make_unique<Impl>()) {}
114AsmParserState::~AsmParserState() = default;
115AsmParserState &AsmParserState::operator=(AsmParserState &&other) {
116 impl = std::move(other.impl);
117 return *this;
118}
119
120//===----------------------------------------------------------------------===//
121// Access State
122//===----------------------------------------------------------------------===//
123
124auto AsmParserState::getBlockDefs() const -> iterator_range<BlockDefIterator> {
125 return llvm::make_pointee_range(Range: llvm::ArrayRef(impl->blocks));
126}
127
128auto AsmParserState::getBlockDef(Block *block) const
129 -> const BlockDefinition * {
130 auto it = impl->blocksToIdx.find(Val: block);
131 return it == impl->blocksToIdx.end() ? nullptr : &*impl->blocks[it->second];
132}
133
134auto AsmParserState::getOpDefs() const -> iterator_range<OperationDefIterator> {
135 return llvm::make_pointee_range(Range: llvm::ArrayRef(impl->operations));
136}
137
138auto AsmParserState::getOpDef(Operation *op) const
139 -> const OperationDefinition * {
140 auto it = impl->operationToIdx.find(Val: op);
141 return it == impl->operationToIdx.end() ? nullptr
142 : &*impl->operations[it->second];
143}
144
145auto AsmParserState::getAttributeAliasDefs() const
146 -> iterator_range<AttributeDefIterator> {
147 return llvm::make_pointee_range(Range: ArrayRef(impl->attrAliases));
148}
149
150auto AsmParserState::getAttributeAliasDef(StringRef name) const
151 -> const AttributeAliasDefinition * {
152 auto it = impl->attrAliasToIdx.find(Key: name);
153 return it == impl->attrAliasToIdx.end() ? nullptr
154 : &*impl->attrAliases[it->second];
155}
156
157auto AsmParserState::getTypeAliasDefs() const
158 -> iterator_range<TypeDefIterator> {
159 return llvm::make_pointee_range(Range: ArrayRef(impl->typeAliases));
160}
161
162auto AsmParserState::getTypeAliasDef(StringRef name) const
163 -> const TypeAliasDefinition * {
164 auto it = impl->typeAliasToIdx.find(Key: name);
165 return it == impl->typeAliasToIdx.end() ? nullptr
166 : &*impl->typeAliases[it->second];
167}
168
169/// Lex a string token whose contents start at the given `curPtr`. Returns the
170/// position at the end of the string, after a terminal or invalid character
171/// (e.g. `"` or `\0`).
172static const char *lexLocStringTok(const char *curPtr) {
173 while (char c = *curPtr++) {
174 // Check for various terminal characters.
175 if (StringRef("\"\n\v\f").contains(C: c))
176 return curPtr;
177
178 // Check for escape sequences.
179 if (c == '\\') {
180 // Check a few known escapes and \xx hex digits.
181 if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' || *curPtr == 't')
182 ++curPtr;
183 else if (llvm::isHexDigit(C: *curPtr) && llvm::isHexDigit(C: curPtr[1]))
184 curPtr += 2;
185 else
186 return curPtr;
187 }
188 }
189
190 // If we hit this point, we've reached the end of the buffer. Update the end
191 // pointer to not point past the buffer.
192 return curPtr - 1;
193}
194
195SMRange AsmParserState::convertIdLocToRange(SMLoc loc) {
196 if (!loc.isValid())
197 return SMRange();
198 const char *curPtr = loc.getPointer();
199
200 // Check if this is a string token.
201 if (*curPtr == '"') {
202 curPtr = lexLocStringTok(curPtr: curPtr + 1);
203
204 // Otherwise, default to handling an identifier.
205 } else {
206 // Return if the given character is a valid identifier character.
207 auto isIdentifierChar = [](char c) {
208 return isalnum(c) || c == '$' || c == '.' || c == '_' || c == '-';
209 };
210
211 while (*curPtr && isIdentifierChar(*(++curPtr)))
212 continue;
213 }
214
215 return SMRange(loc, SMLoc::getFromPointer(Ptr: curPtr));
216}
217
218//===----------------------------------------------------------------------===//
219// Populate State
220//===----------------------------------------------------------------------===//
221
222void AsmParserState::initialize(Operation *topLevelOp) {
223 startOperationDefinition(opName: topLevelOp->getName());
224
225 // If the top-level operation is a symbol table, push a new symbol scope.
226 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
227 if (partialOpDef.isSymbolTable())
228 impl->symbolUseScopes.push_back(Elt: partialOpDef.symbolTable.get());
229}
230
231void AsmParserState::finalize(Operation *topLevelOp) {
232 assert(!impl->partialOperations.empty() &&
233 "expected valid partial operation definition");
234 Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
235
236 // If this operation is a symbol table, resolve any symbol uses.
237 if (partialOpDef.isSymbolTable()) {
238 impl->symbolTableOperations.emplace_back(
239 Args&: topLevelOp, Args: std::move(partialOpDef.symbolTable));
240 }
241 impl->resolveSymbolUses();
242}
243
244void AsmParserState::startOperationDefinition(const OperationName &opName) {
245 impl->partialOperations.emplace_back(Args: opName);
246}
247
248void AsmParserState::finalizeOperationDefinition(
249 Operation *op, SMRange nameLoc, SMLoc endLoc,
250 ArrayRef<std::pair<unsigned, SMLoc>> resultGroups) {
251 assert(!impl->partialOperations.empty() &&
252 "expected valid partial operation definition");
253 Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
254
255 // Build the full operation definition.
256 std::unique_ptr<OperationDefinition> def =
257 std::make_unique<OperationDefinition>(args&: op, args&: nameLoc, args&: endLoc);
258 for (auto &resultGroup : resultGroups)
259 def->resultGroups.emplace_back(Args: resultGroup.first,
260 Args: convertIdLocToRange(loc: resultGroup.second));
261 impl->operationToIdx.try_emplace(Key: op, Args: impl->operations.size());
262 impl->operations.emplace_back(Args: std::move(def));
263
264 // If this operation is a symbol table, resolve any symbol uses.
265 if (partialOpDef.isSymbolTable()) {
266 impl->symbolTableOperations.emplace_back(
267 Args&: op, Args: std::move(partialOpDef.symbolTable));
268 }
269}
270
271void AsmParserState::startRegionDefinition() {
272 assert(!impl->partialOperations.empty() &&
273 "expected valid partial operation definition");
274
275 // If the parent operation of this region is a symbol table, we also push a
276 // new symbol scope.
277 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
278 if (partialOpDef.isSymbolTable())
279 impl->symbolUseScopes.push_back(Elt: partialOpDef.symbolTable.get());
280}
281
282void AsmParserState::finalizeRegionDefinition() {
283 assert(!impl->partialOperations.empty() &&
284 "expected valid partial operation definition");
285
286 // If the parent operation of this region is a symbol table, pop the symbol
287 // scope for this region.
288 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
289 if (partialOpDef.isSymbolTable())
290 impl->symbolUseScopes.pop_back();
291}
292
293void AsmParserState::addDefinition(Block *block, SMLoc location) {
294 auto [it, inserted] =
295 impl->blocksToIdx.try_emplace(Key: block, Args: impl->blocks.size());
296 if (inserted) {
297 impl->blocks.emplace_back(Args: std::make_unique<BlockDefinition>(
298 args&: block, args: convertIdLocToRange(loc: location)));
299 return;
300 }
301
302 // If an entry already exists, this was a forward declaration that now has a
303 // proper definition.
304 impl->blocks[it->second]->definition.loc = convertIdLocToRange(loc: location);
305}
306
307void AsmParserState::addDefinition(BlockArgument blockArg, SMLoc location) {
308 auto it = impl->blocksToIdx.find(Val: blockArg.getOwner());
309 assert(it != impl->blocksToIdx.end() &&
310 "expected owner block to have an entry");
311 BlockDefinition &def = *impl->blocks[it->second];
312 unsigned argIdx = blockArg.getArgNumber();
313
314 if (def.arguments.size() <= argIdx)
315 def.arguments.resize(N: argIdx + 1);
316 def.arguments[argIdx] = SMDefinition(convertIdLocToRange(loc: location));
317}
318
319void AsmParserState::addAttrAliasDefinition(StringRef name, SMRange location,
320 Attribute value) {
321 auto [it, inserted] =
322 impl->attrAliasToIdx.try_emplace(Key: name, Args: impl->attrAliases.size());
323 // Location aliases may be referenced before they are defined.
324 if (inserted) {
325 impl->attrAliases.push_back(
326 Elt: std::make_unique<AttributeAliasDefinition>(args&: name, args&: location, args&: value));
327 } else {
328 AttributeAliasDefinition &attr = *impl->attrAliases[it->second];
329 attr.definition.loc = location;
330 attr.value = value;
331 }
332}
333
334void AsmParserState::addTypeAliasDefinition(StringRef name, SMRange location,
335 Type value) {
336 [[maybe_unused]] auto [it, inserted] =
337 impl->typeAliasToIdx.try_emplace(Key: name, Args: impl->typeAliases.size());
338 assert(inserted && "unexpected attribute alias redefinition");
339 impl->typeAliases.push_back(
340 Elt: std::make_unique<TypeAliasDefinition>(args&: name, args&: location, args&: value));
341}
342
343void AsmParserState::addUses(Value value, ArrayRef<SMLoc> locations) {
344 // Handle the case where the value is an operation result.
345 if (OpResult result = dyn_cast<OpResult>(Val&: value)) {
346 // Check to see if a definition for the parent operation has been recorded.
347 // If one hasn't, we treat the provided value as a placeholder value that
348 // will be refined further later.
349 Operation *parentOp = result.getOwner();
350 auto existingIt = impl->operationToIdx.find(Val: parentOp);
351 if (existingIt == impl->operationToIdx.end()) {
352 impl->placeholderValueUses[value].append(in_start: locations.begin(),
353 in_end: locations.end());
354 return;
355 }
356
357 // If a definition does exist, locate the value's result group and add the
358 // use. The result groups are ordered by increasing start index, so we just
359 // need to find the last group that has a smaller/equal start index.
360 unsigned resultNo = result.getResultNumber();
361 OperationDefinition &def = *impl->operations[existingIt->second];
362 for (auto &resultGroup : llvm::reverse(C&: def.resultGroups)) {
363 if (resultNo >= resultGroup.startIndex) {
364 for (SMLoc loc : locations)
365 resultGroup.definition.uses.push_back(Elt: convertIdLocToRange(loc));
366 return;
367 }
368 }
369 llvm_unreachable("expected valid result group for value use");
370 }
371
372 // Otherwise, this is a block argument.
373 BlockArgument arg = cast<BlockArgument>(Val&: value);
374 auto existingIt = impl->blocksToIdx.find(Val: arg.getOwner());
375 assert(existingIt != impl->blocksToIdx.end() &&
376 "expected valid block definition for block argument");
377 BlockDefinition &blockDef = *impl->blocks[existingIt->second];
378 SMDefinition &argDef = blockDef.arguments[arg.getArgNumber()];
379 for (SMLoc loc : locations)
380 argDef.uses.emplace_back(Args: convertIdLocToRange(loc));
381}
382
383void AsmParserState::addUses(Block *block, ArrayRef<SMLoc> locations) {
384 auto [it, inserted] =
385 impl->blocksToIdx.try_emplace(Key: block, Args: impl->blocks.size());
386 if (inserted)
387 impl->blocks.emplace_back(Args: std::make_unique<BlockDefinition>(args&: block));
388
389 BlockDefinition &def = *impl->blocks[it->second];
390 for (SMLoc loc : locations)
391 def.definition.uses.push_back(Elt: convertIdLocToRange(loc));
392}
393
394void AsmParserState::addUses(SymbolRefAttr refAttr,
395 ArrayRef<SMRange> locations) {
396 // Ignore this symbol if no scopes are active.
397 if (impl->symbolUseScopes.empty())
398 return;
399
400 assert((refAttr.getNestedReferences().size() + 1) == locations.size() &&
401 "expected the same number of references as provided locations");
402 (*impl->symbolUseScopes.back())[refAttr].emplace_back(locations.begin(),
403 locations.end());
404}
405
406void AsmParserState::addAttrAliasUses(StringRef name, SMRange location) {
407 auto it = impl->attrAliasToIdx.find(Key: name);
408 // Location aliases may be referenced before they are defined.
409 if (it == impl->attrAliasToIdx.end()) {
410 it = impl->attrAliasToIdx.try_emplace(Key: name, Args: impl->attrAliases.size()).first;
411 impl->attrAliases.push_back(
412 Elt: std::make_unique<AttributeAliasDefinition>(args&: name));
413 }
414 AttributeAliasDefinition &def = *impl->attrAliases[it->second];
415 def.definition.uses.push_back(Elt: location);
416}
417
418void AsmParserState::addTypeAliasUses(StringRef name, SMRange location) {
419 auto it = impl->typeAliasToIdx.find(Key: name);
420 // Location aliases may be referenced before they are defined.
421 assert(it != impl->typeAliasToIdx.end() &&
422 "expected valid type alias definition");
423 TypeAliasDefinition &def = *impl->typeAliases[it->second];
424 def.definition.uses.push_back(Elt: location);
425}
426
427void AsmParserState::refineDefinition(Value oldValue, Value newValue) {
428 auto it = impl->placeholderValueUses.find(Val: oldValue);
429 assert(it != impl->placeholderValueUses.end() &&
430 "expected `oldValue` to be a placeholder");
431 addUses(value: newValue, locations: it->second);
432 impl->placeholderValueUses.erase(Val: oldValue);
433}
434

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/AsmParser/AsmParserState.cpp