1 | //===---- Query.cpp - -----------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Query/Query.h" |
10 | #include "QueryParser.h" |
11 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
12 | #include "mlir/IR/IRMapping.h" |
13 | #include "mlir/Query/Matcher/MatchFinder.h" |
14 | #include "mlir/Query/QuerySession.h" |
15 | #include "llvm/ADT/SetVector.h" |
16 | #include "llvm/Support/SourceMgr.h" |
17 | #include "llvm/Support/raw_ostream.h" |
18 | |
19 | namespace mlir::query { |
20 | |
21 | QueryRef parse(llvm::StringRef line, const QuerySession &qs) { |
22 | return QueryParser::parse(line, qs); |
23 | } |
24 | |
25 | std::vector<llvm::LineEditor::Completion> |
26 | complete(llvm::StringRef line, size_t pos, const QuerySession &qs) { |
27 | return QueryParser::complete(line, pos, qs); |
28 | } |
29 | |
30 | // TODO: Extract into a helper function that can be reused outside query |
31 | // context. |
32 | static Operation *(std::vector<Operation *> &ops, |
33 | MLIRContext *context, |
34 | llvm::StringRef functionName) { |
35 | context->loadDialect<func::FuncDialect>(); |
36 | OpBuilder builder(context); |
37 | |
38 | // Collect data for function creation |
39 | std::vector<Operation *> slice; |
40 | std::vector<Value> values; |
41 | std::vector<Type> outputTypes; |
42 | |
43 | for (auto *op : ops) { |
44 | // Return op's operands are propagated, but the op itself isn't needed. |
45 | if (!isa<func::ReturnOp>(op)) |
46 | slice.push_back(x: op); |
47 | |
48 | // All results are returned by the extracted function. |
49 | llvm::append_range(C&: outputTypes, R: op->getResults().getTypes()); |
50 | |
51 | // Track all values that need to be taken as input to function. |
52 | llvm::append_range(C&: values, R: op->getOperands()); |
53 | } |
54 | |
55 | // Create the function |
56 | FunctionType funcType = |
57 | builder.getFunctionType(TypeRange(ValueRange(values)), outputTypes); |
58 | auto loc = builder.getUnknownLoc(); |
59 | func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType); |
60 | |
61 | builder.setInsertionPointToEnd(funcOp.addEntryBlock()); |
62 | |
63 | // Map original values to function arguments |
64 | IRMapping mapper; |
65 | for (const auto &arg : llvm::enumerate(First&: values)) |
66 | mapper.map(arg.value(), funcOp.getArgument(arg.index())); |
67 | |
68 | // Clone operations and build function body |
69 | std::vector<Operation *> clonedOps; |
70 | std::vector<Value> clonedVals; |
71 | for (Operation *slicedOp : slice) { |
72 | Operation *clonedOp = |
73 | clonedOps.emplace_back(args: builder.clone(op&: *slicedOp, mapper)); |
74 | clonedVals.insert(position: clonedVals.end(), first: clonedOp->result_begin(), |
75 | last: clonedOp->result_end()); |
76 | } |
77 | // Add return operation |
78 | builder.create<func::ReturnOp>(loc, clonedVals); |
79 | |
80 | // Remove unused function arguments |
81 | size_t currentIndex = 0; |
82 | while (currentIndex < funcOp.getNumArguments()) { |
83 | // Erase if possible. |
84 | if (funcOp.getArgument(currentIndex).use_empty()) |
85 | if (succeeded(funcOp.eraseArgument(currentIndex))) |
86 | continue; |
87 | ++currentIndex; |
88 | } |
89 | |
90 | return funcOp; |
91 | } |
92 | |
93 | Query::~Query() = default; |
94 | |
95 | LogicalResult InvalidQuery::run(llvm::raw_ostream &os, QuerySession &qs) const { |
96 | os << errStr << "\n" ; |
97 | return mlir::failure(); |
98 | } |
99 | |
100 | LogicalResult NoOpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const { |
101 | return mlir::success(); |
102 | } |
103 | |
104 | LogicalResult HelpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const { |
105 | os << "Available commands:\n\n" |
106 | " match MATCHER, m MATCHER " |
107 | "Match the mlir against the given matcher.\n" |
108 | " quit " |
109 | "Terminates the query session.\n\n" ; |
110 | return mlir::success(); |
111 | } |
112 | |
113 | LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const { |
114 | qs.terminate = true; |
115 | return mlir::success(); |
116 | } |
117 | |
118 | LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const { |
119 | Operation *rootOp = qs.getRootOp(); |
120 | int matchCount = 0; |
121 | matcher::MatchFinder finder; |
122 | auto matches = finder.collectMatches(root: rootOp, matcher: std::move(matcher)); |
123 | |
124 | // An extract call is recognized by considering if the matcher has a name. |
125 | // TODO: Consider making the extract more explicit. |
126 | if (matcher.hasFunctionName()) { |
127 | auto functionName = matcher.getFunctionName(); |
128 | std::vector<Operation *> flattenedMatches = |
129 | finder.flattenMatchedOps(matches); |
130 | Operation *function = |
131 | extractFunction(ops&: flattenedMatches, context: rootOp->getContext(), functionName); |
132 | os << "\n" << *function << "\n\n" ; |
133 | function->erase(); |
134 | return mlir::success(); |
135 | } |
136 | |
137 | os << "\n" ; |
138 | for (auto &results : matches) { |
139 | os << "Match #" << ++matchCount << ":\n\n" ; |
140 | for (auto op : results.matchedOps) { |
141 | if (op == results.rootOp) { |
142 | finder.printMatch(os, qs, op, binding: "root" ); |
143 | } else { |
144 | finder.printMatch(os, qs, op); |
145 | } |
146 | } |
147 | } |
148 | os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n" ); |
149 | return mlir::success(); |
150 | } |
151 | |
152 | } // namespace mlir::query |
153 | |