1//===---- Query.cpp - -----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Query/Query.h"
10#include "QueryParser.h"
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/IR/IRMapping.h"
13#include "mlir/Query/Matcher/MatchFinder.h"
14#include "mlir/Query/QuerySession.h"
15#include "mlir/Support/LogicalResult.h"
16#include "llvm/Support/SourceMgr.h"
17#include "llvm/Support/raw_ostream.h"
18
19namespace mlir::query {
20
21QueryRef parse(llvm::StringRef line, const QuerySession &qs) {
22 return QueryParser::parse(line, qs);
23}
24
25std::vector<llvm::LineEditor::Completion>
26complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
27 return QueryParser::complete(line, pos, qs);
28}
29
30static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
31 const std::string &binding) {
32 auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
33 auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
34 BufferID: qs.getBufferId(), LineNo: fileLoc.getLine(), ColNo: fileLoc.getColumn());
35 qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
36 "\"" + binding + "\" binds here");
37}
38
39// TODO: Extract into a helper function that can be reused outside query
40// context.
41static Operation *extractFunction(std::vector<Operation *> &ops,
42 MLIRContext *context,
43 llvm::StringRef functionName) {
44 context->loadDialect<func::FuncDialect>();
45 OpBuilder builder(context);
46
47 // Collect data for function creation
48 std::vector<Operation *> slice;
49 std::vector<Value> values;
50 std::vector<Type> outputTypes;
51
52 for (auto *op : ops) {
53 // Return op's operands are propagated, but the op itself isn't needed.
54 if (!isa<func::ReturnOp>(op))
55 slice.push_back(x: op);
56
57 // All results are returned by the extracted function.
58 outputTypes.insert(position: outputTypes.end(), first: op->getResults().getTypes().begin(),
59 last: op->getResults().getTypes().end());
60
61 // Track all values that need to be taken as input to function.
62 values.insert(position: values.end(), first: op->getOperands().begin(),
63 last: op->getOperands().end());
64 }
65
66 // Create the function
67 FunctionType funcType =
68 builder.getFunctionType(TypeRange(ValueRange(values)), outputTypes);
69 auto loc = builder.getUnknownLoc();
70 func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
71
72 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
73
74 // Map original values to function arguments
75 IRMapping mapper;
76 for (const auto &arg : llvm::enumerate(First&: values))
77 mapper.map(arg.value(), funcOp.getArgument(arg.index()));
78
79 // Clone operations and build function body
80 std::vector<Operation *> clonedOps;
81 std::vector<Value> clonedVals;
82 for (Operation *slicedOp : slice) {
83 Operation *clonedOp =
84 clonedOps.emplace_back(args: builder.clone(op&: *slicedOp, mapper));
85 clonedVals.insert(position: clonedVals.end(), first: clonedOp->result_begin(),
86 last: clonedOp->result_end());
87 }
88 // Add return operation
89 builder.create<func::ReturnOp>(loc, clonedVals);
90
91 // Remove unused function arguments
92 size_t currentIndex = 0;
93 while (currentIndex < funcOp.getNumArguments()) {
94 if (funcOp.getArgument(currentIndex).use_empty())
95 funcOp.eraseArgument(currentIndex);
96 else
97 ++currentIndex;
98 }
99
100 return funcOp;
101}
102
103Query::~Query() = default;
104
105mlir::LogicalResult InvalidQuery::run(llvm::raw_ostream &os,
106 QuerySession &qs) const {
107 os << errStr << "\n";
108 return mlir::failure();
109}
110
111mlir::LogicalResult NoOpQuery::run(llvm::raw_ostream &os,
112 QuerySession &qs) const {
113 return mlir::success();
114}
115
116mlir::LogicalResult HelpQuery::run(llvm::raw_ostream &os,
117 QuerySession &qs) const {
118 os << "Available commands:\n\n"
119 " match MATCHER, m MATCHER "
120 "Match the mlir against the given matcher.\n"
121 " quit "
122 "Terminates the query session.\n\n";
123 return mlir::success();
124}
125
126mlir::LogicalResult QuitQuery::run(llvm::raw_ostream &os,
127 QuerySession &qs) const {
128 qs.terminate = true;
129 return mlir::success();
130}
131
132mlir::LogicalResult MatchQuery::run(llvm::raw_ostream &os,
133 QuerySession &qs) const {
134 Operation *rootOp = qs.getRootOp();
135 int matchCount = 0;
136 std::vector<Operation *> matches =
137 matcher::MatchFinder().getMatches(root: rootOp, matcher);
138
139 // An extract call is recognized by considering if the matcher has a name.
140 // TODO: Consider making the extract more explicit.
141 if (matcher.hasFunctionName()) {
142 auto functionName = matcher.getFunctionName();
143 Operation *function =
144 extractFunction(ops&: matches, context: rootOp->getContext(), functionName);
145 os << "\n" << *function << "\n\n";
146 function->erase();
147 return mlir::success();
148 }
149
150 os << "\n";
151 for (Operation *op : matches) {
152 os << "Match #" << ++matchCount << ":\n\n";
153 // Placeholder "root" binding for the initial draft.
154 printMatch(os, qs, op, binding: "root");
155 }
156 os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
157
158 return mlir::success();
159}
160
161} // namespace mlir::query
162

source code of mlir/lib/Query/Query.cpp