| 1 | //===- MatchFinder.cpp - --------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file contains the method definitions for the `MatchFinder` class |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Query/Matcher/MatchFinder.h" |
| 14 | namespace mlir::query::matcher { |
| 15 | |
| 16 | MatchFinder::MatchResult::MatchResult(Operation *rootOp, |
| 17 | std::vector<Operation *> matchedOps) |
| 18 | : rootOp(rootOp), matchedOps(std::move(matchedOps)) {} |
| 19 | |
| 20 | std::vector<MatchFinder::MatchResult> |
| 21 | MatchFinder::collectMatches(Operation *root, DynMatcher matcher) const { |
| 22 | std::vector<MatchResult> results; |
| 23 | llvm::SetVector<Operation *> tempStorage; |
| 24 | root->walk(callback: [&](Operation *subOp) { |
| 25 | if (matcher.match(op: subOp)) { |
| 26 | MatchResult match; |
| 27 | match.rootOp = subOp; |
| 28 | match.matchedOps.push_back(x: subOp); |
| 29 | results.push_back(x: std::move(match)); |
| 30 | } else if (matcher.match(op: subOp, matchedOps&: tempStorage)) { |
| 31 | results.emplace_back(args&: subOp, args: std::vector<Operation *>(tempStorage.begin(), |
| 32 | tempStorage.end())); |
| 33 | } |
| 34 | tempStorage.clear(); |
| 35 | }); |
| 36 | return results; |
| 37 | } |
| 38 | |
| 39 | void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs, |
| 40 | Operation *op) const { |
| 41 | auto fileLoc = cast<FileLineColLoc>(Val: op->getLoc()); |
| 42 | SMLoc smloc = qs.getSourceManager().FindLocForLineAndColumn( |
| 43 | BufferID: qs.getBufferId(), LineNo: fileLoc.getLine(), ColNo: fileLoc.getColumn()); |
| 44 | llvm::SMDiagnostic diag = |
| 45 | qs.getSourceManager().GetMessage(Loc: smloc, Kind: llvm::SourceMgr::DK_Note, Msg: "" ); |
| 46 | diag.print(ProgName: "" , S&: os, ShowColors: true, ShowKindLabel: false, ShowLocation: true); |
| 47 | } |
| 48 | |
| 49 | void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs, |
| 50 | Operation *op, const std::string &binding) const { |
| 51 | auto fileLoc = cast<FileLineColLoc>(Val: op->getLoc()); |
| 52 | auto smloc = qs.getSourceManager().FindLocForLineAndColumn( |
| 53 | BufferID: qs.getBufferId(), LineNo: fileLoc.getLine(), ColNo: fileLoc.getColumn()); |
| 54 | qs.getSourceManager().PrintMessage(OS&: os, Loc: smloc, Kind: llvm::SourceMgr::DK_Note, |
| 55 | Msg: "\"" + binding + "\" binds here" ); |
| 56 | } |
| 57 | |
| 58 | std::vector<Operation *> |
| 59 | MatchFinder::flattenMatchedOps(std::vector<MatchResult> &matches) const { |
| 60 | std::vector<Operation *> newVector; |
| 61 | for (auto &result : matches) { |
| 62 | newVector.insert(position: newVector.end(), first: result.matchedOps.begin(), |
| 63 | last: result.matchedOps.end()); |
| 64 | } |
| 65 | return newVector; |
| 66 | } |
| 67 | |
| 68 | } // namespace mlir::query::matcher |
| 69 | |