1//===- TestPrintDefUse.cpp - Passes to illustrate the IR def-use chains ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Bytecode/BytecodeWriter.h"
10#include "mlir/Bytecode/Encoding.h"
11#include "mlir/IR/BuiltinOps.h"
12#include "mlir/IR/OwningOpRef.h"
13#include "mlir/Parser/Parser.h"
14#include "mlir/Pass/Pass.h"
15
16#include <numeric>
17#include <random>
18
19using namespace mlir;
20
21namespace {
22/// This pass tests that:
23/// 1) we can shuffle use-lists correctly;
24/// 2) use-list orders are preserved after a roundtrip to bytecode.
25class TestPreserveUseListOrders
26 : public PassWrapper<TestPreserveUseListOrders, OperationPass<ModuleOp>> {
27public:
28 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPreserveUseListOrders)
29
30 TestPreserveUseListOrders() = default;
31 TestPreserveUseListOrders(const TestPreserveUseListOrders &pass)
32 : PassWrapper(pass) {}
33 StringRef getArgument() const final { return "test-verify-uselistorder"; }
34 StringRef getDescription() const final {
35 return "Verify that roundtripping the IR to bytecode preserves the order "
36 "of the uselists";
37 }
38 Option<unsigned> rngSeed{*this, "rng-seed",
39 llvm::cl::desc("Specify an input random seed"),
40 llvm::cl::init(Val: 1)};
41
42 LogicalResult initialize(MLIRContext *context) override {
43 rng.seed(s: static_cast<unsigned>(rngSeed));
44 return success();
45 }
46
47 void runOnOperation() override {
48 // Clone the module so that we can plug in this pass to any other
49 // independently.
50 OwningOpRef<ModuleOp> cloneModule = getOperation().clone();
51
52 // 1. Compute the op numbering of the module.
53 computeOpNumbering(topLevelOp: *cloneModule);
54
55 // 2. Loop over all the values and shuffle the uses. While doing so, check
56 // that each shuffle is correct.
57 if (failed(shuffleUses(topLevelOp: *cloneModule)))
58 return signalPassFailure();
59
60 // 3. Do a bytecode roundtrip to version 3, which supports use-list order
61 // preservation.
62 auto roundtripModuleOr = doRoundtripToBytecode(module: *cloneModule, version: 3);
63 // If the bytecode roundtrip failed, try to roundtrip the original module
64 // to version 2, which does not support use-list. If this also fails, the
65 // original module had an issue unrelated to uselists.
66 if (failed(roundtripModuleOr)) {
67 auto testModuleOr = doRoundtripToBytecode(getOperation(), 2);
68 if (failed(testModuleOr))
69 return;
70
71 return signalPassFailure();
72 }
73
74 // 4. Recompute the op numbering on the new module. The numbering should be
75 // the same as (1), but on the new operation pointers.
76 computeOpNumbering(topLevelOp: roundtripModuleOr->get());
77
78 // 5. Loop over all the values and verify that the use-list is consistent
79 // with the post-shuffle order of step (2).
80 if (failed(verifyUseListOrders(topLevelOp: roundtripModuleOr->get())))
81 return signalPassFailure();
82 }
83
84private:
85 FailureOr<OwningOpRef<Operation *>> doRoundtripToBytecode(Operation *module,
86 uint32_t version) {
87 std::string str;
88 llvm::raw_string_ostream m(str);
89 BytecodeWriterConfig config;
90 config.setDesiredBytecodeVersion(version);
91 if (failed(result: writeBytecodeToFile(op: module, os&: m, config)))
92 return failure();
93
94 ParserConfig parseConfig(&getContext(), /*verifyAfterParse=*/true);
95 auto newModuleOp = parseSourceString(sourceStr: StringRef(str), config: parseConfig);
96 if (!newModuleOp.get())
97 return failure();
98 return newModuleOp;
99 }
100
101 /// Compute an ordered numbering for all the operations in the IR.
102 void computeOpNumbering(Operation *topLevelOp) {
103 uint32_t operationID = 0;
104 opNumbering.clear();
105 topLevelOp->walk<mlir::WalkOrder::PreOrder>(
106 callback: [&](Operation *op) { opNumbering.try_emplace(Key: op, Args: operationID++); });
107 }
108
109 template <typename ValueT>
110 SmallVector<uint64_t> getUseIDs(ValueT val) {
111 return SmallVector<uint64_t>(llvm::map_range(val.getUses(), [&](auto &use) {
112 return bytecode::getUseID(use, opNumbering.at(Val: use.getOwner()));
113 }));
114 }
115
116 LogicalResult shuffleUses(Operation *topLevelOp) {
117 uint32_t valueID = 0;
118 /// Permute randomly the use-list of each value. It is guaranteed that at
119 /// least one pair of the use list is permuted.
120 auto doShuffleForRange = [&](ValueRange range) -> LogicalResult {
121 for (auto val : range) {
122 if (val.use_empty() || val.hasOneUse())
123 continue;
124
125 /// Get a valid index permutation for the uses of value.
126 SmallVector<unsigned> permutation = getRandomPermutation(value: val);
127
128 /// Store original order and verify that the shuffle was applied
129 /// correctly.
130 auto useIDs = getUseIDs(val);
131
132 /// Apply shuffle to the uselist.
133 val.shuffleUseList(indices: permutation);
134
135 /// Get the new order and verify the shuffle happened correctly.
136 auto permutedIDs = getUseIDs(val);
137 if (permutedIDs.size() != useIDs.size())
138 return failure();
139 for (size_t idx = 0; idx < permutation.size(); idx++)
140 if (useIDs[idx] != permutedIDs[permutation[idx]])
141 return failure();
142
143 referenceUseListOrder.try_emplace(
144 Key: valueID++, Args: llvm::map_range(C: val.getUses(), F: [&](auto &use) {
145 return bytecode::getUseID(use, opNumbering.at(Val: use.getOwner()));
146 }));
147 }
148 return success();
149 };
150
151 return walkOverValues(topLevelOp, callable: doShuffleForRange);
152 }
153
154 LogicalResult verifyUseListOrders(Operation *topLevelOp) {
155 uint32_t valueID = 0;
156 /// Check that the use-list for the value range matches the one stored in
157 /// the reference.
158 auto doValidationForRange = [&](ValueRange range) -> LogicalResult {
159 for (auto val : range) {
160 if (val.use_empty() || val.hasOneUse())
161 continue;
162 auto referenceOrder = referenceUseListOrder.at(Val: valueID++);
163 for (auto [use, referenceID] :
164 llvm::zip(t: val.getUses(), u&: referenceOrder)) {
165 uint64_t uniqueID =
166 bytecode::getUseID(val&: use, ownerID: opNumbering.at(Val: use.getOwner()));
167 if (uniqueID != referenceID) {
168 use.getOwner()->emitError()
169 << "found use-list order mismatch for value: " << val;
170 return failure();
171 }
172 }
173 }
174 return success();
175 };
176
177 return walkOverValues(topLevelOp, callable: doValidationForRange);
178 }
179
180 /// Walk over blocks and operations and execute a callable over the ranges of
181 /// operands/results respectively.
182 template <typename FuncT>
183 LogicalResult walkOverValues(Operation *topLevelOp, FuncT callable) {
184 auto blockWalk = topLevelOp->walk([&](Block *block) {
185 if (failed(callable(block->getArguments())))
186 return WalkResult::interrupt();
187 return WalkResult::advance();
188 });
189
190 if (blockWalk.wasInterrupted())
191 return failure();
192
193 auto resultsWalk = topLevelOp->walk([&](Operation *op) {
194 if (failed(callable(op->getResults())))
195 return WalkResult::interrupt();
196 return WalkResult::advance();
197 });
198
199 return failure(resultsWalk.wasInterrupted());
200 }
201
202 /// Creates a random permutation of the uselist order chain of the provided
203 /// value.
204 SmallVector<unsigned> getRandomPermutation(Value value) {
205 size_t numUses = std::distance(first: value.use_begin(), last: value.use_end());
206 SmallVector<unsigned> permutation(numUses);
207 unsigned zero = 0;
208 std::iota(first: permutation.begin(), last: permutation.end(), value: zero);
209 std::shuffle(first: permutation.begin(), last: permutation.end(), g&: rng);
210 return permutation;
211 }
212
213 /// Map each value to its use-list order encoded with unique use IDs.
214 DenseMap<uint32_t, SmallVector<uint64_t>> referenceUseListOrder;
215
216 /// Map each operation to its global ID.
217 DenseMap<Operation *, uint32_t> opNumbering;
218
219 std::default_random_engine rng;
220};
221} // namespace
222
223namespace mlir {
224void registerTestPreserveUseListOrders() {
225 PassRegistration<TestPreserveUseListOrders>();
226}
227} // namespace mlir
228

source code of mlir/test/lib/IR/TestUseListOrders.cpp