1//===- TransformInterpreterUtils.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Lightweight transform dialect interpreter utilities.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
14#include "mlir/Dialect/Transform/IR/TransformDialect.h"
15#include "mlir/Dialect/Transform/IR/TransformOps.h"
16#include "mlir/Dialect/Transform/IR/Utils.h"
17#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18#include "mlir/IR/BuiltinOps.h"
19#include "mlir/IR/Verifier.h"
20#include "mlir/IR/Visitors.h"
21#include "mlir/Interfaces/FunctionInterfaces.h"
22#include "mlir/Parser/Parser.h"
23#include "mlir/Support/FileUtilities.h"
24#include "llvm/ADT/StringRef.h"
25#include "llvm/Support/Casting.h"
26#include "llvm/Support/Debug.h"
27#include "llvm/Support/FileSystem.h"
28#include "llvm/Support/SourceMgr.h"
29#include "llvm/Support/raw_ostream.h"
30
31using namespace mlir;
32
33#define DEBUG_TYPE "transform-dialect-interpreter-utils"
34#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
35
36/// Expands the given list of `paths` to a list of `.mlir` files.
37///
38/// Each entry in `paths` may either be a regular file, in which case it ends up
39/// in the result list, or a directory, in which case all (regular) `.mlir`
40/// files in that directory are added. Any other file types lead to a failure.
41LogicalResult transform::detail::expandPathsToMLIRFiles(
42 ArrayRef<std::string> paths, MLIRContext *context,
43 SmallVectorImpl<std::string> &fileNames) {
44 for (const std::string &path : paths) {
45 auto loc = FileLineColLoc::get(context, path, 0, 0);
46
47 if (llvm::sys::fs::is_regular_file(Path: path)) {
48 LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
49 fileNames.push_back(Elt: path);
50 continue;
51 }
52
53 if (!llvm::sys::fs::is_directory(Path: path)) {
54 return emitError(loc)
55 << "'" << path << "' is neither a file nor a directory";
56 }
57
58 LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");
59
60 std::error_code ec;
61 for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
62 it != itEnd && !ec; it.increment(ec)) {
63 const std::string &fileName = it->path();
64
65 if (it->type() != llvm::sys::fs::file_type::regular_file &&
66 it->type() != llvm::sys::fs::file_type::symlink_file) {
67 LLVM_DEBUG(DBGS() << " Skipping non-regular file '" << fileName
68 << "'\n");
69 continue;
70 }
71
72 if (!StringRef(fileName).ends_with(Suffix: ".mlir")) {
73 LLVM_DEBUG(DBGS() << " Skipping '" << fileName
74 << "' because it does not end with '.mlir'\n");
75 continue;
76 }
77
78 LLVM_DEBUG(DBGS() << " Adding '" << fileName << "' to list of files\n");
79 fileNames.push_back(Elt: fileName);
80 }
81
82 if (ec)
83 return emitError(loc) << "error while opening files in '" << path
84 << "': " << ec.message();
85 }
86
87 return success();
88}
89
90LogicalResult transform::detail::parseTransformModuleFromFile(
91 MLIRContext *context, llvm::StringRef transformFileName,
92 OwningOpRef<ModuleOp> &transformModule) {
93 if (transformFileName.empty()) {
94 LLVM_DEBUG(
95 DBGS() << "no transform file name specified, assuming the transform "
96 "module is embedded in the IR next to the top-level\n");
97 return success();
98 }
99 // Parse transformFileName content into a ModuleOp.
100 std::string errorMessage;
101 auto memoryBuffer = mlir::openInputFile(inputFilename: transformFileName, errorMessage: &errorMessage);
102 if (!memoryBuffer) {
103 return emitError(FileLineColLoc::get(
104 StringAttr::get(context, transformFileName), 0, 0))
105 << "failed to open transform file: " << errorMessage;
106 }
107 // Tell sourceMgr about this buffer, the parser will pick it up.
108 llvm::SourceMgr sourceMgr;
109 sourceMgr.AddNewSourceBuffer(F: std::move(memoryBuffer), IncludeLoc: llvm::SMLoc());
110 transformModule =
111 OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, config: context));
112 if (!transformModule) {
113 // Failed to parse the transform module.
114 // Don't need to emit an error here as the parsing should have already done
115 // that.
116 return failure();
117 }
118 return mlir::verify(op: *transformModule);
119}
120
121ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) {
122 return context->getOrLoadDialect<transform::TransformDialect>()
123 ->getLibraryModule();
124}
125
126transform::TransformOpInterface
127transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
128 StringRef entryPoint) {
129 SmallVector<Operation *, 2> l{root};
130 if (module)
131 l.push_back(Elt: module);
132 for (Operation *op : l) {
133 transform::TransformOpInterface transform = nullptr;
134 op->walk<WalkOrder::PreOrder>(
135 callback: [&](transform::NamedSequenceOp namedSequenceOp) {
136 if (namedSequenceOp.getSymName() == entryPoint) {
137 transform = cast<transform::TransformOpInterface>(
138 namedSequenceOp.getOperation());
139 return WalkResult::interrupt();
140 }
141 return WalkResult::advance();
142 });
143 if (transform)
144 return transform;
145 }
146 auto diag = root->emitError()
147 << "could not find a nested named sequence with name: "
148 << entryPoint;
149 return nullptr;
150}
151
152LogicalResult transform::detail::assembleTransformLibraryFromPaths(
153 MLIRContext *context, ArrayRef<std::string> transformLibraryPaths,
154 OwningOpRef<ModuleOp> &transformModule) {
155 // Assemble list of library files.
156 SmallVector<std::string> libraryFileNames;
157 if (failed(result: detail::expandPathsToMLIRFiles(paths: transformLibraryPaths, context,
158 fileNames&: libraryFileNames)))
159 return failure();
160
161 // Parse modules from library files.
162 SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
163 for (const std::string &libraryFileName : libraryFileNames) {
164 OwningOpRef<ModuleOp> parsedLibrary;
165 if (failed(result: transform::detail::parseTransformModuleFromFile(
166 context, transformFileName: libraryFileName, transformModule&: parsedLibrary)))
167 return failure();
168 parsedLibraries.push_back(Elt: std::move(parsedLibrary));
169 }
170
171 // Merge parsed libraries into one module.
172 auto loc = FileLineColLoc::get(context, "<shared-library-module>", 0, 0);
173 OwningOpRef<ModuleOp> mergedParsedLibraries =
174 ModuleOp::create(loc, "__transform");
175 {
176 mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
177 UnitAttr::get(context));
178 // TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
179 for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
180 if (failed(transform::detail::mergeSymbolsInto(
181 mergedParsedLibraries.get(), std::move(parsedLibrary))))
182 return parsedLibrary->emitError()
183 << "failed to merge symbols into shared library module";
184 }
185 }
186
187 transformModule = std::move(mergedParsedLibraries);
188 return success();
189}
190
191LogicalResult transform::applyTransformNamedSequence(
192 Operation *payload, Operation *transformRoot, ModuleOp transformModule,
193 const TransformOptions &options) {
194 RaggedArray<MappedValue> bindings;
195 bindings.push_back(elements: ArrayRef<Operation *>{payload});
196 return applyTransformNamedSequence(bindings,
197 cast<TransformOpInterface>(transformRoot),
198 transformModule, options);
199}
200
201LogicalResult transform::applyTransformNamedSequence(
202 RaggedArray<MappedValue> bindings, TransformOpInterface transformRoot,
203 ModuleOp transformModule, const TransformOptions &options) {
204 if (bindings.empty()) {
205 return transformRoot.emitError()
206 << "expected at least one binding for the root";
207 }
208 if (bindings.at(pos: 0).size() != 1) {
209 return transformRoot.emitError()
210 << "expected one payload to be bound to the first argument, got "
211 << bindings.at(pos: 0).size();
212 }
213 auto *payloadRoot = bindings.at(pos: 0).front().dyn_cast<Operation *>();
214 if (!payloadRoot) {
215 return transformRoot->emitError() << "expected the object bound to the "
216 "first argument to be an operation";
217 }
218
219 bindings.removeFront();
220
221 // `transformModule` may not be modified.
222 if (transformModule && !transformModule->isAncestor(transformRoot)) {
223 OwningOpRef<Operation *> clonedTransformModule(transformModule->clone());
224 if (failed(detail::mergeSymbolsInto(
225 target: SymbolTable::getNearestSymbolTable(from: transformRoot),
226 other: std::move(clonedTransformModule)))) {
227 return payloadRoot->emitError() << "failed to merge symbols";
228 }
229 }
230
231 LLVM_DEBUG(DBGS() << "Apply\n" << *transformRoot << "\n");
232 LLVM_DEBUG(DBGS() << "To\n" << *payloadRoot << "\n");
233
234 return applyTransforms(payloadRoot, transformRoot, bindings, options,
235 /*enforceToplevelTransformOp=*/false);
236}
237

source code of mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp