1//===- mlir-rewrite.cpp - MLIR Rewrite Driver -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Main entry function for mlir-rewrite.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/AsmParser/AsmParser.h"
14#include "mlir/AsmParser/AsmParserState.h"
15#include "mlir/IR/AsmState.h"
16#include "mlir/IR/Dialect.h"
17#include "mlir/IR/MLIRContext.h"
18#include "mlir/InitAllDialects.h"
19#include "mlir/Pass/Pass.h"
20#include "mlir/Pass/PassManager.h"
21#include "mlir/Support/FileUtilities.h"
22#include "mlir/Tools/ParseUtilities.h"
23#include "llvm/ADT/RewriteBuffer.h"
24#include "llvm/Support/CommandLine.h"
25#include "llvm/Support/InitLLVM.h"
26#include "llvm/Support/LineIterator.h"
27#include "llvm/Support/Regex.h"
28#include "llvm/Support/SourceMgr.h"
29#include "llvm/Support/ToolOutputFile.h"
30
31using namespace mlir;
32
33namespace mlir {
34using OperationDefinition = AsmParserState::OperationDefinition;
35
36/// Return the source code associated with the OperationDefinition.
37SMRange getOpRange(const OperationDefinition &op) {
38 const char *startOp = op.scopeLoc.Start.getPointer();
39 const char *endOp = op.scopeLoc.End.getPointer();
40
41 for (const auto &res : op.resultGroups) {
42 SMRange range = res.definition.loc;
43 startOp = std::min(a: startOp, b: range.Start.getPointer());
44 }
45 return {SMLoc::getFromPointer(Ptr: startOp), SMLoc::getFromPointer(Ptr: endOp)};
46}
47
48/// Helper to simplify rewriting the source file.
49class RewritePad {
50public:
51 static std::unique_ptr<RewritePad> init(StringRef inputFilename,
52 StringRef outputFilename);
53
54 /// Return the context the file was parsed into.
55 MLIRContext *getContext() { return &context; }
56
57 /// Return the OperationDefinition's of the operation's parsed.
58 iterator_range<AsmParserState::OperationDefIterator> getOpDefs() {
59 return asmState.getOpDefs();
60 }
61
62 /// Insert the specified string at the specified location in the original
63 /// buffer.
64 void insertText(SMLoc pos, StringRef str, bool insertAfter = true) {
65 rewriteBuffer.InsertText(OrigOffset: pos.getPointer() - start, Str: str, InsertAfter: insertAfter);
66 }
67
68 /// Replace the range of the source text with the corresponding string in the
69 /// output.
70 void replaceRange(SMRange range, StringRef str) {
71 rewriteBuffer.ReplaceText(OrigOffset: range.Start.getPointer() - start,
72 OrigLength: range.End.getPointer() - range.Start.getPointer(),
73 NewStr: str);
74 }
75
76 /// Replace the range of the operation in the source text with the
77 /// corresponding string in the output.
78 void replaceDef(const OperationDefinition &opDef, StringRef newDef) {
79 replaceRange(range: getOpRange(op: opDef), str: newDef);
80 }
81
82 /// Return the source string corresponding to the source range.
83 StringRef getSourceString(SMRange range) {
84 return StringRef(range.Start.getPointer(),
85 range.End.getPointer() - range.Start.getPointer());
86 }
87
88 /// Return the source string corresponding to operation definition.
89 StringRef getSourceString(const OperationDefinition &opDef) {
90 auto range = getOpRange(op: opDef);
91 return getSourceString(range);
92 }
93
94 /// Write to stream the result of applying all changes to the
95 /// original buffer.
96 /// Note that it isn't safe to use this function to overwrite memory mapped
97 /// files in-place (PR17960).
98 ///
99 /// The original buffer is not actually changed.
100 raw_ostream &write(raw_ostream &stream) const {
101 return rewriteBuffer.write(Stream&: stream);
102 }
103
104 /// Return lines that are purely comments.
105 SmallVector<SMRange> getSingleLineComments() {
106 unsigned curBuf = sourceMgr.getMainFileID();
107 const llvm::MemoryBuffer *curMB = sourceMgr.getMemoryBuffer(i: curBuf);
108 llvm::line_iterator lineIterator(*curMB);
109 SmallVector<SMRange> ret;
110 for (; !lineIterator.is_at_end(); ++lineIterator) {
111 StringRef trimmed = lineIterator->ltrim();
112 if (trimmed.starts_with(Prefix: "//")) {
113 ret.emplace_back(
114 Args: SMLoc::getFromPointer(Ptr: trimmed.data()),
115 Args: SMLoc::getFromPointer(Ptr: trimmed.data() + trimmed.size()));
116 }
117 }
118 return ret;
119 }
120
121 /// Return the IR from parsed file.
122 Block *getParsed() { return &parsedIR; }
123
124 /// Return the definition for the given operation, or nullptr if the given
125 /// operation does not have a definition.
126 const OperationDefinition &getOpDef(Operation *op) const {
127 return *asmState.getOpDef(op);
128 }
129
130private:
131 // The context and state required to parse.
132 MLIRContext context;
133 llvm::SourceMgr sourceMgr;
134 DialectRegistry registry;
135 FallbackAsmResourceMap fallbackResourceMap;
136
137 // Storage of textual parsing results.
138 AsmParserState asmState;
139
140 // Parsed IR.
141 Block parsedIR;
142
143 // The RewriteBuffer is doing most of the real work.
144 llvm::RewriteBuffer rewriteBuffer;
145
146 // Start of the original input, used to compute offset.
147 const char *start;
148};
149
150std::unique_ptr<RewritePad> RewritePad::init(StringRef inputFilename,
151 StringRef outputFilename) {
152 std::unique_ptr<RewritePad> r = std::make_unique<RewritePad>();
153
154 // Register all the dialects needed.
155 registerAllDialects(registry&: r->registry);
156
157 // Set up the input file.
158 std::string errorMessage;
159 std::unique_ptr<llvm::MemoryBuffer> file =
160 openInputFile(inputFilename, errorMessage: &errorMessage);
161 if (!file) {
162 llvm::errs() << errorMessage << "\n";
163 return nullptr;
164 }
165 r->sourceMgr.AddNewSourceBuffer(F: std::move(file), IncludeLoc: SMLoc());
166
167 // Set up the MLIR context and error handling.
168 r->context.appendDialectRegistry(registry: r->registry);
169
170 // Record the start of the buffer to compute offsets with.
171 unsigned curBuf = r->sourceMgr.getMainFileID();
172 const llvm::MemoryBuffer *curMB = r->sourceMgr.getMemoryBuffer(i: curBuf);
173 r->start = curMB->getBufferStart();
174 r->rewriteBuffer.Initialize(Input: curMB->getBuffer());
175
176 // Parse and populate the AsmParserState.
177 ParserConfig parseConfig(&r->context, /*verifyAfterParse=*/true,
178 &r->fallbackResourceMap);
179 // Always allow unregistered.
180 r->context.allowUnregisteredDialects(allow: true);
181 if (failed(Result: parseAsmSourceFile(sourceMgr: r->sourceMgr, block: &r->parsedIR, config: parseConfig,
182 asmState: &r->asmState)))
183 return nullptr;
184
185 return r;
186}
187
188/// Return the source code associated with the operation name.
189SMRange getOpNameRange(const OperationDefinition &op) { return op.loc; }
190
191/// Return whether the operation was printed using generic syntax in original
192/// buffer.
193bool isGeneric(const OperationDefinition &op) {
194 return op.loc.Start.getPointer()[0] == '"';
195}
196
197inline int asMainReturnCode(LogicalResult r) {
198 return r.succeeded() ? EXIT_SUCCESS : EXIT_FAILURE;
199}
200
201/// Reriter function to invoke.
202using RewriterFunction = std::function<mlir::LogicalResult(
203 mlir::RewritePad &rewriteState, llvm::raw_ostream &os)>;
204
205/// Structure to group information about a rewriter (argument to invoke via
206/// mlir-tblgen, description, and rewriter function).
207class RewriterInfo {
208public:
209 /// RewriterInfo constructor should not be invoked directly, instead use
210 /// RewriterRegistration or registerRewriter.
211 RewriterInfo(StringRef arg, StringRef description, RewriterFunction rewriter)
212 : arg(arg), description(description), rewriter(std::move(rewriter)) {}
213
214 /// Invokes the rewriter and returns whether the rewriter failed.
215 LogicalResult invoke(mlir::RewritePad &rewriteState, raw_ostream &os) const {
216 assert(rewriter && "Cannot call rewriter with null rewriter");
217 return rewriter(rewriteState, os);
218 }
219
220 /// Returns the command line option that may be passed to 'mlir-rewrite' to
221 /// invoke this rewriter.
222 StringRef getRewriterArgument() const { return arg; }
223
224 /// Returns a description for the rewriter.
225 StringRef getRewriterDescription() const { return description; }
226
227private:
228 // The argument with which to invoke the rewriter via mlir-tblgen.
229 StringRef arg;
230
231 // Description of the rewriter.
232 StringRef description;
233
234 // Rewritererator function.
235 RewriterFunction rewriter;
236};
237
238static llvm::ManagedStatic<std::vector<RewriterInfo>> rewriterRegistry;
239
240/// Adds command line option for each registered rewriter.
241struct RewriterNameParser : public llvm::cl::parser<const RewriterInfo *> {
242 RewriterNameParser(llvm::cl::Option &opt);
243
244 void printOptionInfo(const llvm::cl::Option &o,
245 size_t globalWidth) const override;
246};
247
248/// RewriterRegistration provides a global initializer that registers a rewriter
249/// function.
250struct RewriterRegistration {
251 RewriterRegistration(StringRef arg, StringRef description,
252 const RewriterFunction &function);
253};
254
255RewriterRegistration::RewriterRegistration(StringRef arg, StringRef description,
256 const RewriterFunction &function) {
257 rewriterRegistry->emplace_back(args&: arg, args&: description, args: function);
258}
259
260RewriterNameParser::RewriterNameParser(llvm::cl::Option &opt)
261 : llvm::cl::parser<const RewriterInfo *>(opt) {
262 for (const auto &kv : *rewriterRegistry) {
263 addLiteralOption(Name: kv.getRewriterArgument(), V: &kv,
264 HelpStr: kv.getRewriterDescription());
265 }
266}
267
268void RewriterNameParser::printOptionInfo(const llvm::cl::Option &o,
269 size_t globalWidth) const {
270 RewriterNameParser *tp = const_cast<RewriterNameParser *>(this);
271 llvm::array_pod_sort(Start: tp->Values.begin(), End: tp->Values.end(),
272 Compare: [](const RewriterNameParser::OptionInfo *vT1,
273 const RewriterNameParser::OptionInfo *vT2) {
274 return vT1->Name.compare(RHS: vT2->Name);
275 });
276 using llvm::cl::parser;
277 parser<const RewriterInfo *>::printOptionInfo(O: o, GlobalWidth: globalWidth);
278}
279
280} // namespace mlir
281
282// TODO: Make these injectable too in non-global way.
283static llvm::cl::OptionCategory clSimpleRenameCategory{"simple-rename options"};
284static llvm::cl::opt<std::string> simpleRenameOpName{
285 "simple-rename-op-name", llvm::cl::desc("Name of op to match on"),
286 llvm::cl::cat(clSimpleRenameCategory)};
287static llvm::cl::opt<std::string> simpleRenameMatch{
288 "simple-rename-match", llvm::cl::desc("Match string for rename"),
289 llvm::cl::cat(clSimpleRenameCategory)};
290static llvm::cl::opt<std::string> simpleRenameReplace{
291 "simple-rename-replace", llvm::cl::desc("Replace string for rename"),
292 llvm::cl::cat(clSimpleRenameCategory)};
293
294// Rewriter that does simple renames.
295LogicalResult simpleRename(RewritePad &rewriteState, raw_ostream &os) {
296 StringRef opName = simpleRenameOpName;
297 StringRef match = simpleRenameMatch;
298 StringRef replace = simpleRenameReplace;
299 llvm::Regex regex(match);
300
301 rewriteState.getParsed()->walk(callback: [&](Operation *op) {
302 if (op->getName().getStringRef() != opName)
303 return;
304
305 const OperationDefinition &opDef = rewriteState.getOpDef(op);
306 SMRange range = getOpRange(op: opDef);
307 // This is a little bit overkill for simple.
308 std::string str = regex.sub(Repl: replace, String: rewriteState.getSourceString(range));
309 rewriteState.replaceRange(range, str);
310 });
311 return success();
312}
313
314static mlir::RewriterRegistration rewriteSimpleRename("simple-rename",
315 "Perform a simple rename",
316 simpleRename);
317
318// Rewriter that insert range markers.
319LogicalResult markRanges(RewritePad &rewriteState, raw_ostream &os) {
320 for (const auto &it : rewriteState.getOpDefs()) {
321 auto [startOp, endOp] = getOpRange(op: it);
322
323 rewriteState.insertText(pos: startOp, str: "<");
324 rewriteState.insertText(pos: endOp, str: ">");
325
326 auto nameRange = getOpNameRange(op: it);
327
328 if (isGeneric(op: it)) {
329 rewriteState.insertText(pos: nameRange.Start, str: "[");
330 rewriteState.insertText(pos: nameRange.End, str: "]");
331 } else {
332 rewriteState.insertText(pos: nameRange.Start, str: "![");
333 rewriteState.insertText(pos: nameRange.End, str: "]!");
334 }
335 }
336
337 // Highlight all comment lines.
338 // TODO: Could be replaced if this is kept in memory.
339 for (auto commentLine : rewriteState.getSingleLineComments()) {
340 rewriteState.insertText(pos: commentLine.Start, str: "{");
341 rewriteState.insertText(pos: commentLine.End, str: "}");
342 }
343
344 return success();
345}
346
347static mlir::RewriterRegistration
348 rewriteMarkRanges("mark-ranges", "Indicate ranges parsed", markRanges);
349
350int main(int argc, char **argv) {
351 llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
352 llvm::cl::desc("<input file>"),
353 llvm::cl::init(Val: "-"));
354
355 llvm::cl::opt<std::string> outputFilename(
356 "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
357 llvm::cl::init(Val: "-"));
358
359 llvm::cl::opt<const mlir::RewriterInfo *, false, mlir::RewriterNameParser>
360 rewriter("", llvm::cl::desc("Rewriter to run"));
361
362 std::string helpHeader = "mlir-rewrite";
363
364 llvm::cl::ParseCommandLineOptions(argc, argv, Overview: helpHeader);
365
366 // If no rewriter has been selected, exit with error code. Could also just
367 // return but its unlikely this was intentionally being used as `cp`.
368 if (!rewriter) {
369 llvm::errs() << "No rewriter selected!\n";
370 return mlir::asMainReturnCode(r: mlir::failure());
371 }
372
373 // Set up rewrite buffer.
374 auto rewriterOr = RewritePad::init(inputFilename, outputFilename);
375 if (!rewriterOr)
376 return mlir::asMainReturnCode(r: mlir::failure());
377
378 // Set up the output file.
379 std::string errorMessage;
380 auto output = openOutputFile(outputFilename, errorMessage: &errorMessage);
381 if (!output) {
382 llvm::errs() << errorMessage << "\n";
383 return mlir::asMainReturnCode(r: mlir::failure());
384 }
385
386 LogicalResult result = rewriter->invoke(rewriteState&: *rewriterOr, os&: output->os());
387 if (succeeded(Result: result)) {
388 rewriterOr->write(stream&: output->os());
389 output->keep();
390 }
391 return mlir::asMainReturnCode(r: result);
392}
393

source code of mlir/tools/mlir-rewrite/mlir-rewrite.cpp