1//===- mlir-transform-opt.cpp -----------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Transform/IR/TransformDialect.h"
10#include "mlir/Dialect/Transform/IR/Utils.h"
11#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
12#include "mlir/IR/AsmState.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/IR/Diagnostics.h"
15#include "mlir/IR/DialectRegistry.h"
16#include "mlir/IR/MLIRContext.h"
17#include "mlir/InitAllDialects.h"
18#include "mlir/InitAllExtensions.h"
19#include "mlir/InitAllPasses.h"
20#include "mlir/Parser/Parser.h"
21#include "mlir/Support/FileUtilities.h"
22#include "mlir/Tools/mlir-opt/MlirOptMain.h"
23#include "llvm/Support/CommandLine.h"
24#include "llvm/Support/InitLLVM.h"
25#include "llvm/Support/SourceMgr.h"
26#include "llvm/Support/ToolOutputFile.h"
27#include <cstdlib>
28
29namespace {
30
31using namespace llvm;
32
33/// Structure containing command line options for the tool, these will get
34/// initialized when an instance is created.
35struct MlirTransformOptCLOptions {
36 cl::opt<bool> allowUnregisteredDialects{
37 "allow-unregistered-dialect",
38 cl::desc("Allow operations coming from an unregistered dialect"),
39 cl::init(Val: false)};
40
41 cl::opt<bool> verifyDiagnostics{
42 "verify-diagnostics",
43 cl::desc("Check that emitted diagnostics match expected-* lines "
44 "on the corresponding line"),
45 cl::init(Val: false)};
46
47 cl::opt<std::string> payloadFilename{cl::Positional, cl::desc("<input file>"),
48 cl::init(Val: "-")};
49
50 cl::opt<std::string> outputFilename{"o", cl::desc("Output filename"),
51 cl::value_desc("filename"),
52 cl::init(Val: "-")};
53
54 cl::opt<std::string> transformMainFilename{
55 "transform",
56 cl::desc("File containing entry point of the transform script, if "
57 "different from the input file"),
58 cl::value_desc("filename"), cl::init(Val: "")};
59
60 cl::list<std::string> transformLibraryFilenames{
61 "transform-library", cl::desc("File(s) containing definitions of "
62 "additional transform script symbols")};
63
64 cl::opt<std::string> transformEntryPoint{
65 "transform-entry-point",
66 cl::desc("Name of the entry point transform symbol"),
67 cl::init(mlir::transform::TransformDialect::kTransformEntryPointSymbolName
68 .str())};
69
70 cl::opt<bool> disableExpensiveChecks{
71 "disable-expensive-checks",
72 cl::desc("Disables potentially expensive checks in the transform "
73 "interpreter, providing more speed at the expense of "
74 "potential memory problems and silent corruptions"),
75 cl::init(Val: false)};
76
77 cl::opt<bool> dumpLibraryModule{
78 "dump-library-module",
79 cl::desc("Prints the combined library module before the output"),
80 cl::init(Val: false)};
81};
82} // namespace
83
84/// "Managed" static instance of the command-line options structure. This makes
85/// them locally-scoped and explicitly initialized/deinitialized. While this is
86/// not strictly necessary in the tool source file that is not being used as a
87/// library (where the options would pollute the global list of options), it is
88/// good practice to follow this.
89static llvm::ManagedStatic<MlirTransformOptCLOptions> clOptions;
90
91/// Explicitly registers command-line options.
92static void registerCLOptions() { *clOptions; }
93
94namespace {
95/// A wrapper class for source managers diagnostic. This provides both unique
96/// ownership and virtual function-like overload for a pair of
97/// inheritance-related classes that do not use virtual functions.
98class DiagnosticHandlerWrapper {
99public:
100 /// Kind of the diagnostic handler to use.
101 enum class Kind { EmitDiagnostics, VerifyDiagnostics };
102
103 /// Constructs the diagnostic handler of the specified kind of the given
104 /// source manager and context.
105 DiagnosticHandlerWrapper(Kind kind, llvm::SourceMgr &mgr,
106 mlir::MLIRContext *context) {
107 if (kind == Kind::EmitDiagnostics)
108 handler = new mlir::SourceMgrDiagnosticHandler(mgr, context);
109 else
110 handler = new mlir::SourceMgrDiagnosticVerifierHandler(mgr, context);
111 }
112
113 /// This object is non-copyable but movable.
114 DiagnosticHandlerWrapper(const DiagnosticHandlerWrapper &) = delete;
115 DiagnosticHandlerWrapper(DiagnosticHandlerWrapper &&other) = default;
116 DiagnosticHandlerWrapper &
117 operator=(const DiagnosticHandlerWrapper &) = delete;
118 DiagnosticHandlerWrapper &operator=(DiagnosticHandlerWrapper &&) = default;
119
120 /// Verifies the captured "expected-*" diagnostics if required.
121 mlir::LogicalResult verify() const {
122 if (auto *ptr =
123 handler.dyn_cast<mlir::SourceMgrDiagnosticVerifierHandler *>()) {
124 return ptr->verify();
125 }
126 return mlir::success();
127 }
128
129 /// Destructs the object of the same type as allocated.
130 ~DiagnosticHandlerWrapper() {
131 if (auto *ptr = handler.dyn_cast<mlir::SourceMgrDiagnosticHandler *>()) {
132 delete ptr;
133 } else {
134 delete handler.get<mlir::SourceMgrDiagnosticVerifierHandler *>();
135 }
136 }
137
138private:
139 /// Internal storage is a type-safe union.
140 llvm::PointerUnion<mlir::SourceMgrDiagnosticHandler *,
141 mlir::SourceMgrDiagnosticVerifierHandler *>
142 handler;
143};
144
145/// MLIR has deeply rooted expectations that the LLVM source manager contains
146/// exactly one buffer, until at least the lexer level. This class wraps
147/// multiple LLVM source managers each managing a buffer to match MLIR's
148/// expectations while still providing a centralized handling mechanism.
149class TransformSourceMgr {
150public:
151 /// Constructs the source manager indicating whether diagnostic messages will
152 /// be verified later on.
153 explicit TransformSourceMgr(bool verifyDiagnostics)
154 : verifyDiagnostics(verifyDiagnostics) {}
155
156 /// Deconstructs the source manager. Note that `checkResults` must have been
157 /// called on this instance before deconstructing it.
158 ~TransformSourceMgr() {
159 assert(resultChecked && "must check the result of diagnostic handlers by "
160 "running TransformSourceMgr::checkResult");
161 }
162
163 /// Parses the given buffer and creates the top-level operation of the kind
164 /// specified as template argument in the given context. Additional parsing
165 /// options may be provided.
166 template <typename OpTy = mlir::Operation *>
167 mlir::OwningOpRef<OpTy> parseBuffer(std::unique_ptr<MemoryBuffer> buffer,
168 mlir::MLIRContext &context,
169 const mlir::ParserConfig &config) {
170 // Create a single-buffer LLVM source manager. Note that `unique_ptr` allows
171 // the code below to capture a reference to the source manager in such a way
172 // that it is not invalidated when the vector contents is eventually
173 // reallocated.
174 llvm::SourceMgr &mgr =
175 *sourceMgrs.emplace_back(Args: std::make_unique<llvm::SourceMgr>());
176 mgr.AddNewSourceBuffer(F: std::move(buffer), IncludeLoc: llvm::SMLoc());
177
178 // Choose the type of diagnostic handler depending on whether diagnostic
179 // verification needs to happen and store it.
180 if (verifyDiagnostics) {
181 diagHandlers.emplace_back(
182 Args: DiagnosticHandlerWrapper::Kind::VerifyDiagnostics, Args&: mgr, Args: &context);
183 } else {
184 diagHandlers.emplace_back(Args: DiagnosticHandlerWrapper::Kind::EmitDiagnostics,
185 Args&: mgr, Args: &context);
186 }
187
188 // Defer to MLIR's parser.
189 return mlir::parseSourceFile<OpTy>(mgr, config);
190 }
191
192 /// If diagnostic message verification has been requested upon construction of
193 /// this source manager, performs the verification, reports errors and returns
194 /// the result of the verification. Otherwise passes through the given value.
195 mlir::LogicalResult checkResult(mlir::LogicalResult result) {
196 resultChecked = true;
197 if (!verifyDiagnostics)
198 return result;
199
200 return mlir::failure(isFailure: llvm::any_of(Range&: diagHandlers, P: [](const auto &handler) {
201 return mlir::failed(result: handler.verify());
202 }));
203 }
204
205private:
206 /// Indicates whether diagnostic message verification is requested.
207 const bool verifyDiagnostics;
208
209 /// Indicates that diagnostic message verification has taken place, and the
210 /// deconstruction is therefore safe.
211 bool resultChecked = false;
212
213 /// Storage for per-buffer source managers and diagnostic handlers. These are
214 /// wrapped into unique pointers in order to make it safe to capture
215 /// references to these objects: if the vector is reallocated, the unique
216 /// pointer objects are moved by the pointer addresses won't change. Also, for
217 /// handlers, this allows to store the pointer to the base class.
218 SmallVector<std::unique_ptr<llvm::SourceMgr>> sourceMgrs;
219 SmallVector<DiagnosticHandlerWrapper> diagHandlers;
220};
221} // namespace
222
223/// Trivial wrapper around `applyTransforms` that doesn't support extra mapping
224/// and doesn't enforce the entry point transform ops being top-level.
225static mlir::LogicalResult
226applyTransforms(mlir::Operation *payloadRoot,
227 mlir::transform::TransformOpInterface transformRoot,
228 const mlir::transform::TransformOptions &options) {
229 return applyTransforms(payloadRoot, transformRoot, {}, options,
230 /*enforceToplevelTransformOp=*/false);
231}
232
233/// Applies transforms indicated in the transform dialect script to the input
234/// buffer. The transform script may be embedded in the input buffer or as a
235/// separate buffer. The transform script may have external symbols, the
236/// definitions of which must be provided in transform library buffers. If the
237/// application is successful, prints the transformed input buffer into the
238/// given output stream. Additional configuration options are derived from
239/// command-line options.
240static mlir::LogicalResult processPayloadBuffer(
241 raw_ostream &os, std::unique_ptr<MemoryBuffer> inputBuffer,
242 std::unique_ptr<llvm::MemoryBuffer> transformBuffer,
243 MutableArrayRef<std::unique_ptr<MemoryBuffer>> transformLibraries,
244 mlir::DialectRegistry &registry) {
245
246 // Initialize the MLIR context, and various configurations.
247 mlir::MLIRContext context(registry, mlir::MLIRContext::Threading::DISABLED);
248 context.allowUnregisteredDialects(allow: clOptions->allowUnregisteredDialects);
249 mlir::ParserConfig config(&context);
250 TransformSourceMgr sourceMgr(
251 /*verifyDiagnostics=*/clOptions->verifyDiagnostics);
252
253 // Parse the input buffer that will be used as transform payload.
254 mlir::OwningOpRef<mlir::Operation *> payloadRoot =
255 sourceMgr.parseBuffer(buffer: std::move(inputBuffer), context, config);
256 if (!payloadRoot)
257 return sourceMgr.checkResult(result: mlir::failure());
258
259 // Identify the module containing the transform script entry point. This may
260 // be the same module as the input or a separate module. In the former case,
261 // make a copy of the module so it can be modified freely. Modification may
262 // happen in the script itself (at which point it could be rewriting itself
263 // during interpretation, leading to tricky memory errors) or by embedding
264 // library modules in the script.
265 mlir::OwningOpRef<mlir::ModuleOp> transformRoot;
266 if (transformBuffer) {
267 transformRoot = sourceMgr.parseBuffer<mlir::ModuleOp>(
268 buffer: std::move(transformBuffer), context, config);
269 if (!transformRoot)
270 return sourceMgr.checkResult(result: mlir::failure());
271 } else {
272 transformRoot = cast<mlir::ModuleOp>(payloadRoot->clone());
273 }
274
275 // Parse and merge the libraries into the main transform module.
276 for (auto &&transformLibrary : transformLibraries) {
277 mlir::OwningOpRef<mlir::ModuleOp> libraryModule =
278 sourceMgr.parseBuffer<mlir::ModuleOp>(buffer: std::move(transformLibrary),
279 context, config);
280
281 if (!libraryModule ||
282 mlir::failed(result: mlir::transform::detail::mergeSymbolsInto(
283 target: *transformRoot, other: std::move(libraryModule))))
284 return sourceMgr.checkResult(result: mlir::failure());
285 }
286
287 // If requested, dump the combined transform module.
288 if (clOptions->dumpLibraryModule)
289 transformRoot->dump();
290
291 // Find the entry point symbol. Even if it had originally been in the payload
292 // module, it was cloned into the transform module so only look there.
293 mlir::transform::TransformOpInterface entryPoint =
294 mlir::transform::detail::findTransformEntryPoint(
295 *transformRoot, mlir::ModuleOp(), clOptions->transformEntryPoint);
296 if (!entryPoint)
297 return sourceMgr.checkResult(result: mlir::failure());
298
299 // Apply the requested transformations.
300 mlir::transform::TransformOptions transformOptions;
301 transformOptions.enableExpensiveChecks(enable: !clOptions->disableExpensiveChecks);
302 if (mlir::failed(result: applyTransforms(*payloadRoot, entryPoint, transformOptions)))
303 return sourceMgr.checkResult(result: mlir::failure());
304
305 // Print the transformed result and check the captured diagnostics if
306 // requested.
307 payloadRoot->print(os);
308 return sourceMgr.checkResult(result: mlir::success());
309}
310
311/// Tool entry point.
312static mlir::LogicalResult runMain(int argc, char **argv) {
313 // Register all upstream dialects and extensions. Specific uses are advised
314 // not to register all dialects indiscriminately but rather hand-pick what is
315 // necessary for their use case.
316 mlir::DialectRegistry registry;
317 mlir::registerAllDialects(registry);
318 mlir::registerAllExtensions(registry);
319 mlir::registerAllPasses();
320
321 // Explicitly register the transform dialect. This is not strictly necessary
322 // since it has been already registered as part of the upstream dialect list,
323 // but useful for example purposes for cases when dialects to register are
324 // hand-picked. The transform dialect must be registered.
325 registry.insert<mlir::transform::TransformDialect>();
326
327 // Register various command-line options. Note that the LLVM initializer
328 // object is a RAII that ensures correct deconstruction of command-line option
329 // objects inside ManagedStatic.
330 llvm::InitLLVM y(argc, argv);
331 mlir::registerAsmPrinterCLOptions();
332 mlir::registerMLIRContextCLOptions();
333 registerCLOptions();
334 llvm::cl::ParseCommandLineOptions(argc, argv,
335 Overview: "Minimal Transform dialect driver\n");
336
337 // Try opening the main input file.
338 std::string errorMessage;
339 std::unique_ptr<llvm::MemoryBuffer> payloadFile =
340 mlir::openInputFile(inputFilename: clOptions->payloadFilename, errorMessage: &errorMessage);
341 if (!payloadFile) {
342 llvm::errs() << errorMessage << "\n";
343 return mlir::failure();
344 }
345
346 // Try opening the output file.
347 std::unique_ptr<llvm::ToolOutputFile> outputFile =
348 mlir::openOutputFile(outputFilename: clOptions->outputFilename, errorMessage: &errorMessage);
349 if (!outputFile) {
350 llvm::errs() << errorMessage << "\n";
351 return mlir::failure();
352 }
353
354 // Try opening the main transform file if provided.
355 std::unique_ptr<llvm::MemoryBuffer> transformRootFile;
356 if (!clOptions->transformMainFilename.empty()) {
357 if (clOptions->transformMainFilename == clOptions->payloadFilename) {
358 llvm::errs() << "warning: " << clOptions->payloadFilename
359 << " is provided as both payload and transform file\n";
360 } else {
361 transformRootFile =
362 mlir::openInputFile(inputFilename: clOptions->transformMainFilename, errorMessage: &errorMessage);
363 if (!transformRootFile) {
364 llvm::errs() << errorMessage << "\n";
365 return mlir::failure();
366 }
367 }
368 }
369
370 // Try opening transform library files if provided.
371 SmallVector<std::unique_ptr<llvm::MemoryBuffer>> transformLibraries;
372 transformLibraries.reserve(N: clOptions->transformLibraryFilenames.size());
373 for (llvm::StringRef filename : clOptions->transformLibraryFilenames) {
374 transformLibraries.emplace_back(
375 Args: mlir::openInputFile(inputFilename: filename, errorMessage: &errorMessage));
376 if (!transformLibraries.back()) {
377 llvm::errs() << errorMessage << "\n";
378 return mlir::failure();
379 }
380 }
381
382 return processPayloadBuffer(os&: outputFile->os(), inputBuffer: std::move(payloadFile),
383 transformBuffer: std::move(transformRootFile), transformLibraries,
384 registry);
385}
386
387int main(int argc, char **argv) {
388 return mlir::asMainReturnCode(r: runMain(argc, argv));
389}
390

source code of mlir/examples/transform-opt/mlir-transform-opt.cpp