1//===- mlir-transform-opt.cpp -----------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Transform/IR/TransformDialect.h"
10#include "mlir/Dialect/Transform/IR/Utils.h"
11#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
12#include "mlir/IR/AsmState.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/IR/Diagnostics.h"
15#include "mlir/IR/DialectRegistry.h"
16#include "mlir/IR/MLIRContext.h"
17#include "mlir/InitAllDialects.h"
18#include "mlir/InitAllExtensions.h"
19#include "mlir/InitAllPasses.h"
20#include "mlir/Parser/Parser.h"
21#include "mlir/Support/FileUtilities.h"
22#include "mlir/Tools/mlir-opt/MlirOptMain.h"
23#include "llvm/Support/CommandLine.h"
24#include "llvm/Support/InitLLVM.h"
25#include "llvm/Support/SourceMgr.h"
26#include "llvm/Support/ToolOutputFile.h"
27#include <cstdlib>
28
29namespace {
30
31using namespace llvm;
32
33/// Structure containing command line options for the tool, these will get
34/// initialized when an instance is created.
35struct MlirTransformOptCLOptions {
36 cl::opt<bool> allowUnregisteredDialects{
37 "allow-unregistered-dialect",
38 cl::desc("Allow operations coming from an unregistered dialect"),
39 cl::init(Val: false)};
40
41 cl::opt<mlir::SourceMgrDiagnosticVerifierHandler::Level> verifyDiagnostics{
42 "verify-diagnostics", llvm::cl::ValueOptional,
43 cl::desc("Check that emitted diagnostics match expected-* lines on the "
44 "corresponding line"),
45 cl::values(
46 clEnumValN(
47 mlir::SourceMgrDiagnosticVerifierHandler::Level::All, "all",
48 "Check all diagnostics (expected, unexpected, near-misses)"),
49 // Implicit value: when passed with no arguments, e.g.
50 // `--verify-diagnostics` or `--verify-diagnostics=`.
51 clEnumValN(
52 mlir::SourceMgrDiagnosticVerifierHandler::Level::All, "",
53 "Check all diagnostics (expected, unexpected, near-misses)"),
54 clEnumValN(
55 mlir::SourceMgrDiagnosticVerifierHandler::Level::OnlyExpected,
56 "only-expected", "Check only expected diagnostics"))};
57
58 cl::opt<std::string> payloadFilename{cl::Positional, cl::desc("<input file>"),
59 cl::init(Val: "-")};
60
61 cl::opt<std::string> outputFilename{"o", cl::desc("Output filename"),
62 cl::value_desc("filename"),
63 cl::init(Val: "-")};
64
65 cl::opt<std::string> transformMainFilename{
66 "transform",
67 cl::desc("File containing entry point of the transform script, if "
68 "different from the input file"),
69 cl::value_desc("filename"), cl::init(Val: "")};
70
71 cl::list<std::string> transformLibraryFilenames{
72 "transform-library", cl::desc("File(s) containing definitions of "
73 "additional transform script symbols")};
74
75 cl::opt<std::string> transformEntryPoint{
76 "transform-entry-point",
77 cl::desc("Name of the entry point transform symbol"),
78 cl::init(mlir::transform::TransformDialect::kTransformEntryPointSymbolName
79 .str())};
80
81 cl::opt<bool> disableExpensiveChecks{
82 "disable-expensive-checks",
83 cl::desc("Disables potentially expensive checks in the transform "
84 "interpreter, providing more speed at the expense of "
85 "potential memory problems and silent corruptions"),
86 cl::init(Val: false)};
87
88 cl::opt<bool> dumpLibraryModule{
89 "dump-library-module",
90 cl::desc("Prints the combined library module before the output"),
91 cl::init(Val: false)};
92};
93} // namespace
94
95/// "Managed" static instance of the command-line options structure. This makes
96/// them locally-scoped and explicitly initialized/deinitialized. While this is
97/// not strictly necessary in the tool source file that is not being used as a
98/// library (where the options would pollute the global list of options), it is
99/// good practice to follow this.
100static llvm::ManagedStatic<MlirTransformOptCLOptions> clOptions;
101
102/// Explicitly registers command-line options.
103static void registerCLOptions() { *clOptions; }
104
105namespace {
106/// A wrapper class for source managers diagnostic. This provides both unique
107/// ownership and virtual function-like overload for a pair of
108/// inheritance-related classes that do not use virtual functions.
109class DiagnosticHandlerWrapper {
110public:
111 /// Kind of the diagnostic handler to use.
112 enum class Kind { EmitDiagnostics, VerifyDiagnostics };
113
114 /// Constructs the diagnostic handler of the specified kind of the given
115 /// source manager and context.
116 DiagnosticHandlerWrapper(
117 Kind kind, llvm::SourceMgr &mgr, mlir::MLIRContext *context,
118 std::optional<mlir::SourceMgrDiagnosticVerifierHandler::Level> level =
119 {}) {
120 if (kind == Kind::EmitDiagnostics) {
121 handler = new mlir::SourceMgrDiagnosticHandler(mgr, context);
122 } else {
123 assert(level.has_value() && "expected level");
124 handler =
125 new mlir::SourceMgrDiagnosticVerifierHandler(mgr, context, *level);
126 }
127 }
128
129 /// This object is non-copyable but movable.
130 DiagnosticHandlerWrapper(const DiagnosticHandlerWrapper &) = delete;
131 DiagnosticHandlerWrapper(DiagnosticHandlerWrapper &&other) = default;
132 DiagnosticHandlerWrapper &
133 operator=(const DiagnosticHandlerWrapper &) = delete;
134 DiagnosticHandlerWrapper &operator=(DiagnosticHandlerWrapper &&) = default;
135
136 /// Verifies the captured "expected-*" diagnostics if required.
137 llvm::LogicalResult verify() const {
138 if (auto *ptr =
139 dyn_cast<mlir::SourceMgrDiagnosticVerifierHandler *>(Val: handler)) {
140 return ptr->verify();
141 }
142 return mlir::success();
143 }
144
145 /// Destructs the object of the same type as allocated.
146 ~DiagnosticHandlerWrapper() {
147 if (auto *ptr = dyn_cast<mlir::SourceMgrDiagnosticHandler *>(Val&: handler)) {
148 delete ptr;
149 } else {
150 delete cast<mlir::SourceMgrDiagnosticVerifierHandler *>(Val&: handler);
151 }
152 }
153
154private:
155 /// Internal storage is a type-safe union.
156 llvm::PointerUnion<mlir::SourceMgrDiagnosticHandler *,
157 mlir::SourceMgrDiagnosticVerifierHandler *>
158 handler;
159};
160
161/// MLIR has deeply rooted expectations that the LLVM source manager contains
162/// exactly one buffer, until at least the lexer level. This class wraps
163/// multiple LLVM source managers each managing a buffer to match MLIR's
164/// expectations while still providing a centralized handling mechanism.
165class TransformSourceMgr {
166public:
167 /// Constructs the source manager indicating whether diagnostic messages will
168 /// be verified later on.
169 explicit TransformSourceMgr(
170 std::optional<mlir::SourceMgrDiagnosticVerifierHandler::Level>
171 verifyDiagnostics)
172 : verifyDiagnostics(verifyDiagnostics) {}
173
174 /// Deconstructs the source manager. Note that `checkResults` must have been
175 /// called on this instance before deconstructing it.
176 ~TransformSourceMgr() {
177 assert(resultChecked && "must check the result of diagnostic handlers by "
178 "running TransformSourceMgr::checkResult");
179 }
180
181 /// Parses the given buffer and creates the top-level operation of the kind
182 /// specified as template argument in the given context. Additional parsing
183 /// options may be provided.
184 template <typename OpTy = mlir::Operation *>
185 mlir::OwningOpRef<OpTy> parseBuffer(std::unique_ptr<MemoryBuffer> buffer,
186 mlir::MLIRContext &context,
187 const mlir::ParserConfig &config) {
188 // Create a single-buffer LLVM source manager. Note that `unique_ptr` allows
189 // the code below to capture a reference to the source manager in such a way
190 // that it is not invalidated when the vector contents is eventually
191 // reallocated.
192 llvm::SourceMgr &mgr =
193 *sourceMgrs.emplace_back(Args: std::make_unique<llvm::SourceMgr>());
194 mgr.AddNewSourceBuffer(F: std::move(buffer), IncludeLoc: llvm::SMLoc());
195
196 // Choose the type of diagnostic handler depending on whether diagnostic
197 // verification needs to happen and store it.
198 if (verifyDiagnostics) {
199 diagHandlers.emplace_back(
200 Args: DiagnosticHandlerWrapper::Kind::VerifyDiagnostics, Args&: mgr, Args: &context,
201 Args: verifyDiagnostics);
202 } else {
203 diagHandlers.emplace_back(Args: DiagnosticHandlerWrapper::Kind::EmitDiagnostics,
204 Args&: mgr, Args: &context);
205 }
206
207 // Defer to MLIR's parser.
208 return mlir::parseSourceFile<OpTy>(mgr, config);
209 }
210
211 /// If diagnostic message verification has been requested upon construction of
212 /// this source manager, performs the verification, reports errors and returns
213 /// the result of the verification. Otherwise passes through the given value.
214 llvm::LogicalResult checkResult(llvm::LogicalResult result) {
215 resultChecked = true;
216 if (!verifyDiagnostics)
217 return result;
218
219 return mlir::failure(IsFailure: llvm::any_of(Range&: diagHandlers, P: [](const auto &handler) {
220 return mlir::failed(Result: handler.verify());
221 }));
222 }
223
224private:
225 /// Indicates whether diagnostic message verification is requested.
226 const std::optional<mlir::SourceMgrDiagnosticVerifierHandler::Level>
227 verifyDiagnostics;
228
229 /// Indicates that diagnostic message verification has taken place, and the
230 /// deconstruction is therefore safe.
231 bool resultChecked = false;
232
233 /// Storage for per-buffer source managers and diagnostic handlers. These are
234 /// wrapped into unique pointers in order to make it safe to capture
235 /// references to these objects: if the vector is reallocated, the unique
236 /// pointer objects are moved by the pointer addresses won't change. Also, for
237 /// handlers, this allows to store the pointer to the base class.
238 SmallVector<std::unique_ptr<llvm::SourceMgr>> sourceMgrs;
239 SmallVector<DiagnosticHandlerWrapper> diagHandlers;
240};
241} // namespace
242
243/// Trivial wrapper around `applyTransforms` that doesn't support extra mapping
244/// and doesn't enforce the entry point transform ops being top-level.
245static llvm::LogicalResult
246applyTransforms(mlir::Operation *payloadRoot,
247 mlir::transform::TransformOpInterface transformRoot,
248 const mlir::transform::TransformOptions &options) {
249 return applyTransforms(payloadRoot, transformRoot, {}, options,
250 /*enforceToplevelTransformOp=*/false);
251}
252
253/// Applies transforms indicated in the transform dialect script to the input
254/// buffer. The transform script may be embedded in the input buffer or as a
255/// separate buffer. The transform script may have external symbols, the
256/// definitions of which must be provided in transform library buffers. If the
257/// application is successful, prints the transformed input buffer into the
258/// given output stream. Additional configuration options are derived from
259/// command-line options.
260static llvm::LogicalResult processPayloadBuffer(
261 raw_ostream &os, std::unique_ptr<MemoryBuffer> inputBuffer,
262 std::unique_ptr<llvm::MemoryBuffer> transformBuffer,
263 MutableArrayRef<std::unique_ptr<MemoryBuffer>> transformLibraries,
264 mlir::DialectRegistry &registry) {
265
266 // Initialize the MLIR context, and various configurations.
267 mlir::MLIRContext context(registry, mlir::MLIRContext::Threading::DISABLED);
268 context.allowUnregisteredDialects(allow: clOptions->allowUnregisteredDialects);
269 mlir::ParserConfig config(&context);
270 TransformSourceMgr sourceMgr(
271 /*verifyDiagnostics=*/clOptions->verifyDiagnostics.getNumOccurrences()
272 ? std::optional{clOptions->verifyDiagnostics.getValue()}
273 : std::nullopt);
274
275 // Parse the input buffer that will be used as transform payload.
276 mlir::OwningOpRef<mlir::Operation *> payloadRoot =
277 sourceMgr.parseBuffer(buffer: std::move(inputBuffer), context, config);
278 if (!payloadRoot)
279 return sourceMgr.checkResult(result: mlir::failure());
280
281 // Identify the module containing the transform script entry point. This may
282 // be the same module as the input or a separate module. In the former case,
283 // make a copy of the module so it can be modified freely. Modification may
284 // happen in the script itself (at which point it could be rewriting itself
285 // during interpretation, leading to tricky memory errors) or by embedding
286 // library modules in the script.
287 mlir::OwningOpRef<mlir::ModuleOp> transformRoot;
288 if (transformBuffer) {
289 transformRoot = sourceMgr.parseBuffer<mlir::ModuleOp>(
290 buffer: std::move(transformBuffer), context, config);
291 if (!transformRoot)
292 return sourceMgr.checkResult(result: mlir::failure());
293 } else {
294 transformRoot = cast<mlir::ModuleOp>(payloadRoot->clone());
295 }
296
297 // Parse and merge the libraries into the main transform module.
298 for (auto &&transformLibrary : transformLibraries) {
299 mlir::OwningOpRef<mlir::ModuleOp> libraryModule =
300 sourceMgr.parseBuffer<mlir::ModuleOp>(buffer: std::move(transformLibrary),
301 context, config);
302
303 if (!libraryModule ||
304 mlir::failed(Result: mlir::transform::detail::mergeSymbolsInto(
305 target: *transformRoot, other: std::move(libraryModule))))
306 return sourceMgr.checkResult(result: mlir::failure());
307 }
308
309 // If requested, dump the combined transform module.
310 if (clOptions->dumpLibraryModule)
311 transformRoot->dump();
312
313 // Find the entry point symbol. Even if it had originally been in the payload
314 // module, it was cloned into the transform module so only look there.
315 mlir::transform::TransformOpInterface entryPoint =
316 mlir::transform::detail::findTransformEntryPoint(
317 *transformRoot, mlir::ModuleOp(), clOptions->transformEntryPoint);
318 if (!entryPoint)
319 return sourceMgr.checkResult(result: mlir::failure());
320
321 // Apply the requested transformations.
322 mlir::transform::TransformOptions transformOptions;
323 transformOptions.enableExpensiveChecks(enable: !clOptions->disableExpensiveChecks);
324 if (mlir::failed(Result: applyTransforms(*payloadRoot, entryPoint, transformOptions)))
325 return sourceMgr.checkResult(result: mlir::failure());
326
327 // Print the transformed result and check the captured diagnostics if
328 // requested.
329 payloadRoot->print(os);
330 return sourceMgr.checkResult(result: mlir::success());
331}
332
333/// Tool entry point.
334static llvm::LogicalResult runMain(int argc, char **argv) {
335 // Register all upstream dialects and extensions. Specific uses are advised
336 // not to register all dialects indiscriminately but rather hand-pick what is
337 // necessary for their use case.
338 mlir::DialectRegistry registry;
339 mlir::registerAllDialects(registry);
340 mlir::registerAllExtensions(registry);
341 mlir::registerAllPasses();
342
343 // Explicitly register the transform dialect. This is not strictly necessary
344 // since it has been already registered as part of the upstream dialect list,
345 // but useful for example purposes for cases when dialects to register are
346 // hand-picked. The transform dialect must be registered.
347 registry.insert<mlir::transform::TransformDialect>();
348
349 // Register various command-line options. Note that the LLVM initializer
350 // object is a RAII that ensures correct deconstruction of command-line option
351 // objects inside ManagedStatic.
352 llvm::InitLLVM y(argc, argv);
353 mlir::registerAsmPrinterCLOptions();
354 mlir::registerMLIRContextCLOptions();
355 registerCLOptions();
356 llvm::cl::ParseCommandLineOptions(argc, argv,
357 Overview: "Minimal Transform dialect driver\n");
358
359 // Try opening the main input file.
360 std::string errorMessage;
361 std::unique_ptr<llvm::MemoryBuffer> payloadFile =
362 mlir::openInputFile(inputFilename: clOptions->payloadFilename, errorMessage: &errorMessage);
363 if (!payloadFile) {
364 llvm::errs() << errorMessage << "\n";
365 return mlir::failure();
366 }
367
368 // Try opening the output file.
369 std::unique_ptr<llvm::ToolOutputFile> outputFile =
370 mlir::openOutputFile(outputFilename: clOptions->outputFilename, errorMessage: &errorMessage);
371 if (!outputFile) {
372 llvm::errs() << errorMessage << "\n";
373 return mlir::failure();
374 }
375
376 // Try opening the main transform file if provided.
377 std::unique_ptr<llvm::MemoryBuffer> transformRootFile;
378 if (!clOptions->transformMainFilename.empty()) {
379 if (clOptions->transformMainFilename == clOptions->payloadFilename) {
380 llvm::errs() << "warning: " << clOptions->payloadFilename
381 << " is provided as both payload and transform file\n";
382 } else {
383 transformRootFile =
384 mlir::openInputFile(inputFilename: clOptions->transformMainFilename, errorMessage: &errorMessage);
385 if (!transformRootFile) {
386 llvm::errs() << errorMessage << "\n";
387 return mlir::failure();
388 }
389 }
390 }
391
392 // Try opening transform library files if provided.
393 SmallVector<std::unique_ptr<llvm::MemoryBuffer>> transformLibraries;
394 transformLibraries.reserve(N: clOptions->transformLibraryFilenames.size());
395 for (llvm::StringRef filename : clOptions->transformLibraryFilenames) {
396 transformLibraries.emplace_back(
397 Args: mlir::openInputFile(inputFilename: filename, errorMessage: &errorMessage));
398 if (!transformLibraries.back()) {
399 llvm::errs() << errorMessage << "\n";
400 return mlir::failure();
401 }
402 }
403
404 return processPayloadBuffer(os&: outputFile->os(), inputBuffer: std::move(payloadFile),
405 transformBuffer: std::move(transformRootFile), transformLibraries,
406 registry);
407}
408
409int main(int argc, char **argv) {
410 return mlir::asMainReturnCode(r: runMain(argc, argv));
411}
412

source code of mlir/examples/transform-opt/mlir-transform-opt.cpp