1//===---- Query.cpp - -----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Query/Query.h"
10#include "QueryParser.h"
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/IR/IRMapping.h"
13#include "mlir/Query/Matcher/MatchFinder.h"
14#include "mlir/Query/QuerySession.h"
15#include "llvm/ADT/SetVector.h"
16#include "llvm/Support/SourceMgr.h"
17#include "llvm/Support/raw_ostream.h"
18
19namespace mlir::query {
20
21QueryRef parse(llvm::StringRef line, const QuerySession &qs) {
22 return QueryParser::parse(line, qs);
23}
24
25std::vector<llvm::LineEditor::Completion>
26complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
27 return QueryParser::complete(line, pos, qs);
28}
29
30// TODO: Extract into a helper function that can be reused outside query
31// context.
32static Operation *extractFunction(std::vector<Operation *> &ops,
33 MLIRContext *context,
34 llvm::StringRef functionName) {
35 context->loadDialect<func::FuncDialect>();
36 OpBuilder builder(context);
37
38 // Collect data for function creation
39 std::vector<Operation *> slice;
40 std::vector<Value> values;
41 std::vector<Type> outputTypes;
42
43 for (auto *op : ops) {
44 // Return op's operands are propagated, but the op itself isn't needed.
45 if (!isa<func::ReturnOp>(op))
46 slice.push_back(x: op);
47
48 // All results are returned by the extracted function.
49 llvm::append_range(C&: outputTypes, R: op->getResults().getTypes());
50
51 // Track all values that need to be taken as input to function.
52 llvm::append_range(C&: values, R: op->getOperands());
53 }
54
55 // Create the function
56 FunctionType funcType =
57 builder.getFunctionType(TypeRange(ValueRange(values)), outputTypes);
58 auto loc = builder.getUnknownLoc();
59 func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
60
61 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
62
63 // Map original values to function arguments
64 IRMapping mapper;
65 for (const auto &arg : llvm::enumerate(First&: values))
66 mapper.map(arg.value(), funcOp.getArgument(arg.index()));
67
68 // Clone operations and build function body
69 std::vector<Operation *> clonedOps;
70 std::vector<Value> clonedVals;
71 for (Operation *slicedOp : slice) {
72 Operation *clonedOp =
73 clonedOps.emplace_back(args: builder.clone(op&: *slicedOp, mapper));
74 clonedVals.insert(position: clonedVals.end(), first: clonedOp->result_begin(),
75 last: clonedOp->result_end());
76 }
77 // Add return operation
78 builder.create<func::ReturnOp>(loc, clonedVals);
79
80 // Remove unused function arguments
81 size_t currentIndex = 0;
82 while (currentIndex < funcOp.getNumArguments()) {
83 // Erase if possible.
84 if (funcOp.getArgument(currentIndex).use_empty())
85 if (succeeded(funcOp.eraseArgument(currentIndex)))
86 continue;
87 ++currentIndex;
88 }
89
90 return funcOp;
91}
92
93Query::~Query() = default;
94
95LogicalResult InvalidQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
96 os << errStr << "\n";
97 return mlir::failure();
98}
99
100LogicalResult NoOpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
101 return mlir::success();
102}
103
104LogicalResult HelpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
105 os << "Available commands:\n\n"
106 " match MATCHER, m MATCHER "
107 "Match the mlir against the given matcher.\n"
108 " quit "
109 "Terminates the query session.\n\n";
110 return mlir::success();
111}
112
113LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
114 qs.terminate = true;
115 return mlir::success();
116}
117
118LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
119 Operation *rootOp = qs.getRootOp();
120 int matchCount = 0;
121 matcher::MatchFinder finder;
122 auto matches = finder.collectMatches(root: rootOp, matcher: std::move(matcher));
123
124 // An extract call is recognized by considering if the matcher has a name.
125 // TODO: Consider making the extract more explicit.
126 if (matcher.hasFunctionName()) {
127 auto functionName = matcher.getFunctionName();
128 std::vector<Operation *> flattenedMatches =
129 finder.flattenMatchedOps(matches);
130 Operation *function =
131 extractFunction(ops&: flattenedMatches, context: rootOp->getContext(), functionName);
132 os << "\n" << *function << "\n\n";
133 function->erase();
134 return mlir::success();
135 }
136
137 os << "\n";
138 for (auto &results : matches) {
139 os << "Match #" << ++matchCount << ":\n\n";
140 for (auto op : results.matchedOps) {
141 if (op == results.rootOp) {
142 finder.printMatch(os, qs, op, binding: "root");
143 } else {
144 finder.printMatch(os, qs, op);
145 }
146 }
147 }
148 os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
149 return mlir::success();
150}
151
152} // namespace mlir::query
153

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Query/Query.cpp