1 | //===- TransformInterpreterUtils.cpp --------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Lightweight transform dialect interpreter utilities. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" |
14 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
15 | #include "mlir/Dialect/Transform/IR/TransformOps.h" |
16 | #include "mlir/Dialect/Transform/IR/Utils.h" |
17 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
18 | #include "mlir/IR/BuiltinOps.h" |
19 | #include "mlir/IR/Verifier.h" |
20 | #include "mlir/IR/Visitors.h" |
21 | #include "mlir/Interfaces/FunctionInterfaces.h" |
22 | #include "mlir/Parser/Parser.h" |
23 | #include "mlir/Support/FileUtilities.h" |
24 | #include "llvm/ADT/StringRef.h" |
25 | #include "llvm/Support/Casting.h" |
26 | #include "llvm/Support/Debug.h" |
27 | #include "llvm/Support/FileSystem.h" |
28 | #include "llvm/Support/SourceMgr.h" |
29 | #include "llvm/Support/raw_ostream.h" |
30 | |
31 | using namespace mlir; |
32 | |
33 | #define DEBUG_TYPE "transform-dialect-interpreter-utils" |
34 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
35 | |
36 | /// Expands the given list of `paths` to a list of `.mlir` files. |
37 | /// |
38 | /// Each entry in `paths` may either be a regular file, in which case it ends up |
39 | /// in the result list, or a directory, in which case all (regular) `.mlir` |
40 | /// files in that directory are added. Any other file types lead to a failure. |
41 | LogicalResult transform::detail::expandPathsToMLIRFiles( |
42 | ArrayRef<std::string> paths, MLIRContext *context, |
43 | SmallVectorImpl<std::string> &fileNames) { |
44 | for (const std::string &path : paths) { |
45 | auto loc = FileLineColLoc::get(context, path, 0, 0); |
46 | |
47 | if (llvm::sys::fs::is_regular_file(Path: path)) { |
48 | LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n" ); |
49 | fileNames.push_back(Elt: path); |
50 | continue; |
51 | } |
52 | |
53 | if (!llvm::sys::fs::is_directory(Path: path)) { |
54 | return emitError(loc) |
55 | << "'" << path << "' is neither a file nor a directory" ; |
56 | } |
57 | |
58 | LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n" ); |
59 | |
60 | std::error_code ec; |
61 | for (llvm::sys::fs::directory_iterator it(path, ec), itEnd; |
62 | it != itEnd && !ec; it.increment(ec)) { |
63 | const std::string &fileName = it->path(); |
64 | |
65 | if (it->type() != llvm::sys::fs::file_type::regular_file && |
66 | it->type() != llvm::sys::fs::file_type::symlink_file) { |
67 | LLVM_DEBUG(DBGS() << " Skipping non-regular file '" << fileName |
68 | << "'\n" ); |
69 | continue; |
70 | } |
71 | |
72 | if (!StringRef(fileName).ends_with(Suffix: ".mlir" )) { |
73 | LLVM_DEBUG(DBGS() << " Skipping '" << fileName |
74 | << "' because it does not end with '.mlir'\n" ); |
75 | continue; |
76 | } |
77 | |
78 | LLVM_DEBUG(DBGS() << " Adding '" << fileName << "' to list of files\n" ); |
79 | fileNames.push_back(Elt: fileName); |
80 | } |
81 | |
82 | if (ec) |
83 | return emitError(loc) << "error while opening files in '" << path |
84 | << "': " << ec.message(); |
85 | } |
86 | |
87 | return success(); |
88 | } |
89 | |
90 | LogicalResult transform::detail::parseTransformModuleFromFile( |
91 | MLIRContext *context, llvm::StringRef transformFileName, |
92 | OwningOpRef<ModuleOp> &transformModule) { |
93 | if (transformFileName.empty()) { |
94 | LLVM_DEBUG( |
95 | DBGS() << "no transform file name specified, assuming the transform " |
96 | "module is embedded in the IR next to the top-level\n" ); |
97 | return success(); |
98 | } |
99 | // Parse transformFileName content into a ModuleOp. |
100 | std::string errorMessage; |
101 | auto memoryBuffer = mlir::openInputFile(inputFilename: transformFileName, errorMessage: &errorMessage); |
102 | if (!memoryBuffer) { |
103 | return emitError(FileLineColLoc::get( |
104 | StringAttr::get(context, transformFileName), 0, 0)) |
105 | << "failed to open transform file: " << errorMessage; |
106 | } |
107 | // Tell sourceMgr about this buffer, the parser will pick it up. |
108 | llvm::SourceMgr sourceMgr; |
109 | sourceMgr.AddNewSourceBuffer(F: std::move(memoryBuffer), IncludeLoc: llvm::SMLoc()); |
110 | transformModule = |
111 | OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, config: context)); |
112 | if (!transformModule) { |
113 | // Failed to parse the transform module. |
114 | // Don't need to emit an error here as the parsing should have already done |
115 | // that. |
116 | return failure(); |
117 | } |
118 | return mlir::verify(op: *transformModule); |
119 | } |
120 | |
121 | ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) { |
122 | return context->getOrLoadDialect<transform::TransformDialect>() |
123 | ->getLibraryModule(); |
124 | } |
125 | |
126 | transform::TransformOpInterface |
127 | transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module, |
128 | StringRef entryPoint) { |
129 | SmallVector<Operation *, 2> l{root}; |
130 | if (module) |
131 | l.push_back(Elt: module); |
132 | for (Operation *op : l) { |
133 | transform::TransformOpInterface transform = nullptr; |
134 | op->walk<WalkOrder::PreOrder>( |
135 | callback: [&](transform::NamedSequenceOp namedSequenceOp) { |
136 | if (namedSequenceOp.getSymName() == entryPoint) { |
137 | transform = cast<transform::TransformOpInterface>( |
138 | namedSequenceOp.getOperation()); |
139 | return WalkResult::interrupt(); |
140 | } |
141 | return WalkResult::advance(); |
142 | }); |
143 | if (transform) |
144 | return transform; |
145 | } |
146 | auto diag = root->emitError() |
147 | << "could not find a nested named sequence with name: " |
148 | << entryPoint; |
149 | return nullptr; |
150 | } |
151 | |
152 | LogicalResult transform::detail::assembleTransformLibraryFromPaths( |
153 | MLIRContext *context, ArrayRef<std::string> transformLibraryPaths, |
154 | OwningOpRef<ModuleOp> &transformModule) { |
155 | // Assemble list of library files. |
156 | SmallVector<std::string> libraryFileNames; |
157 | if (failed(result: detail::expandPathsToMLIRFiles(paths: transformLibraryPaths, context, |
158 | fileNames&: libraryFileNames))) |
159 | return failure(); |
160 | |
161 | // Parse modules from library files. |
162 | SmallVector<OwningOpRef<ModuleOp>> parsedLibraries; |
163 | for (const std::string &libraryFileName : libraryFileNames) { |
164 | OwningOpRef<ModuleOp> parsedLibrary; |
165 | if (failed(result: transform::detail::parseTransformModuleFromFile( |
166 | context, transformFileName: libraryFileName, transformModule&: parsedLibrary))) |
167 | return failure(); |
168 | parsedLibraries.push_back(Elt: std::move(parsedLibrary)); |
169 | } |
170 | |
171 | // Merge parsed libraries into one module. |
172 | auto loc = FileLineColLoc::get(context, "<shared-library-module>" , 0, 0); |
173 | OwningOpRef<ModuleOp> mergedParsedLibraries = |
174 | ModuleOp::create(loc, "__transform" ); |
175 | { |
176 | mergedParsedLibraries.get()->setAttr("transform.with_named_sequence" , |
177 | UnitAttr::get(context)); |
178 | // TODO: extend `mergeSymbolsInto` to support multiple `other` modules. |
179 | for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) { |
180 | if (failed(transform::detail::mergeSymbolsInto( |
181 | mergedParsedLibraries.get(), std::move(parsedLibrary)))) |
182 | return parsedLibrary->emitError() |
183 | << "failed to merge symbols into shared library module" ; |
184 | } |
185 | } |
186 | |
187 | transformModule = std::move(mergedParsedLibraries); |
188 | return success(); |
189 | } |
190 | |
191 | LogicalResult transform::applyTransformNamedSequence( |
192 | Operation *payload, Operation *transformRoot, ModuleOp transformModule, |
193 | const TransformOptions &options) { |
194 | RaggedArray<MappedValue> bindings; |
195 | bindings.push_back(elements: ArrayRef<Operation *>{payload}); |
196 | return applyTransformNamedSequence(bindings, |
197 | cast<TransformOpInterface>(transformRoot), |
198 | transformModule, options); |
199 | } |
200 | |
201 | LogicalResult transform::applyTransformNamedSequence( |
202 | RaggedArray<MappedValue> bindings, TransformOpInterface transformRoot, |
203 | ModuleOp transformModule, const TransformOptions &options) { |
204 | if (bindings.empty()) { |
205 | return transformRoot.emitError() |
206 | << "expected at least one binding for the root" ; |
207 | } |
208 | if (bindings.at(pos: 0).size() != 1) { |
209 | return transformRoot.emitError() |
210 | << "expected one payload to be bound to the first argument, got " |
211 | << bindings.at(pos: 0).size(); |
212 | } |
213 | auto *payloadRoot = bindings.at(pos: 0).front().dyn_cast<Operation *>(); |
214 | if (!payloadRoot) { |
215 | return transformRoot->emitError() << "expected the object bound to the " |
216 | "first argument to be an operation" ; |
217 | } |
218 | |
219 | bindings.removeFront(); |
220 | |
221 | // `transformModule` may not be modified. |
222 | if (transformModule && !transformModule->isAncestor(transformRoot)) { |
223 | OwningOpRef<Operation *> clonedTransformModule(transformModule->clone()); |
224 | if (failed(detail::mergeSymbolsInto( |
225 | target: SymbolTable::getNearestSymbolTable(from: transformRoot), |
226 | other: std::move(clonedTransformModule)))) { |
227 | return payloadRoot->emitError() << "failed to merge symbols" ; |
228 | } |
229 | } |
230 | |
231 | LLVM_DEBUG(DBGS() << "Apply\n" << *transformRoot << "\n" ); |
232 | LLVM_DEBUG(DBGS() << "To\n" << *payloadRoot << "\n" ); |
233 | |
234 | return applyTransforms(payloadRoot, transformRoot, bindings, options, |
235 | /*enforceToplevelTransformOp=*/false); |
236 | } |
237 | |