| 1 | //===- TransformInterpreterUtils.cpp --------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Lightweight transform dialect interpreter utilities. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" |
| 14 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
| 15 | #include "mlir/Dialect/Transform/IR/TransformOps.h" |
| 16 | #include "mlir/Dialect/Transform/IR/Utils.h" |
| 17 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
| 18 | #include "mlir/IR/BuiltinOps.h" |
| 19 | #include "mlir/IR/Verifier.h" |
| 20 | #include "mlir/IR/Visitors.h" |
| 21 | #include "mlir/Interfaces/FunctionInterfaces.h" |
| 22 | #include "mlir/Parser/Parser.h" |
| 23 | #include "mlir/Support/FileUtilities.h" |
| 24 | #include "llvm/ADT/StringRef.h" |
| 25 | #include "llvm/Support/Casting.h" |
| 26 | #include "llvm/Support/Debug.h" |
| 27 | #include "llvm/Support/FileSystem.h" |
| 28 | #include "llvm/Support/SourceMgr.h" |
| 29 | #include "llvm/Support/raw_ostream.h" |
| 30 | |
| 31 | using namespace mlir; |
| 32 | |
| 33 | #define DEBUG_TYPE "transform-dialect-interpreter-utils" |
| 34 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
| 35 | |
| 36 | /// Expands the given list of `paths` to a list of `.mlir` files. |
| 37 | /// |
| 38 | /// Each entry in `paths` may either be a regular file, in which case it ends up |
| 39 | /// in the result list, or a directory, in which case all (regular) `.mlir` |
| 40 | /// files in that directory are added. Any other file types lead to a failure. |
| 41 | LogicalResult transform::detail::expandPathsToMLIRFiles( |
| 42 | ArrayRef<std::string> paths, MLIRContext *context, |
| 43 | SmallVectorImpl<std::string> &fileNames) { |
| 44 | for (const std::string &path : paths) { |
| 45 | auto loc = FileLineColLoc::get(context, fileName: path, line: 0, column: 0); |
| 46 | |
| 47 | if (llvm::sys::fs::is_regular_file(Path: path)) { |
| 48 | LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n" ); |
| 49 | fileNames.push_back(Elt: path); |
| 50 | continue; |
| 51 | } |
| 52 | |
| 53 | if (!llvm::sys::fs::is_directory(Path: path)) { |
| 54 | return emitError(loc) |
| 55 | << "'" << path << "' is neither a file nor a directory" ; |
| 56 | } |
| 57 | |
| 58 | LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n" ); |
| 59 | |
| 60 | std::error_code ec; |
| 61 | for (llvm::sys::fs::directory_iterator it(path, ec), itEnd; |
| 62 | it != itEnd && !ec; it.increment(ec)) { |
| 63 | const std::string &fileName = it->path(); |
| 64 | |
| 65 | if (it->type() != llvm::sys::fs::file_type::regular_file && |
| 66 | it->type() != llvm::sys::fs::file_type::symlink_file) { |
| 67 | LLVM_DEBUG(DBGS() << " Skipping non-regular file '" << fileName |
| 68 | << "'\n" ); |
| 69 | continue; |
| 70 | } |
| 71 | |
| 72 | if (!StringRef(fileName).ends_with(Suffix: ".mlir" )) { |
| 73 | LLVM_DEBUG(DBGS() << " Skipping '" << fileName |
| 74 | << "' because it does not end with '.mlir'\n" ); |
| 75 | continue; |
| 76 | } |
| 77 | |
| 78 | LLVM_DEBUG(DBGS() << " Adding '" << fileName << "' to list of files\n" ); |
| 79 | fileNames.push_back(Elt: fileName); |
| 80 | } |
| 81 | |
| 82 | if (ec) |
| 83 | return emitError(loc) << "error while opening files in '" << path |
| 84 | << "': " << ec.message(); |
| 85 | } |
| 86 | |
| 87 | return success(); |
| 88 | } |
| 89 | |
| 90 | LogicalResult transform::detail::parseTransformModuleFromFile( |
| 91 | MLIRContext *context, llvm::StringRef transformFileName, |
| 92 | OwningOpRef<ModuleOp> &transformModule) { |
| 93 | if (transformFileName.empty()) { |
| 94 | LLVM_DEBUG( |
| 95 | DBGS() << "no transform file name specified, assuming the transform " |
| 96 | "module is embedded in the IR next to the top-level\n" ); |
| 97 | return success(); |
| 98 | } |
| 99 | // Parse transformFileName content into a ModuleOp. |
| 100 | std::string errorMessage; |
| 101 | auto memoryBuffer = mlir::openInputFile(inputFilename: transformFileName, errorMessage: &errorMessage); |
| 102 | if (!memoryBuffer) { |
| 103 | return emitError(FileLineColLoc::get( |
| 104 | StringAttr::get(context, transformFileName), 0, 0)) |
| 105 | << "failed to open transform file: " << errorMessage; |
| 106 | } |
| 107 | // Tell sourceMgr about this buffer, the parser will pick it up. |
| 108 | llvm::SourceMgr sourceMgr; |
| 109 | sourceMgr.AddNewSourceBuffer(F: std::move(memoryBuffer), IncludeLoc: llvm::SMLoc()); |
| 110 | transformModule = |
| 111 | OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, config: context)); |
| 112 | if (!transformModule) { |
| 113 | // Failed to parse the transform module. |
| 114 | // Don't need to emit an error here as the parsing should have already done |
| 115 | // that. |
| 116 | return failure(); |
| 117 | } |
| 118 | return mlir::verify(op: *transformModule); |
| 119 | } |
| 120 | |
| 121 | ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) { |
| 122 | return context->getOrLoadDialect<transform::TransformDialect>() |
| 123 | ->getLibraryModule(); |
| 124 | } |
| 125 | |
| 126 | transform::TransformOpInterface |
| 127 | transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module, |
| 128 | StringRef entryPoint) { |
| 129 | SmallVector<Operation *, 2> l{root}; |
| 130 | if (module) |
| 131 | l.push_back(Elt: module); |
| 132 | for (Operation *op : l) { |
| 133 | transform::TransformOpInterface transform = nullptr; |
| 134 | op->walk<WalkOrder::PreOrder>( |
| 135 | callback: [&](transform::NamedSequenceOp namedSequenceOp) { |
| 136 | if (namedSequenceOp.getSymName() == entryPoint) { |
| 137 | transform = cast<transform::TransformOpInterface>( |
| 138 | namedSequenceOp.getOperation()); |
| 139 | return WalkResult::interrupt(); |
| 140 | } |
| 141 | return WalkResult::advance(); |
| 142 | }); |
| 143 | if (transform) |
| 144 | return transform; |
| 145 | } |
| 146 | auto diag = root->emitError() |
| 147 | << "could not find a nested named sequence with name: " |
| 148 | << entryPoint; |
| 149 | return nullptr; |
| 150 | } |
| 151 | |
| 152 | LogicalResult transform::detail::assembleTransformLibraryFromPaths( |
| 153 | MLIRContext *context, ArrayRef<std::string> transformLibraryPaths, |
| 154 | OwningOpRef<ModuleOp> &transformModule) { |
| 155 | // Assemble list of library files. |
| 156 | SmallVector<std::string> libraryFileNames; |
| 157 | if (failed(Result: detail::expandPathsToMLIRFiles(paths: transformLibraryPaths, context, |
| 158 | fileNames&: libraryFileNames))) |
| 159 | return failure(); |
| 160 | |
| 161 | // Parse modules from library files. |
| 162 | SmallVector<OwningOpRef<ModuleOp>> parsedLibraries; |
| 163 | for (const std::string &libraryFileName : libraryFileNames) { |
| 164 | OwningOpRef<ModuleOp> parsedLibrary; |
| 165 | if (failed(Result: transform::detail::parseTransformModuleFromFile( |
| 166 | context, transformFileName: libraryFileName, transformModule&: parsedLibrary))) |
| 167 | return failure(); |
| 168 | parsedLibraries.push_back(Elt: std::move(parsedLibrary)); |
| 169 | } |
| 170 | |
| 171 | // Merge parsed libraries into one module. |
| 172 | auto loc = FileLineColLoc::get(context, fileName: "<shared-library-module>" , line: 0, column: 0); |
| 173 | OwningOpRef<ModuleOp> mergedParsedLibraries = |
| 174 | ModuleOp::create(loc, "__transform" ); |
| 175 | { |
| 176 | mergedParsedLibraries.get()->setAttr("transform.with_named_sequence" , |
| 177 | UnitAttr::get(context)); |
| 178 | // TODO: extend `mergeSymbolsInto` to support multiple `other` modules. |
| 179 | for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) { |
| 180 | if (failed(transform::detail::mergeSymbolsInto( |
| 181 | mergedParsedLibraries.get(), std::move(parsedLibrary)))) |
| 182 | return parsedLibrary->emitError() |
| 183 | << "failed to merge symbols into shared library module" ; |
| 184 | } |
| 185 | } |
| 186 | |
| 187 | transformModule = std::move(mergedParsedLibraries); |
| 188 | return success(); |
| 189 | } |
| 190 | |
| 191 | LogicalResult transform::applyTransformNamedSequence( |
| 192 | Operation *payload, Operation *transformRoot, ModuleOp transformModule, |
| 193 | const TransformOptions &options) { |
| 194 | RaggedArray<MappedValue> bindings; |
| 195 | bindings.push_back(elements: ArrayRef<Operation *>{payload}); |
| 196 | return applyTransformNamedSequence(bindings, |
| 197 | cast<TransformOpInterface>(transformRoot), |
| 198 | transformModule, options); |
| 199 | } |
| 200 | |
| 201 | LogicalResult transform::applyTransformNamedSequence( |
| 202 | RaggedArray<MappedValue> bindings, TransformOpInterface transformRoot, |
| 203 | ModuleOp transformModule, const TransformOptions &options) { |
| 204 | if (bindings.empty()) { |
| 205 | return transformRoot.emitError() |
| 206 | << "expected at least one binding for the root" ; |
| 207 | } |
| 208 | if (bindings.at(pos: 0).size() != 1) { |
| 209 | return transformRoot.emitError() |
| 210 | << "expected one payload to be bound to the first argument, got " |
| 211 | << bindings.at(pos: 0).size(); |
| 212 | } |
| 213 | auto *payloadRoot = dyn_cast<Operation *>(Val&: bindings.at(pos: 0).front()); |
| 214 | if (!payloadRoot) { |
| 215 | return transformRoot->emitError() << "expected the object bound to the " |
| 216 | "first argument to be an operation" ; |
| 217 | } |
| 218 | |
| 219 | bindings.removeFront(); |
| 220 | |
| 221 | // `transformModule` may not be modified. |
| 222 | if (transformModule && !transformModule->isAncestor(transformRoot)) { |
| 223 | OwningOpRef<Operation *> clonedTransformModule(transformModule->clone()); |
| 224 | if (failed(detail::mergeSymbolsInto( |
| 225 | target: SymbolTable::getNearestSymbolTable(from: transformRoot), |
| 226 | other: std::move(clonedTransformModule)))) { |
| 227 | return payloadRoot->emitError() << "failed to merge symbols" ; |
| 228 | } |
| 229 | } |
| 230 | |
| 231 | LLVM_DEBUG(DBGS() << "Apply\n" << *transformRoot << "\n" ); |
| 232 | LLVM_DEBUG(DBGS() << "To\n" << *payloadRoot << "\n" ); |
| 233 | |
| 234 | return applyTransforms(payloadRoot, transformRoot, bindings, options, |
| 235 | /*enforceToplevelTransformOp=*/false); |
| 236 | } |
| 237 | |