1 | //===- InterpreterPass.cpp - Transform dialect interpreter pass -----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
10 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
11 | #include "mlir/Dialect/Transform/Transforms/Passes.h" |
12 | #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" |
13 | |
14 | using namespace mlir; |
15 | |
16 | namespace mlir { |
17 | namespace transform { |
18 | #define GEN_PASS_DEF_INTERPRETERPASS |
19 | #include "mlir/Dialect/Transform/Transforms/Passes.h.inc" |
20 | } // namespace transform |
21 | } // namespace mlir |
22 | |
23 | /// Returns the payload operation to be used as payload root: |
24 | /// - the operation nested under `passRoot` that has the given tag attribute, |
25 | /// must be unique; |
26 | /// - the `passRoot` itself if the tag is empty. |
27 | static Operation *findPayloadRoot(Operation *passRoot, StringRef tag) { |
28 | // Fast return. |
29 | if (tag.empty()) |
30 | return passRoot; |
31 | |
32 | // Walk to do a lookup. |
33 | Operation *target = nullptr; |
34 | auto tagAttrName = StringAttr::get( |
35 | passRoot->getContext(), transform::TransformDialect::kTargetTagAttrName); |
36 | WalkResult walkResult = passRoot->walk([&](Operation *op) { |
37 | auto attr = op->getAttrOfType<StringAttr>(tagAttrName); |
38 | if (!attr || attr.getValue() != tag) |
39 | return WalkResult::advance(); |
40 | |
41 | if (!target) { |
42 | target = op; |
43 | return WalkResult::advance(); |
44 | } |
45 | |
46 | InFlightDiagnostic diag = op->emitError() |
47 | << "repeated operation with the target tag '" |
48 | << tag << "'" ; |
49 | diag.attachNote(target->getLoc()) << "previously seen operation" ; |
50 | return WalkResult::interrupt(); |
51 | }); |
52 | |
53 | if (!target) { |
54 | passRoot->emitError() |
55 | << "could not find the operation with transform.target_tag=\"" << tag |
56 | << "\" attribute" ; |
57 | return nullptr; |
58 | } |
59 | |
60 | return walkResult.wasInterrupted() ? nullptr : target; |
61 | } |
62 | |
63 | namespace { |
64 | class InterpreterPass |
65 | : public transform::impl::InterpreterPassBase<InterpreterPass> { |
66 | // Parses the pass arguments to bind trailing arguments of the entry point. |
67 | std::optional<RaggedArray<transform::MappedValue>> |
68 | parseArguments(Operation *payloadRoot) { |
69 | MLIRContext *context = payloadRoot->getContext(); |
70 | |
71 | SmallVector<SmallVector<transform::MappedValue>, 2> trailingBindings; |
72 | trailingBindings.resize(debugBindTrailingArgs.size()); |
73 | |
74 | // Construct lists of op names to match. |
75 | SmallVector<std::optional<OperationName>> debugBindNames; |
76 | debugBindNames.reserve(N: debugBindTrailingArgs.size()); |
77 | for (auto &&[position, nameString] : |
78 | llvm::enumerate(debugBindTrailingArgs)) { |
79 | StringRef name = nameString; |
80 | |
81 | // Parse the integer literals. |
82 | if (name.starts_with("#" )) { |
83 | debugBindNames.push_back(std::nullopt); |
84 | StringRef lhs = "" ; |
85 | StringRef rhs = name.drop_front(); |
86 | do { |
87 | std::tie(lhs, rhs) = rhs.split(';'); |
88 | int64_t value; |
89 | if (lhs.getAsInteger(10, value)) { |
90 | emitError(UnknownLoc::get(context)) |
91 | << "couldn't parse integer pass argument " << name; |
92 | return std::nullopt; |
93 | } |
94 | trailingBindings[position].push_back( |
95 | Builder(context).getI64IntegerAttr(value)); |
96 | } while (!rhs.empty()); |
97 | } else if (name.starts_with("^" )) { |
98 | debugBindNames.emplace_back(OperationName(name.drop_front(), context)); |
99 | } else { |
100 | debugBindNames.emplace_back(OperationName(name, context)); |
101 | } |
102 | } |
103 | |
104 | // Collect operations or results for extra bindings. |
105 | payloadRoot->walk(callback: [&](Operation *payload) { |
106 | for (auto &&[position, name] : llvm::enumerate(First&: debugBindNames)) { |
107 | if (!name || payload->getName() != *name) |
108 | continue; |
109 | |
110 | if (StringRef(*std::next(debugBindTrailingArgs.begin(), position)) |
111 | .starts_with("^" )) { |
112 | llvm::append_range(C&: trailingBindings[position], R: payload->getResults()); |
113 | } else { |
114 | trailingBindings[position].push_back(Elt: payload); |
115 | } |
116 | } |
117 | }); |
118 | |
119 | RaggedArray<transform::MappedValue> bindings; |
120 | bindings.push_back(elements: ArrayRef<Operation *>{payloadRoot}); |
121 | for (SmallVector<transform::MappedValue> &trailing : trailingBindings) |
122 | bindings.push_back(elements: std::move(trailing)); |
123 | return bindings; |
124 | } |
125 | |
126 | public: |
127 | using Base::Base; |
128 | |
129 | void runOnOperation() override { |
130 | MLIRContext *context = &getContext(); |
131 | ModuleOp transformModule = |
132 | transform::detail::getPreloadedTransformModule(context); |
133 | Operation *payloadRoot = |
134 | findPayloadRoot(getOperation(), debugPayloadRootTag); |
135 | if (!payloadRoot) |
136 | return signalPassFailure(); |
137 | |
138 | Operation *transformEntryPoint = transform::detail::findTransformEntryPoint( |
139 | getOperation(), transformModule, entryPoint); |
140 | if (!transformEntryPoint) |
141 | return signalPassFailure(); |
142 | |
143 | std::optional<RaggedArray<transform::MappedValue>> bindings = |
144 | parseArguments(payloadRoot); |
145 | if (!bindings) |
146 | return signalPassFailure(); |
147 | if (failed(transform::applyTransformNamedSequence( |
148 | *bindings, |
149 | cast<transform::TransformOpInterface>(transformEntryPoint), |
150 | transformModule, |
151 | options.enableExpensiveChecks(!disableExpensiveChecks)))) { |
152 | return signalPassFailure(); |
153 | } |
154 | } |
155 | |
156 | private: |
157 | /// Transform interpreter options. |
158 | transform::TransformOptions options; |
159 | }; |
160 | } // namespace |
161 | |