1//===---- Query.cpp - -----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Query/Query.h"
10#include "QueryParser.h"
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/IR/IRMapping.h"
13#include "mlir/IR/Verifier.h"
14#include "mlir/Query/Matcher/MatchFinder.h"
15#include "mlir/Query/QuerySession.h"
16#include "llvm/ADT/SetVector.h"
17#include "llvm/Support/SourceMgr.h"
18#include "llvm/Support/raw_ostream.h"
19
20namespace mlir::query {
21
22QueryRef parse(llvm::StringRef line, const QuerySession &qs) {
23 return QueryParser::parse(line, qs);
24}
25
26std::vector<llvm::LineEditor::Completion>
27complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
28 return QueryParser::complete(line, pos, qs);
29}
30
31// TODO: Extract into a helper function that can be reused outside query
32// context.
33static Operation *extractFunction(std::vector<Operation *> &ops,
34 MLIRContext *context,
35 llvm::StringRef functionName) {
36 context->loadDialect<func::FuncDialect>();
37 OpBuilder builder(context);
38
39 // Collect data for function creation
40 std::vector<Operation *> slice;
41 std::vector<Value> values;
42 std::vector<Type> outputTypes;
43
44 for (auto *op : ops) {
45 // Return op's operands are propagated, but the op itself isn't needed.
46 if (!isa<func::ReturnOp>(Val: op))
47 slice.push_back(x: op);
48
49 // All results are returned by the extracted function.
50 llvm::append_range(C&: outputTypes, R: op->getResults().getTypes());
51
52 // Track all values that need to be taken as input to function.
53 llvm::append_range(C&: values, R: op->getOperands());
54 }
55
56 // Create the function
57 FunctionType funcType =
58 builder.getFunctionType(inputs: TypeRange(ValueRange(values)), results: outputTypes);
59 auto loc = builder.getUnknownLoc();
60 func::FuncOp funcOp = func::FuncOp::create(location: loc, name: functionName, type: funcType);
61
62 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
63
64 // Map original values to function arguments
65 IRMapping mapper;
66 for (const auto &arg : llvm::enumerate(First&: values))
67 mapper.map(from: arg.value(), to: funcOp.getArgument(idx: arg.index()));
68
69 // Clone operations and build function body
70 std::vector<Operation *> clonedOps;
71 std::vector<Value> clonedVals;
72 // TODO: Handle extraction of operations with compute payloads defined via
73 // regions.
74 for (Operation *slicedOp : slice) {
75 Operation *clonedOp =
76 clonedOps.emplace_back(args: builder.clone(op&: *slicedOp, mapper));
77 clonedVals.insert(position: clonedVals.end(), first: clonedOp->result_begin(),
78 last: clonedOp->result_end());
79 }
80 // Add return operation
81 builder.create<func::ReturnOp>(location: loc, args&: clonedVals);
82
83 // Remove unused function arguments
84 size_t currentIndex = 0;
85 while (currentIndex < funcOp.getNumArguments()) {
86 // Erase if possible.
87 if (funcOp.getArgument(idx: currentIndex).use_empty())
88 if (succeeded(Result: funcOp.eraseArgument(argIndex: currentIndex)))
89 continue;
90 ++currentIndex;
91 }
92
93 return funcOp;
94}
95
96Query::~Query() = default;
97
98LogicalResult InvalidQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
99 os << errStr << "\n";
100 return mlir::failure();
101}
102
103LogicalResult NoOpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
104 return mlir::success();
105}
106
107LogicalResult HelpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
108 os << "Available commands:\n\n"
109 " match MATCHER, m MATCHER "
110 "Match the mlir against the given matcher.\n"
111 " quit "
112 "Terminates the query session.\n\n";
113 return mlir::success();
114}
115
116LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
117 qs.terminate = true;
118 return mlir::success();
119}
120
121LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
122 Operation *rootOp = qs.getRootOp();
123 int matchCount = 0;
124 matcher::MatchFinder finder;
125 auto matches = finder.collectMatches(root: rootOp, matcher: std::move(matcher));
126
127 // An extract call is recognized by considering if the matcher has a name.
128 // TODO: Consider making the extract more explicit.
129 if (matcher.hasFunctionName()) {
130 auto functionName = matcher.getFunctionName();
131 std::vector<Operation *> flattenedMatches =
132 finder.flattenMatchedOps(matches);
133 Operation *function =
134 extractFunction(ops&: flattenedMatches, context: rootOp->getContext(), functionName);
135 if (failed(Result: verify(op: function)))
136 return mlir::failure();
137 os << "\n" << *function << "\n\n";
138 function->erase();
139 return mlir::success();
140 }
141
142 os << "\n";
143 for (auto &results : matches) {
144 os << "Match #" << ++matchCount << ":\n\n";
145 for (auto op : results.matchedOps) {
146 if (op == results.rootOp) {
147 finder.printMatch(os, qs, op, binding: "root");
148 } else {
149 finder.printMatch(os, qs, op);
150 }
151 }
152 }
153 os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
154 return mlir::success();
155}
156
157} // namespace mlir::query
158

source code of mlir/lib/Query/Query.cpp