1 | //===- MatchFinder.cpp - --------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains the method definitions for the `MatchFinder` class |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Query/Matcher/MatchFinder.h" |
14 | namespace mlir::query::matcher { |
15 | |
16 | MatchFinder::MatchResult::MatchResult(Operation *rootOp, |
17 | std::vector<Operation *> matchedOps) |
18 | : rootOp(rootOp), matchedOps(std::move(matchedOps)) {} |
19 | |
20 | std::vector<MatchFinder::MatchResult> |
21 | MatchFinder::collectMatches(Operation *root, DynMatcher matcher) const { |
22 | std::vector<MatchResult> results; |
23 | llvm::SetVector<Operation *> tempStorage; |
24 | root->walk(callback: [&](Operation *subOp) { |
25 | if (matcher.match(op: subOp)) { |
26 | MatchResult match; |
27 | match.rootOp = subOp; |
28 | match.matchedOps.push_back(x: subOp); |
29 | results.push_back(x: std::move(match)); |
30 | } else if (matcher.match(op: subOp, matchedOps&: tempStorage)) { |
31 | results.emplace_back(args&: subOp, args: std::vector<Operation *>(tempStorage.begin(), |
32 | tempStorage.end())); |
33 | } |
34 | tempStorage.clear(); |
35 | }); |
36 | return results; |
37 | } |
38 | |
39 | void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs, |
40 | Operation *op) const { |
41 | auto fileLoc = cast<FileLineColLoc>(Val: op->getLoc()); |
42 | SMLoc smloc = qs.getSourceManager().FindLocForLineAndColumn( |
43 | BufferID: qs.getBufferId(), LineNo: fileLoc.getLine(), ColNo: fileLoc.getColumn()); |
44 | llvm::SMDiagnostic diag = |
45 | qs.getSourceManager().GetMessage(Loc: smloc, Kind: llvm::SourceMgr::DK_Note, Msg: "" ); |
46 | diag.print(ProgName: "" , S&: os, ShowColors: true, ShowKindLabel: false, ShowLocation: true); |
47 | } |
48 | |
49 | void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs, |
50 | Operation *op, const std::string &binding) const { |
51 | auto fileLoc = cast<FileLineColLoc>(Val: op->getLoc()); |
52 | auto smloc = qs.getSourceManager().FindLocForLineAndColumn( |
53 | BufferID: qs.getBufferId(), LineNo: fileLoc.getLine(), ColNo: fileLoc.getColumn()); |
54 | qs.getSourceManager().PrintMessage(OS&: os, Loc: smloc, Kind: llvm::SourceMgr::DK_Note, |
55 | Msg: "\"" + binding + "\" binds here" ); |
56 | } |
57 | |
58 | std::vector<Operation *> |
59 | MatchFinder::flattenMatchedOps(std::vector<MatchResult> &matches) const { |
60 | std::vector<Operation *> newVector; |
61 | for (auto &result : matches) { |
62 | newVector.insert(position: newVector.end(), first: result.matchedOps.begin(), |
63 | last: result.matchedOps.end()); |
64 | } |
65 | return newVector; |
66 | } |
67 | |
68 | } // namespace mlir::query::matcher |
69 | |