1//===- MatchFinder.cpp - --------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the method definitions for the `MatchFinder` class
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Query/Matcher/MatchFinder.h"
14namespace mlir::query::matcher {
15
16MatchFinder::MatchResult::MatchResult(Operation *rootOp,
17 std::vector<Operation *> matchedOps)
18 : rootOp(rootOp), matchedOps(std::move(matchedOps)) {}
19
20std::vector<MatchFinder::MatchResult>
21MatchFinder::collectMatches(Operation *root, DynMatcher matcher) const {
22 std::vector<MatchResult> results;
23 llvm::SetVector<Operation *> tempStorage;
24 root->walk(callback: [&](Operation *subOp) {
25 if (matcher.match(op: subOp)) {
26 MatchResult match;
27 match.rootOp = subOp;
28 match.matchedOps.push_back(x: subOp);
29 results.push_back(x: std::move(match));
30 } else if (matcher.match(op: subOp, matchedOps&: tempStorage)) {
31 results.emplace_back(args&: subOp, args: std::vector<Operation *>(tempStorage.begin(),
32 tempStorage.end()));
33 }
34 tempStorage.clear();
35 });
36 return results;
37}
38
39void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs,
40 Operation *op) const {
41 auto fileLoc = cast<FileLineColLoc>(Val: op->getLoc());
42 SMLoc smloc = qs.getSourceManager().FindLocForLineAndColumn(
43 BufferID: qs.getBufferId(), LineNo: fileLoc.getLine(), ColNo: fileLoc.getColumn());
44 llvm::SMDiagnostic diag =
45 qs.getSourceManager().GetMessage(Loc: smloc, Kind: llvm::SourceMgr::DK_Note, Msg: "");
46 diag.print(ProgName: "", S&: os, ShowColors: true, ShowKindLabel: false, ShowLocation: true);
47}
48
49void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs,
50 Operation *op, const std::string &binding) const {
51 auto fileLoc = cast<FileLineColLoc>(Val: op->getLoc());
52 auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
53 BufferID: qs.getBufferId(), LineNo: fileLoc.getLine(), ColNo: fileLoc.getColumn());
54 qs.getSourceManager().PrintMessage(OS&: os, Loc: smloc, Kind: llvm::SourceMgr::DK_Note,
55 Msg: "\"" + binding + "\" binds here");
56}
57
58std::vector<Operation *>
59MatchFinder::flattenMatchedOps(std::vector<MatchResult> &matches) const {
60 std::vector<Operation *> newVector;
61 for (auto &result : matches) {
62 newVector.insert(position: newVector.end(), first: result.matchedOps.begin(),
63 last: result.matchedOps.end());
64 }
65 return newVector;
66}
67
68} // namespace mlir::query::matcher
69

source code of mlir/lib/Query/Matcher/MatchFinder.cpp