1 | //===- TestPrintDefUse.cpp - Passes to illustrate the IR def-use chains ---===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Bytecode/BytecodeWriter.h" |
10 | #include "mlir/Bytecode/Encoding.h" |
11 | #include "mlir/IR/BuiltinOps.h" |
12 | #include "mlir/IR/OwningOpRef.h" |
13 | #include "mlir/Parser/Parser.h" |
14 | #include "mlir/Pass/Pass.h" |
15 | |
16 | #include <numeric> |
17 | #include <random> |
18 | |
19 | using namespace mlir; |
20 | |
21 | namespace { |
22 | /// This pass tests that: |
23 | /// 1) we can shuffle use-lists correctly; |
24 | /// 2) use-list orders are preserved after a roundtrip to bytecode. |
25 | class TestPreserveUseListOrders |
26 | : public PassWrapper<TestPreserveUseListOrders, OperationPass<ModuleOp>> { |
27 | public: |
28 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPreserveUseListOrders) |
29 | |
30 | TestPreserveUseListOrders() = default; |
31 | TestPreserveUseListOrders(const TestPreserveUseListOrders &pass) |
32 | : PassWrapper(pass) {} |
33 | StringRef getArgument() const final { return "test-verify-uselistorder" ; } |
34 | StringRef getDescription() const final { |
35 | return "Verify that roundtripping the IR to bytecode preserves the order " |
36 | "of the uselists" ; |
37 | } |
38 | Option<unsigned> rngSeed{*this, "rng-seed" , |
39 | llvm::cl::desc("Specify an input random seed" ), |
40 | llvm::cl::init(Val: 1)}; |
41 | |
42 | LogicalResult initialize(MLIRContext *context) override { |
43 | rng.seed(s: static_cast<unsigned>(rngSeed)); |
44 | return success(); |
45 | } |
46 | |
47 | void runOnOperation() override { |
48 | // Clone the module so that we can plug in this pass to any other |
49 | // independently. |
50 | OwningOpRef<ModuleOp> cloneModule = getOperation().clone(); |
51 | |
52 | // 1. Compute the op numbering of the module. |
53 | computeOpNumbering(topLevelOp: *cloneModule); |
54 | |
55 | // 2. Loop over all the values and shuffle the uses. While doing so, check |
56 | // that each shuffle is correct. |
57 | if (failed(shuffleUses(topLevelOp: *cloneModule))) |
58 | return signalPassFailure(); |
59 | |
60 | // 3. Do a bytecode roundtrip to version 3, which supports use-list order |
61 | // preservation. |
62 | auto roundtripModuleOr = doRoundtripToBytecode(module: *cloneModule, version: 3); |
63 | // If the bytecode roundtrip failed, try to roundtrip the original module |
64 | // to version 2, which does not support use-list. If this also fails, the |
65 | // original module had an issue unrelated to uselists. |
66 | if (failed(roundtripModuleOr)) { |
67 | auto testModuleOr = doRoundtripToBytecode(getOperation(), 2); |
68 | if (failed(testModuleOr)) |
69 | return; |
70 | |
71 | return signalPassFailure(); |
72 | } |
73 | |
74 | // 4. Recompute the op numbering on the new module. The numbering should be |
75 | // the same as (1), but on the new operation pointers. |
76 | computeOpNumbering(topLevelOp: roundtripModuleOr->get()); |
77 | |
78 | // 5. Loop over all the values and verify that the use-list is consistent |
79 | // with the post-shuffle order of step (2). |
80 | if (failed(verifyUseListOrders(topLevelOp: roundtripModuleOr->get()))) |
81 | return signalPassFailure(); |
82 | } |
83 | |
84 | private: |
85 | FailureOr<OwningOpRef<Operation *>> doRoundtripToBytecode(Operation *module, |
86 | uint32_t version) { |
87 | std::string str; |
88 | llvm::raw_string_ostream m(str); |
89 | BytecodeWriterConfig config; |
90 | config.setDesiredBytecodeVersion(version); |
91 | if (failed(result: writeBytecodeToFile(op: module, os&: m, config))) |
92 | return failure(); |
93 | |
94 | ParserConfig parseConfig(&getContext(), /*verifyAfterParse=*/true); |
95 | auto newModuleOp = parseSourceString(sourceStr: StringRef(str), config: parseConfig); |
96 | if (!newModuleOp.get()) |
97 | return failure(); |
98 | return newModuleOp; |
99 | } |
100 | |
101 | /// Compute an ordered numbering for all the operations in the IR. |
102 | void computeOpNumbering(Operation *topLevelOp) { |
103 | uint32_t operationID = 0; |
104 | opNumbering.clear(); |
105 | topLevelOp->walk<mlir::WalkOrder::PreOrder>( |
106 | callback: [&](Operation *op) { opNumbering.try_emplace(Key: op, Args: operationID++); }); |
107 | } |
108 | |
109 | template <typename ValueT> |
110 | SmallVector<uint64_t> getUseIDs(ValueT val) { |
111 | return SmallVector<uint64_t>(llvm::map_range(val.getUses(), [&](auto &use) { |
112 | return bytecode::getUseID(use, opNumbering.at(Val: use.getOwner())); |
113 | })); |
114 | } |
115 | |
116 | LogicalResult shuffleUses(Operation *topLevelOp) { |
117 | uint32_t valueID = 0; |
118 | /// Permute randomly the use-list of each value. It is guaranteed that at |
119 | /// least one pair of the use list is permuted. |
120 | auto doShuffleForRange = [&](ValueRange range) -> LogicalResult { |
121 | for (auto val : range) { |
122 | if (val.use_empty() || val.hasOneUse()) |
123 | continue; |
124 | |
125 | /// Get a valid index permutation for the uses of value. |
126 | SmallVector<unsigned> permutation = getRandomPermutation(value: val); |
127 | |
128 | /// Store original order and verify that the shuffle was applied |
129 | /// correctly. |
130 | auto useIDs = getUseIDs(val); |
131 | |
132 | /// Apply shuffle to the uselist. |
133 | val.shuffleUseList(indices: permutation); |
134 | |
135 | /// Get the new order and verify the shuffle happened correctly. |
136 | auto permutedIDs = getUseIDs(val); |
137 | if (permutedIDs.size() != useIDs.size()) |
138 | return failure(); |
139 | for (size_t idx = 0; idx < permutation.size(); idx++) |
140 | if (useIDs[idx] != permutedIDs[permutation[idx]]) |
141 | return failure(); |
142 | |
143 | referenceUseListOrder.try_emplace( |
144 | Key: valueID++, Args: llvm::map_range(C: val.getUses(), F: [&](auto &use) { |
145 | return bytecode::getUseID(use, opNumbering.at(Val: use.getOwner())); |
146 | })); |
147 | } |
148 | return success(); |
149 | }; |
150 | |
151 | return walkOverValues(topLevelOp, callable: doShuffleForRange); |
152 | } |
153 | |
154 | LogicalResult verifyUseListOrders(Operation *topLevelOp) { |
155 | uint32_t valueID = 0; |
156 | /// Check that the use-list for the value range matches the one stored in |
157 | /// the reference. |
158 | auto doValidationForRange = [&](ValueRange range) -> LogicalResult { |
159 | for (auto val : range) { |
160 | if (val.use_empty() || val.hasOneUse()) |
161 | continue; |
162 | auto referenceOrder = referenceUseListOrder.at(Val: valueID++); |
163 | for (auto [use, referenceID] : |
164 | llvm::zip(t: val.getUses(), u&: referenceOrder)) { |
165 | uint64_t uniqueID = |
166 | bytecode::getUseID(val&: use, ownerID: opNumbering.at(Val: use.getOwner())); |
167 | if (uniqueID != referenceID) { |
168 | use.getOwner()->emitError() |
169 | << "found use-list order mismatch for value: " << val; |
170 | return failure(); |
171 | } |
172 | } |
173 | } |
174 | return success(); |
175 | }; |
176 | |
177 | return walkOverValues(topLevelOp, callable: doValidationForRange); |
178 | } |
179 | |
180 | /// Walk over blocks and operations and execute a callable over the ranges of |
181 | /// operands/results respectively. |
182 | template <typename FuncT> |
183 | LogicalResult walkOverValues(Operation *topLevelOp, FuncT callable) { |
184 | auto blockWalk = topLevelOp->walk([&](Block *block) { |
185 | if (failed(callable(block->getArguments()))) |
186 | return WalkResult::interrupt(); |
187 | return WalkResult::advance(); |
188 | }); |
189 | |
190 | if (blockWalk.wasInterrupted()) |
191 | return failure(); |
192 | |
193 | auto resultsWalk = topLevelOp->walk([&](Operation *op) { |
194 | if (failed(callable(op->getResults()))) |
195 | return WalkResult::interrupt(); |
196 | return WalkResult::advance(); |
197 | }); |
198 | |
199 | return failure(resultsWalk.wasInterrupted()); |
200 | } |
201 | |
202 | /// Creates a random permutation of the uselist order chain of the provided |
203 | /// value. |
204 | SmallVector<unsigned> getRandomPermutation(Value value) { |
205 | size_t numUses = std::distance(first: value.use_begin(), last: value.use_end()); |
206 | SmallVector<unsigned> permutation(numUses); |
207 | unsigned zero = 0; |
208 | std::iota(first: permutation.begin(), last: permutation.end(), value: zero); |
209 | std::shuffle(first: permutation.begin(), last: permutation.end(), g&: rng); |
210 | return permutation; |
211 | } |
212 | |
213 | /// Map each value to its use-list order encoded with unique use IDs. |
214 | DenseMap<uint32_t, SmallVector<uint64_t>> referenceUseListOrder; |
215 | |
216 | /// Map each operation to its global ID. |
217 | DenseMap<Operation *, uint32_t> opNumbering; |
218 | |
219 | std::default_random_engine rng; |
220 | }; |
221 | } // namespace |
222 | |
223 | namespace mlir { |
224 | void registerTestPreserveUseListOrders() { |
225 | PassRegistration<TestPreserveUseListOrders>(); |
226 | } |
227 | } // namespace mlir |
228 | |