1//===- InterpreterPass.cpp - Transform dialect interpreter pass -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Transform/IR/TransformDialect.h"
10#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
11#include "mlir/Dialect/Transform/Transforms/Passes.h"
12#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
13
14using namespace mlir;
15
16namespace mlir {
17namespace transform {
18#define GEN_PASS_DEF_INTERPRETERPASS
19#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
20} // namespace transform
21} // namespace mlir
22
23/// Returns the payload operation to be used as payload root:
24/// - the operation nested under `passRoot` that has the given tag attribute,
25/// must be unique;
26/// - the `passRoot` itself if the tag is empty.
27static Operation *findPayloadRoot(Operation *passRoot, StringRef tag) {
28 // Fast return.
29 if (tag.empty())
30 return passRoot;
31
32 // Walk to do a lookup.
33 Operation *target = nullptr;
34 auto tagAttrName = StringAttr::get(
35 passRoot->getContext(), transform::TransformDialect::kTargetTagAttrName);
36 WalkResult walkResult = passRoot->walk([&](Operation *op) {
37 auto attr = op->getAttrOfType<StringAttr>(tagAttrName);
38 if (!attr || attr.getValue() != tag)
39 return WalkResult::advance();
40
41 if (!target) {
42 target = op;
43 return WalkResult::advance();
44 }
45
46 InFlightDiagnostic diag = op->emitError()
47 << "repeated operation with the target tag '"
48 << tag << "'";
49 diag.attachNote(target->getLoc()) << "previously seen operation";
50 return WalkResult::interrupt();
51 });
52
53 if (!target) {
54 passRoot->emitError()
55 << "could not find the operation with transform.target_tag=\"" << tag
56 << "\" attribute";
57 return nullptr;
58 }
59
60 return walkResult.wasInterrupted() ? nullptr : target;
61}
62
63namespace {
64class InterpreterPass
65 : public transform::impl::InterpreterPassBase<InterpreterPass> {
66 // Parses the pass arguments to bind trailing arguments of the entry point.
67 std::optional<RaggedArray<transform::MappedValue>>
68 parseArguments(Operation *payloadRoot) {
69 MLIRContext *context = payloadRoot->getContext();
70
71 SmallVector<SmallVector<transform::MappedValue>, 2> trailingBindings;
72 trailingBindings.resize(debugBindTrailingArgs.size());
73
74 // Construct lists of op names to match.
75 SmallVector<std::optional<OperationName>> debugBindNames;
76 debugBindNames.reserve(N: debugBindTrailingArgs.size());
77 for (auto &&[position, nameString] :
78 llvm::enumerate(debugBindTrailingArgs)) {
79 StringRef name = nameString;
80
81 // Parse the integer literals.
82 if (name.starts_with("#")) {
83 debugBindNames.push_back(std::nullopt);
84 StringRef lhs = "";
85 StringRef rhs = name.drop_front();
86 do {
87 std::tie(lhs, rhs) = rhs.split(';');
88 int64_t value;
89 if (lhs.getAsInteger(10, value)) {
90 emitError(UnknownLoc::get(context))
91 << "couldn't parse integer pass argument " << name;
92 return std::nullopt;
93 }
94 trailingBindings[position].push_back(
95 Builder(context).getI64IntegerAttr(value));
96 } while (!rhs.empty());
97 } else if (name.starts_with("^")) {
98 debugBindNames.emplace_back(OperationName(name.drop_front(), context));
99 } else {
100 debugBindNames.emplace_back(OperationName(name, context));
101 }
102 }
103
104 // Collect operations or results for extra bindings.
105 payloadRoot->walk(callback: [&](Operation *payload) {
106 for (auto &&[position, name] : llvm::enumerate(First&: debugBindNames)) {
107 if (!name || payload->getName() != *name)
108 continue;
109
110 if (StringRef(*std::next(debugBindTrailingArgs.begin(), position))
111 .starts_with("^")) {
112 llvm::append_range(C&: trailingBindings[position], R: payload->getResults());
113 } else {
114 trailingBindings[position].push_back(Elt: payload);
115 }
116 }
117 });
118
119 RaggedArray<transform::MappedValue> bindings;
120 bindings.push_back(elements: ArrayRef<Operation *>{payloadRoot});
121 for (SmallVector<transform::MappedValue> &trailing : trailingBindings)
122 bindings.push_back(elements: std::move(trailing));
123 return bindings;
124 }
125
126public:
127 using Base::Base;
128
129 void runOnOperation() override {
130 MLIRContext *context = &getContext();
131 ModuleOp transformModule =
132 transform::detail::getPreloadedTransformModule(context);
133 Operation *payloadRoot =
134 findPayloadRoot(getOperation(), debugPayloadRootTag);
135 if (!payloadRoot)
136 return signalPassFailure();
137
138 Operation *transformEntryPoint = transform::detail::findTransformEntryPoint(
139 getOperation(), transformModule, entryPoint);
140 if (!transformEntryPoint)
141 return signalPassFailure();
142
143 std::optional<RaggedArray<transform::MappedValue>> bindings =
144 parseArguments(payloadRoot);
145 if (!bindings)
146 return signalPassFailure();
147 if (failed(transform::applyTransformNamedSequence(
148 *bindings,
149 cast<transform::TransformOpInterface>(transformEntryPoint),
150 transformModule,
151 options.enableExpensiveChecks(!disableExpensiveChecks)))) {
152 return signalPassFailure();
153 }
154 }
155
156private:
157 /// Transform interpreter options.
158 transform::TransformOptions options;
159};
160} // namespace
161

source code of mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp