1 | //===- mlir-transform-opt.cpp -----------------------------------*- C++ -*-===// |
2 | // |
3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
10 | #include "mlir/Dialect/Transform/IR/Utils.h" |
11 | #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" |
12 | #include "mlir/IR/AsmState.h" |
13 | #include "mlir/IR/BuiltinOps.h" |
14 | #include "mlir/IR/Diagnostics.h" |
15 | #include "mlir/IR/DialectRegistry.h" |
16 | #include "mlir/IR/MLIRContext.h" |
17 | #include "mlir/InitAllDialects.h" |
18 | #include "mlir/InitAllExtensions.h" |
19 | #include "mlir/InitAllPasses.h" |
20 | #include "mlir/Parser/Parser.h" |
21 | #include "mlir/Support/FileUtilities.h" |
22 | #include "mlir/Tools/mlir-opt/MlirOptMain.h" |
23 | #include "llvm/Support/CommandLine.h" |
24 | #include "llvm/Support/InitLLVM.h" |
25 | #include "llvm/Support/SourceMgr.h" |
26 | #include "llvm/Support/ToolOutputFile.h" |
27 | #include <cstdlib> |
28 | |
29 | namespace { |
30 | |
31 | using namespace llvm; |
32 | |
33 | /// Structure containing command line options for the tool, these will get |
34 | /// initialized when an instance is created. |
35 | struct MlirTransformOptCLOptions { |
36 | cl::opt<bool> allowUnregisteredDialects{ |
37 | "allow-unregistered-dialect" , |
38 | cl::desc("Allow operations coming from an unregistered dialect" ), |
39 | cl::init(Val: false)}; |
40 | |
41 | cl::opt<mlir::SourceMgrDiagnosticVerifierHandler::Level> verifyDiagnostics{ |
42 | "verify-diagnostics" , llvm::cl::ValueOptional, |
43 | cl::desc("Check that emitted diagnostics match expected-* lines on the " |
44 | "corresponding line" ), |
45 | cl::values( |
46 | clEnumValN( |
47 | mlir::SourceMgrDiagnosticVerifierHandler::Level::All, "all" , |
48 | "Check all diagnostics (expected, unexpected, near-misses)" ), |
49 | // Implicit value: when passed with no arguments, e.g. |
50 | // `--verify-diagnostics` or `--verify-diagnostics=`. |
51 | clEnumValN( |
52 | mlir::SourceMgrDiagnosticVerifierHandler::Level::All, "" , |
53 | "Check all diagnostics (expected, unexpected, near-misses)" ), |
54 | clEnumValN( |
55 | mlir::SourceMgrDiagnosticVerifierHandler::Level::OnlyExpected, |
56 | "only-expected" , "Check only expected diagnostics" ))}; |
57 | |
58 | cl::opt<std::string> payloadFilename{cl::Positional, cl::desc("<input file>" ), |
59 | cl::init(Val: "-" )}; |
60 | |
61 | cl::opt<std::string> outputFilename{"o" , cl::desc("Output filename" ), |
62 | cl::value_desc("filename" ), |
63 | cl::init(Val: "-" )}; |
64 | |
65 | cl::opt<std::string> transformMainFilename{ |
66 | "transform" , |
67 | cl::desc("File containing entry point of the transform script, if " |
68 | "different from the input file" ), |
69 | cl::value_desc("filename" ), cl::init(Val: "" )}; |
70 | |
71 | cl::list<std::string> transformLibraryFilenames{ |
72 | "transform-library" , cl::desc("File(s) containing definitions of " |
73 | "additional transform script symbols" )}; |
74 | |
75 | cl::opt<std::string> transformEntryPoint{ |
76 | "transform-entry-point" , |
77 | cl::desc("Name of the entry point transform symbol" ), |
78 | cl::init(mlir::transform::TransformDialect::kTransformEntryPointSymbolName |
79 | .str())}; |
80 | |
81 | cl::opt<bool> disableExpensiveChecks{ |
82 | "disable-expensive-checks" , |
83 | cl::desc("Disables potentially expensive checks in the transform " |
84 | "interpreter, providing more speed at the expense of " |
85 | "potential memory problems and silent corruptions" ), |
86 | cl::init(Val: false)}; |
87 | |
88 | cl::opt<bool> dumpLibraryModule{ |
89 | "dump-library-module" , |
90 | cl::desc("Prints the combined library module before the output" ), |
91 | cl::init(Val: false)}; |
92 | }; |
93 | } // namespace |
94 | |
95 | /// "Managed" static instance of the command-line options structure. This makes |
96 | /// them locally-scoped and explicitly initialized/deinitialized. While this is |
97 | /// not strictly necessary in the tool source file that is not being used as a |
98 | /// library (where the options would pollute the global list of options), it is |
99 | /// good practice to follow this. |
100 | static llvm::ManagedStatic<MlirTransformOptCLOptions> clOptions; |
101 | |
102 | /// Explicitly registers command-line options. |
103 | static void registerCLOptions() { *clOptions; } |
104 | |
105 | namespace { |
106 | /// A wrapper class for source managers diagnostic. This provides both unique |
107 | /// ownership and virtual function-like overload for a pair of |
108 | /// inheritance-related classes that do not use virtual functions. |
109 | class DiagnosticHandlerWrapper { |
110 | public: |
111 | /// Kind of the diagnostic handler to use. |
112 | enum class Kind { EmitDiagnostics, VerifyDiagnostics }; |
113 | |
114 | /// Constructs the diagnostic handler of the specified kind of the given |
115 | /// source manager and context. |
116 | DiagnosticHandlerWrapper( |
117 | Kind kind, llvm::SourceMgr &mgr, mlir::MLIRContext *context, |
118 | std::optional<mlir::SourceMgrDiagnosticVerifierHandler::Level> level = |
119 | {}) { |
120 | if (kind == Kind::EmitDiagnostics) { |
121 | handler = new mlir::SourceMgrDiagnosticHandler(mgr, context); |
122 | } else { |
123 | assert(level.has_value() && "expected level" ); |
124 | handler = |
125 | new mlir::SourceMgrDiagnosticVerifierHandler(mgr, context, *level); |
126 | } |
127 | } |
128 | |
129 | /// This object is non-copyable but movable. |
130 | DiagnosticHandlerWrapper(const DiagnosticHandlerWrapper &) = delete; |
131 | DiagnosticHandlerWrapper(DiagnosticHandlerWrapper &&other) = default; |
132 | DiagnosticHandlerWrapper & |
133 | operator=(const DiagnosticHandlerWrapper &) = delete; |
134 | DiagnosticHandlerWrapper &operator=(DiagnosticHandlerWrapper &&) = default; |
135 | |
136 | /// Verifies the captured "expected-*" diagnostics if required. |
137 | llvm::LogicalResult verify() const { |
138 | if (auto *ptr = |
139 | dyn_cast<mlir::SourceMgrDiagnosticVerifierHandler *>(Val: handler)) { |
140 | return ptr->verify(); |
141 | } |
142 | return mlir::success(); |
143 | } |
144 | |
145 | /// Destructs the object of the same type as allocated. |
146 | ~DiagnosticHandlerWrapper() { |
147 | if (auto *ptr = dyn_cast<mlir::SourceMgrDiagnosticHandler *>(Val&: handler)) { |
148 | delete ptr; |
149 | } else { |
150 | delete cast<mlir::SourceMgrDiagnosticVerifierHandler *>(Val&: handler); |
151 | } |
152 | } |
153 | |
154 | private: |
155 | /// Internal storage is a type-safe union. |
156 | llvm::PointerUnion<mlir::SourceMgrDiagnosticHandler *, |
157 | mlir::SourceMgrDiagnosticVerifierHandler *> |
158 | handler; |
159 | }; |
160 | |
161 | /// MLIR has deeply rooted expectations that the LLVM source manager contains |
162 | /// exactly one buffer, until at least the lexer level. This class wraps |
163 | /// multiple LLVM source managers each managing a buffer to match MLIR's |
164 | /// expectations while still providing a centralized handling mechanism. |
165 | class TransformSourceMgr { |
166 | public: |
167 | /// Constructs the source manager indicating whether diagnostic messages will |
168 | /// be verified later on. |
169 | explicit TransformSourceMgr( |
170 | std::optional<mlir::SourceMgrDiagnosticVerifierHandler::Level> |
171 | verifyDiagnostics) |
172 | : verifyDiagnostics(verifyDiagnostics) {} |
173 | |
174 | /// Deconstructs the source manager. Note that `checkResults` must have been |
175 | /// called on this instance before deconstructing it. |
176 | ~TransformSourceMgr() { |
177 | assert(resultChecked && "must check the result of diagnostic handlers by " |
178 | "running TransformSourceMgr::checkResult" ); |
179 | } |
180 | |
181 | /// Parses the given buffer and creates the top-level operation of the kind |
182 | /// specified as template argument in the given context. Additional parsing |
183 | /// options may be provided. |
184 | template <typename OpTy = mlir::Operation *> |
185 | mlir::OwningOpRef<OpTy> parseBuffer(std::unique_ptr<MemoryBuffer> buffer, |
186 | mlir::MLIRContext &context, |
187 | const mlir::ParserConfig &config) { |
188 | // Create a single-buffer LLVM source manager. Note that `unique_ptr` allows |
189 | // the code below to capture a reference to the source manager in such a way |
190 | // that it is not invalidated when the vector contents is eventually |
191 | // reallocated. |
192 | llvm::SourceMgr &mgr = |
193 | *sourceMgrs.emplace_back(Args: std::make_unique<llvm::SourceMgr>()); |
194 | mgr.AddNewSourceBuffer(F: std::move(buffer), IncludeLoc: llvm::SMLoc()); |
195 | |
196 | // Choose the type of diagnostic handler depending on whether diagnostic |
197 | // verification needs to happen and store it. |
198 | if (verifyDiagnostics) { |
199 | diagHandlers.emplace_back( |
200 | Args: DiagnosticHandlerWrapper::Kind::VerifyDiagnostics, Args&: mgr, Args: &context, |
201 | Args: verifyDiagnostics); |
202 | } else { |
203 | diagHandlers.emplace_back(Args: DiagnosticHandlerWrapper::Kind::EmitDiagnostics, |
204 | Args&: mgr, Args: &context); |
205 | } |
206 | |
207 | // Defer to MLIR's parser. |
208 | return mlir::parseSourceFile<OpTy>(mgr, config); |
209 | } |
210 | |
211 | /// If diagnostic message verification has been requested upon construction of |
212 | /// this source manager, performs the verification, reports errors and returns |
213 | /// the result of the verification. Otherwise passes through the given value. |
214 | llvm::LogicalResult checkResult(llvm::LogicalResult result) { |
215 | resultChecked = true; |
216 | if (!verifyDiagnostics) |
217 | return result; |
218 | |
219 | return mlir::failure(IsFailure: llvm::any_of(Range&: diagHandlers, P: [](const auto &handler) { |
220 | return mlir::failed(Result: handler.verify()); |
221 | })); |
222 | } |
223 | |
224 | private: |
225 | /// Indicates whether diagnostic message verification is requested. |
226 | const std::optional<mlir::SourceMgrDiagnosticVerifierHandler::Level> |
227 | verifyDiagnostics; |
228 | |
229 | /// Indicates that diagnostic message verification has taken place, and the |
230 | /// deconstruction is therefore safe. |
231 | bool resultChecked = false; |
232 | |
233 | /// Storage for per-buffer source managers and diagnostic handlers. These are |
234 | /// wrapped into unique pointers in order to make it safe to capture |
235 | /// references to these objects: if the vector is reallocated, the unique |
236 | /// pointer objects are moved by the pointer addresses won't change. Also, for |
237 | /// handlers, this allows to store the pointer to the base class. |
238 | SmallVector<std::unique_ptr<llvm::SourceMgr>> sourceMgrs; |
239 | SmallVector<DiagnosticHandlerWrapper> diagHandlers; |
240 | }; |
241 | } // namespace |
242 | |
243 | /// Trivial wrapper around `applyTransforms` that doesn't support extra mapping |
244 | /// and doesn't enforce the entry point transform ops being top-level. |
245 | static llvm::LogicalResult |
246 | applyTransforms(mlir::Operation *payloadRoot, |
247 | mlir::transform::TransformOpInterface transformRoot, |
248 | const mlir::transform::TransformOptions &options) { |
249 | return applyTransforms(payloadRoot, transformRoot, {}, options, |
250 | /*enforceToplevelTransformOp=*/false); |
251 | } |
252 | |
253 | /// Applies transforms indicated in the transform dialect script to the input |
254 | /// buffer. The transform script may be embedded in the input buffer or as a |
255 | /// separate buffer. The transform script may have external symbols, the |
256 | /// definitions of which must be provided in transform library buffers. If the |
257 | /// application is successful, prints the transformed input buffer into the |
258 | /// given output stream. Additional configuration options are derived from |
259 | /// command-line options. |
260 | static llvm::LogicalResult processPayloadBuffer( |
261 | raw_ostream &os, std::unique_ptr<MemoryBuffer> inputBuffer, |
262 | std::unique_ptr<llvm::MemoryBuffer> transformBuffer, |
263 | MutableArrayRef<std::unique_ptr<MemoryBuffer>> transformLibraries, |
264 | mlir::DialectRegistry ®istry) { |
265 | |
266 | // Initialize the MLIR context, and various configurations. |
267 | mlir::MLIRContext context(registry, mlir::MLIRContext::Threading::DISABLED); |
268 | context.allowUnregisteredDialects(allow: clOptions->allowUnregisteredDialects); |
269 | mlir::ParserConfig config(&context); |
270 | TransformSourceMgr sourceMgr( |
271 | /*verifyDiagnostics=*/clOptions->verifyDiagnostics.getNumOccurrences() |
272 | ? std::optional{clOptions->verifyDiagnostics.getValue()} |
273 | : std::nullopt); |
274 | |
275 | // Parse the input buffer that will be used as transform payload. |
276 | mlir::OwningOpRef<mlir::Operation *> payloadRoot = |
277 | sourceMgr.parseBuffer(buffer: std::move(inputBuffer), context, config); |
278 | if (!payloadRoot) |
279 | return sourceMgr.checkResult(result: mlir::failure()); |
280 | |
281 | // Identify the module containing the transform script entry point. This may |
282 | // be the same module as the input or a separate module. In the former case, |
283 | // make a copy of the module so it can be modified freely. Modification may |
284 | // happen in the script itself (at which point it could be rewriting itself |
285 | // during interpretation, leading to tricky memory errors) or by embedding |
286 | // library modules in the script. |
287 | mlir::OwningOpRef<mlir::ModuleOp> transformRoot; |
288 | if (transformBuffer) { |
289 | transformRoot = sourceMgr.parseBuffer<mlir::ModuleOp>( |
290 | buffer: std::move(transformBuffer), context, config); |
291 | if (!transformRoot) |
292 | return sourceMgr.checkResult(result: mlir::failure()); |
293 | } else { |
294 | transformRoot = cast<mlir::ModuleOp>(payloadRoot->clone()); |
295 | } |
296 | |
297 | // Parse and merge the libraries into the main transform module. |
298 | for (auto &&transformLibrary : transformLibraries) { |
299 | mlir::OwningOpRef<mlir::ModuleOp> libraryModule = |
300 | sourceMgr.parseBuffer<mlir::ModuleOp>(buffer: std::move(transformLibrary), |
301 | context, config); |
302 | |
303 | if (!libraryModule || |
304 | mlir::failed(Result: mlir::transform::detail::mergeSymbolsInto( |
305 | target: *transformRoot, other: std::move(libraryModule)))) |
306 | return sourceMgr.checkResult(result: mlir::failure()); |
307 | } |
308 | |
309 | // If requested, dump the combined transform module. |
310 | if (clOptions->dumpLibraryModule) |
311 | transformRoot->dump(); |
312 | |
313 | // Find the entry point symbol. Even if it had originally been in the payload |
314 | // module, it was cloned into the transform module so only look there. |
315 | mlir::transform::TransformOpInterface entryPoint = |
316 | mlir::transform::detail::findTransformEntryPoint( |
317 | *transformRoot, mlir::ModuleOp(), clOptions->transformEntryPoint); |
318 | if (!entryPoint) |
319 | return sourceMgr.checkResult(result: mlir::failure()); |
320 | |
321 | // Apply the requested transformations. |
322 | mlir::transform::TransformOptions transformOptions; |
323 | transformOptions.enableExpensiveChecks(enable: !clOptions->disableExpensiveChecks); |
324 | if (mlir::failed(Result: applyTransforms(*payloadRoot, entryPoint, transformOptions))) |
325 | return sourceMgr.checkResult(result: mlir::failure()); |
326 | |
327 | // Print the transformed result and check the captured diagnostics if |
328 | // requested. |
329 | payloadRoot->print(os); |
330 | return sourceMgr.checkResult(result: mlir::success()); |
331 | } |
332 | |
333 | /// Tool entry point. |
334 | static llvm::LogicalResult runMain(int argc, char **argv) { |
335 | // Register all upstream dialects and extensions. Specific uses are advised |
336 | // not to register all dialects indiscriminately but rather hand-pick what is |
337 | // necessary for their use case. |
338 | mlir::DialectRegistry registry; |
339 | mlir::registerAllDialects(registry); |
340 | mlir::registerAllExtensions(registry); |
341 | mlir::registerAllPasses(); |
342 | |
343 | // Explicitly register the transform dialect. This is not strictly necessary |
344 | // since it has been already registered as part of the upstream dialect list, |
345 | // but useful for example purposes for cases when dialects to register are |
346 | // hand-picked. The transform dialect must be registered. |
347 | registry.insert<mlir::transform::TransformDialect>(); |
348 | |
349 | // Register various command-line options. Note that the LLVM initializer |
350 | // object is a RAII that ensures correct deconstruction of command-line option |
351 | // objects inside ManagedStatic. |
352 | llvm::InitLLVM y(argc, argv); |
353 | mlir::registerAsmPrinterCLOptions(); |
354 | mlir::registerMLIRContextCLOptions(); |
355 | registerCLOptions(); |
356 | llvm::cl::ParseCommandLineOptions(argc, argv, |
357 | Overview: "Minimal Transform dialect driver\n" ); |
358 | |
359 | // Try opening the main input file. |
360 | std::string errorMessage; |
361 | std::unique_ptr<llvm::MemoryBuffer> payloadFile = |
362 | mlir::openInputFile(inputFilename: clOptions->payloadFilename, errorMessage: &errorMessage); |
363 | if (!payloadFile) { |
364 | llvm::errs() << errorMessage << "\n" ; |
365 | return mlir::failure(); |
366 | } |
367 | |
368 | // Try opening the output file. |
369 | std::unique_ptr<llvm::ToolOutputFile> outputFile = |
370 | mlir::openOutputFile(outputFilename: clOptions->outputFilename, errorMessage: &errorMessage); |
371 | if (!outputFile) { |
372 | llvm::errs() << errorMessage << "\n" ; |
373 | return mlir::failure(); |
374 | } |
375 | |
376 | // Try opening the main transform file if provided. |
377 | std::unique_ptr<llvm::MemoryBuffer> transformRootFile; |
378 | if (!clOptions->transformMainFilename.empty()) { |
379 | if (clOptions->transformMainFilename == clOptions->payloadFilename) { |
380 | llvm::errs() << "warning: " << clOptions->payloadFilename |
381 | << " is provided as both payload and transform file\n" ; |
382 | } else { |
383 | transformRootFile = |
384 | mlir::openInputFile(inputFilename: clOptions->transformMainFilename, errorMessage: &errorMessage); |
385 | if (!transformRootFile) { |
386 | llvm::errs() << errorMessage << "\n" ; |
387 | return mlir::failure(); |
388 | } |
389 | } |
390 | } |
391 | |
392 | // Try opening transform library files if provided. |
393 | SmallVector<std::unique_ptr<llvm::MemoryBuffer>> transformLibraries; |
394 | transformLibraries.reserve(N: clOptions->transformLibraryFilenames.size()); |
395 | for (llvm::StringRef filename : clOptions->transformLibraryFilenames) { |
396 | transformLibraries.emplace_back( |
397 | Args: mlir::openInputFile(inputFilename: filename, errorMessage: &errorMessage)); |
398 | if (!transformLibraries.back()) { |
399 | llvm::errs() << errorMessage << "\n" ; |
400 | return mlir::failure(); |
401 | } |
402 | } |
403 | |
404 | return processPayloadBuffer(os&: outputFile->os(), inputBuffer: std::move(payloadFile), |
405 | transformBuffer: std::move(transformRootFile), transformLibraries, |
406 | registry); |
407 | } |
408 | |
409 | int main(int argc, char **argv) { |
410 | return mlir::asMainReturnCode(r: runMain(argc, argv)); |
411 | } |
412 | |