1 | //===---- Query.cpp - -----------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Query/Query.h" |
10 | #include "QueryParser.h" |
11 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
12 | #include "mlir/IR/IRMapping.h" |
13 | #include "mlir/Query/Matcher/MatchFinder.h" |
14 | #include "mlir/Query/QuerySession.h" |
15 | #include "mlir/Support/LogicalResult.h" |
16 | #include "llvm/Support/SourceMgr.h" |
17 | #include "llvm/Support/raw_ostream.h" |
18 | |
19 | namespace mlir::query { |
20 | |
21 | QueryRef parse(llvm::StringRef line, const QuerySession &qs) { |
22 | return QueryParser::parse(line, qs); |
23 | } |
24 | |
25 | std::vector<llvm::LineEditor::Completion> |
26 | complete(llvm::StringRef line, size_t pos, const QuerySession &qs) { |
27 | return QueryParser::complete(line, pos, qs); |
28 | } |
29 | |
30 | static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op, |
31 | const std::string &binding) { |
32 | auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>(); |
33 | auto smloc = qs.getSourceManager().FindLocForLineAndColumn( |
34 | BufferID: qs.getBufferId(), LineNo: fileLoc.getLine(), ColNo: fileLoc.getColumn()); |
35 | qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note, |
36 | "\"" + binding + "\" binds here" ); |
37 | } |
38 | |
39 | // TODO: Extract into a helper function that can be reused outside query |
40 | // context. |
41 | static Operation *(std::vector<Operation *> &ops, |
42 | MLIRContext *context, |
43 | llvm::StringRef functionName) { |
44 | context->loadDialect<func::FuncDialect>(); |
45 | OpBuilder builder(context); |
46 | |
47 | // Collect data for function creation |
48 | std::vector<Operation *> slice; |
49 | std::vector<Value> values; |
50 | std::vector<Type> outputTypes; |
51 | |
52 | for (auto *op : ops) { |
53 | // Return op's operands are propagated, but the op itself isn't needed. |
54 | if (!isa<func::ReturnOp>(op)) |
55 | slice.push_back(x: op); |
56 | |
57 | // All results are returned by the extracted function. |
58 | outputTypes.insert(position: outputTypes.end(), first: op->getResults().getTypes().begin(), |
59 | last: op->getResults().getTypes().end()); |
60 | |
61 | // Track all values that need to be taken as input to function. |
62 | values.insert(position: values.end(), first: op->getOperands().begin(), |
63 | last: op->getOperands().end()); |
64 | } |
65 | |
66 | // Create the function |
67 | FunctionType funcType = |
68 | builder.getFunctionType(TypeRange(ValueRange(values)), outputTypes); |
69 | auto loc = builder.getUnknownLoc(); |
70 | func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType); |
71 | |
72 | builder.setInsertionPointToEnd(funcOp.addEntryBlock()); |
73 | |
74 | // Map original values to function arguments |
75 | IRMapping mapper; |
76 | for (const auto &arg : llvm::enumerate(First&: values)) |
77 | mapper.map(arg.value(), funcOp.getArgument(arg.index())); |
78 | |
79 | // Clone operations and build function body |
80 | std::vector<Operation *> clonedOps; |
81 | std::vector<Value> clonedVals; |
82 | for (Operation *slicedOp : slice) { |
83 | Operation *clonedOp = |
84 | clonedOps.emplace_back(args: builder.clone(op&: *slicedOp, mapper)); |
85 | clonedVals.insert(position: clonedVals.end(), first: clonedOp->result_begin(), |
86 | last: clonedOp->result_end()); |
87 | } |
88 | // Add return operation |
89 | builder.create<func::ReturnOp>(loc, clonedVals); |
90 | |
91 | // Remove unused function arguments |
92 | size_t currentIndex = 0; |
93 | while (currentIndex < funcOp.getNumArguments()) { |
94 | if (funcOp.getArgument(currentIndex).use_empty()) |
95 | funcOp.eraseArgument(currentIndex); |
96 | else |
97 | ++currentIndex; |
98 | } |
99 | |
100 | return funcOp; |
101 | } |
102 | |
103 | Query::~Query() = default; |
104 | |
105 | mlir::LogicalResult InvalidQuery::run(llvm::raw_ostream &os, |
106 | QuerySession &qs) const { |
107 | os << errStr << "\n" ; |
108 | return mlir::failure(); |
109 | } |
110 | |
111 | mlir::LogicalResult NoOpQuery::run(llvm::raw_ostream &os, |
112 | QuerySession &qs) const { |
113 | return mlir::success(); |
114 | } |
115 | |
116 | mlir::LogicalResult HelpQuery::run(llvm::raw_ostream &os, |
117 | QuerySession &qs) const { |
118 | os << "Available commands:\n\n" |
119 | " match MATCHER, m MATCHER " |
120 | "Match the mlir against the given matcher.\n" |
121 | " quit " |
122 | "Terminates the query session.\n\n" ; |
123 | return mlir::success(); |
124 | } |
125 | |
126 | mlir::LogicalResult QuitQuery::run(llvm::raw_ostream &os, |
127 | QuerySession &qs) const { |
128 | qs.terminate = true; |
129 | return mlir::success(); |
130 | } |
131 | |
132 | mlir::LogicalResult MatchQuery::run(llvm::raw_ostream &os, |
133 | QuerySession &qs) const { |
134 | Operation *rootOp = qs.getRootOp(); |
135 | int matchCount = 0; |
136 | std::vector<Operation *> matches = |
137 | matcher::MatchFinder().getMatches(root: rootOp, matcher); |
138 | |
139 | // An extract call is recognized by considering if the matcher has a name. |
140 | // TODO: Consider making the extract more explicit. |
141 | if (matcher.hasFunctionName()) { |
142 | auto functionName = matcher.getFunctionName(); |
143 | Operation *function = |
144 | extractFunction(ops&: matches, context: rootOp->getContext(), functionName); |
145 | os << "\n" << *function << "\n\n" ; |
146 | function->erase(); |
147 | return mlir::success(); |
148 | } |
149 | |
150 | os << "\n" ; |
151 | for (Operation *op : matches) { |
152 | os << "Match #" << ++matchCount << ":\n\n" ; |
153 | // Placeholder "root" binding for the initial draft. |
154 | printMatch(os, qs, op, binding: "root" ); |
155 | } |
156 | os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n" ); |
157 | |
158 | return mlir::success(); |
159 | } |
160 | |
161 | } // namespace mlir::query |
162 | |