1//===- TransformInterpreterUtils.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Lightweight transform dialect interpreter utilities.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
14#include "mlir/Dialect/Transform/IR/TransformDialect.h"
15#include "mlir/Dialect/Transform/IR/TransformOps.h"
16#include "mlir/Dialect/Transform/IR/Utils.h"
17#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18#include "mlir/IR/BuiltinOps.h"
19#include "mlir/IR/Verifier.h"
20#include "mlir/IR/Visitors.h"
21#include "mlir/Parser/Parser.h"
22#include "mlir/Support/FileUtilities.h"
23#include "llvm/ADT/StringRef.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/FileSystem.h"
26#include "llvm/Support/SourceMgr.h"
27#include "llvm/Support/raw_ostream.h"
28
29using namespace mlir;
30
31#define DEBUG_TYPE "transform-dialect-interpreter-utils"
32#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
33
34/// Expands the given list of `paths` to a list of `.mlir` files.
35///
36/// Each entry in `paths` may either be a regular file, in which case it ends up
37/// in the result list, or a directory, in which case all (regular) `.mlir`
38/// files in that directory are added. Any other file types lead to a failure.
39LogicalResult transform::detail::expandPathsToMLIRFiles(
40 ArrayRef<std::string> paths, MLIRContext *context,
41 SmallVectorImpl<std::string> &fileNames) {
42 for (const std::string &path : paths) {
43 auto loc = FileLineColLoc::get(context, fileName: path, line: 0, column: 0);
44
45 if (llvm::sys::fs::is_regular_file(Path: path)) {
46 LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
47 fileNames.push_back(Elt: path);
48 continue;
49 }
50
51 if (!llvm::sys::fs::is_directory(Path: path)) {
52 return emitError(loc)
53 << "'" << path << "' is neither a file nor a directory";
54 }
55
56 LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");
57
58 std::error_code ec;
59 for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
60 it != itEnd && !ec; it.increment(ec)) {
61 const std::string &fileName = it->path();
62
63 if (it->type() != llvm::sys::fs::file_type::regular_file &&
64 it->type() != llvm::sys::fs::file_type::symlink_file) {
65 LLVM_DEBUG(DBGS() << " Skipping non-regular file '" << fileName
66 << "'\n");
67 continue;
68 }
69
70 if (!StringRef(fileName).ends_with(Suffix: ".mlir")) {
71 LLVM_DEBUG(DBGS() << " Skipping '" << fileName
72 << "' because it does not end with '.mlir'\n");
73 continue;
74 }
75
76 LLVM_DEBUG(DBGS() << " Adding '" << fileName << "' to list of files\n");
77 fileNames.push_back(Elt: fileName);
78 }
79
80 if (ec)
81 return emitError(loc) << "error while opening files in '" << path
82 << "': " << ec.message();
83 }
84
85 return success();
86}
87
88LogicalResult transform::detail::parseTransformModuleFromFile(
89 MLIRContext *context, llvm::StringRef transformFileName,
90 OwningOpRef<ModuleOp> &transformModule) {
91 if (transformFileName.empty()) {
92 LLVM_DEBUG(
93 DBGS() << "no transform file name specified, assuming the transform "
94 "module is embedded in the IR next to the top-level\n");
95 return success();
96 }
97 // Parse transformFileName content into a ModuleOp.
98 std::string errorMessage;
99 auto memoryBuffer = mlir::openInputFile(inputFilename: transformFileName, errorMessage: &errorMessage);
100 if (!memoryBuffer) {
101 return emitError(loc: FileLineColLoc::get(
102 filename: StringAttr::get(context, bytes: transformFileName), line: 0, column: 0))
103 << "failed to open transform file: " << errorMessage;
104 }
105 // Tell sourceMgr about this buffer, the parser will pick it up.
106 llvm::SourceMgr sourceMgr;
107 sourceMgr.AddNewSourceBuffer(F: std::move(memoryBuffer), IncludeLoc: llvm::SMLoc());
108 transformModule =
109 OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, config: context));
110 if (!transformModule) {
111 // Failed to parse the transform module.
112 // Don't need to emit an error here as the parsing should have already done
113 // that.
114 return failure();
115 }
116 return mlir::verify(op: *transformModule);
117}
118
119ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) {
120 return context->getOrLoadDialect<transform::TransformDialect>()
121 ->getLibraryModule();
122}
123
124transform::TransformOpInterface
125transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
126 StringRef entryPoint) {
127 SmallVector<Operation *, 2> l{root};
128 if (module)
129 l.push_back(Elt: module);
130 for (Operation *op : l) {
131 transform::TransformOpInterface transform = nullptr;
132 op->walk<WalkOrder::PreOrder>(
133 callback: [&](transform::NamedSequenceOp namedSequenceOp) {
134 if (namedSequenceOp.getSymName() == entryPoint) {
135 transform = cast<transform::TransformOpInterface>(
136 Val: namedSequenceOp.getOperation());
137 return WalkResult::interrupt();
138 }
139 return WalkResult::advance();
140 });
141 if (transform)
142 return transform;
143 }
144 auto diag = root->emitError()
145 << "could not find a nested named sequence with name: "
146 << entryPoint;
147 return nullptr;
148}
149
150LogicalResult transform::detail::assembleTransformLibraryFromPaths(
151 MLIRContext *context, ArrayRef<std::string> transformLibraryPaths,
152 OwningOpRef<ModuleOp> &transformModule) {
153 // Assemble list of library files.
154 SmallVector<std::string> libraryFileNames;
155 if (failed(Result: detail::expandPathsToMLIRFiles(paths: transformLibraryPaths, context,
156 fileNames&: libraryFileNames)))
157 return failure();
158
159 // Parse modules from library files.
160 SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
161 for (const std::string &libraryFileName : libraryFileNames) {
162 OwningOpRef<ModuleOp> parsedLibrary;
163 if (failed(Result: transform::detail::parseTransformModuleFromFile(
164 context, transformFileName: libraryFileName, transformModule&: parsedLibrary)))
165 return failure();
166 parsedLibraries.push_back(Elt: std::move(parsedLibrary));
167 }
168
169 // Merge parsed libraries into one module.
170 auto loc = FileLineColLoc::get(context, fileName: "<shared-library-module>", line: 0, column: 0);
171 OwningOpRef<ModuleOp> mergedParsedLibraries =
172 ModuleOp::create(loc, name: "__transform");
173 {
174 mergedParsedLibraries.get()->setAttr(name: "transform.with_named_sequence",
175 value: UnitAttr::get(context));
176 // TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
177 for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
178 if (failed(Result: transform::detail::mergeSymbolsInto(
179 target: mergedParsedLibraries.get(), other: std::move(parsedLibrary))))
180 return parsedLibrary->emitError()
181 << "failed to merge symbols into shared library module";
182 }
183 }
184
185 transformModule = std::move(mergedParsedLibraries);
186 return success();
187}
188
189LogicalResult transform::applyTransformNamedSequence(
190 Operation *payload, Operation *transformRoot, ModuleOp transformModule,
191 const TransformOptions &options) {
192 RaggedArray<MappedValue> bindings;
193 bindings.push_back(elements: ArrayRef<Operation *>{payload});
194 return applyTransformNamedSequence(bindings,
195 transformRoot: cast<TransformOpInterface>(Val: transformRoot),
196 transformModule, options);
197}
198
199LogicalResult transform::applyTransformNamedSequence(
200 RaggedArray<MappedValue> bindings, TransformOpInterface transformRoot,
201 ModuleOp transformModule, const TransformOptions &options) {
202 if (bindings.empty()) {
203 return transformRoot.emitError()
204 << "expected at least one binding for the root";
205 }
206 if (bindings.at(pos: 0).size() != 1) {
207 return transformRoot.emitError()
208 << "expected one payload to be bound to the first argument, got "
209 << bindings.at(pos: 0).size();
210 }
211 auto *payloadRoot = dyn_cast<Operation *>(Val&: bindings.at(pos: 0).front());
212 if (!payloadRoot) {
213 return transformRoot->emitError() << "expected the object bound to the "
214 "first argument to be an operation";
215 }
216
217 bindings.removeFront();
218
219 // `transformModule` may not be modified.
220 if (transformModule && !transformModule->isAncestor(other: transformRoot)) {
221 OwningOpRef<Operation *> clonedTransformModule(transformModule->clone());
222 if (failed(Result: detail::mergeSymbolsInto(
223 target: SymbolTable::getNearestSymbolTable(from: transformRoot),
224 other: std::move(clonedTransformModule)))) {
225 return payloadRoot->emitError() << "failed to merge symbols";
226 }
227 }
228
229 LLVM_DEBUG(DBGS() << "Apply\n" << *transformRoot << "\n");
230 LLVM_DEBUG(DBGS() << "To\n" << *payloadRoot << "\n");
231
232 return applyTransforms(payloadRoot, transform: transformRoot, extraMapping: bindings, options,
233 /*enforceToplevelTransformOp=*/false);
234}
235

source code of mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp