| 1 | //===- TestPrintDefUse.cpp - Passes to illustrate the IR def-use chains ---===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Bytecode/BytecodeWriter.h" |
| 10 | #include "mlir/Bytecode/Encoding.h" |
| 11 | #include "mlir/IR/BuiltinOps.h" |
| 12 | #include "mlir/IR/OwningOpRef.h" |
| 13 | #include "mlir/Parser/Parser.h" |
| 14 | #include "mlir/Pass/Pass.h" |
| 15 | |
| 16 | #include <numeric> |
| 17 | #include <random> |
| 18 | |
| 19 | using namespace mlir; |
| 20 | |
| 21 | namespace { |
| 22 | /// This pass tests that: |
| 23 | /// 1) we can shuffle use-lists correctly; |
| 24 | /// 2) use-list orders are preserved after a roundtrip to bytecode. |
| 25 | class TestPreserveUseListOrders |
| 26 | : public PassWrapper<TestPreserveUseListOrders, OperationPass<ModuleOp>> { |
| 27 | public: |
| 28 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPreserveUseListOrders) |
| 29 | |
| 30 | TestPreserveUseListOrders() = default; |
| 31 | TestPreserveUseListOrders(const TestPreserveUseListOrders &pass) |
| 32 | : PassWrapper(pass) {} |
| 33 | StringRef getArgument() const final { return "test-verify-uselistorder" ; } |
| 34 | StringRef getDescription() const final { |
| 35 | return "Verify that roundtripping the IR to bytecode preserves the order " |
| 36 | "of the uselists" ; |
| 37 | } |
| 38 | Option<unsigned> rngSeed{*this, "rng-seed" , |
| 39 | llvm::cl::desc("Specify an input random seed" ), |
| 40 | llvm::cl::init(Val: 1)}; |
| 41 | |
| 42 | LogicalResult initialize(MLIRContext *context) override { |
| 43 | rng.seed(s: static_cast<unsigned>(rngSeed)); |
| 44 | return success(); |
| 45 | } |
| 46 | |
| 47 | void runOnOperation() override { |
| 48 | // Clone the module so that we can plug in this pass to any other |
| 49 | // independently. |
| 50 | OwningOpRef<ModuleOp> cloneModule = getOperation().clone(); |
| 51 | |
| 52 | // 1. Compute the op numbering of the module. |
| 53 | computeOpNumbering(topLevelOp: *cloneModule); |
| 54 | |
| 55 | // 2. Loop over all the values and shuffle the uses. While doing so, check |
| 56 | // that each shuffle is correct. |
| 57 | if (failed(shuffleUses(topLevelOp: *cloneModule))) |
| 58 | return signalPassFailure(); |
| 59 | |
| 60 | // 3. Do a bytecode roundtrip to version 3, which supports use-list order |
| 61 | // preservation. |
| 62 | auto roundtripModuleOr = doRoundtripToBytecode(module: *cloneModule, version: 3); |
| 63 | // If the bytecode roundtrip failed, try to roundtrip the original module |
| 64 | // to version 2, which does not support use-list. If this also fails, the |
| 65 | // original module had an issue unrelated to uselists. |
| 66 | if (failed(roundtripModuleOr)) { |
| 67 | auto testModuleOr = doRoundtripToBytecode(getOperation(), 2); |
| 68 | if (failed(testModuleOr)) |
| 69 | return; |
| 70 | |
| 71 | return signalPassFailure(); |
| 72 | } |
| 73 | |
| 74 | // 4. Recompute the op numbering on the new module. The numbering should be |
| 75 | // the same as (1), but on the new operation pointers. |
| 76 | computeOpNumbering(topLevelOp: roundtripModuleOr->get()); |
| 77 | |
| 78 | // 5. Loop over all the values and verify that the use-list is consistent |
| 79 | // with the post-shuffle order of step (2). |
| 80 | if (failed(verifyUseListOrders(topLevelOp: roundtripModuleOr->get()))) |
| 81 | return signalPassFailure(); |
| 82 | } |
| 83 | |
| 84 | private: |
| 85 | FailureOr<OwningOpRef<Operation *>> doRoundtripToBytecode(Operation *module, |
| 86 | uint32_t version) { |
| 87 | std::string str; |
| 88 | llvm::raw_string_ostream m(str); |
| 89 | BytecodeWriterConfig config; |
| 90 | config.setDesiredBytecodeVersion(version); |
| 91 | if (failed(Result: writeBytecodeToFile(op: module, os&: m, config))) |
| 92 | return failure(); |
| 93 | |
| 94 | ParserConfig parseConfig(&getContext(), /*verifyAfterParse=*/true); |
| 95 | auto newModuleOp = parseSourceString(sourceStr: StringRef(str), config: parseConfig); |
| 96 | if (!newModuleOp.get()) |
| 97 | return failure(); |
| 98 | return newModuleOp; |
| 99 | } |
| 100 | |
| 101 | /// Compute an ordered numbering for all the operations in the IR. |
| 102 | void computeOpNumbering(Operation *topLevelOp) { |
| 103 | uint32_t operationID = 0; |
| 104 | opNumbering.clear(); |
| 105 | topLevelOp->walk<mlir::WalkOrder::PreOrder>( |
| 106 | callback: [&](Operation *op) { opNumbering.try_emplace(Key: op, Args: operationID++); }); |
| 107 | } |
| 108 | |
| 109 | template <typename ValueT> |
| 110 | SmallVector<uint64_t> getUseIDs(ValueT val) { |
| 111 | return SmallVector<uint64_t>(llvm::map_range(val.getUses(), [&](auto &use) { |
| 112 | return bytecode::getUseID(use, opNumbering.at(Val: use.getOwner())); |
| 113 | })); |
| 114 | } |
| 115 | |
| 116 | LogicalResult shuffleUses(Operation *topLevelOp) { |
| 117 | uint32_t valueID = 0; |
| 118 | /// Permute randomly the use-list of each value. It is guaranteed that at |
| 119 | /// least one pair of the use list is permuted. |
| 120 | auto doShuffleForRange = [&](ValueRange range) -> LogicalResult { |
| 121 | for (auto val : range) { |
| 122 | if (val.use_empty() || val.hasOneUse()) |
| 123 | continue; |
| 124 | |
| 125 | /// Get a valid index permutation for the uses of value. |
| 126 | SmallVector<unsigned> permutation = getRandomPermutation(value: val); |
| 127 | |
| 128 | /// Store original order and verify that the shuffle was applied |
| 129 | /// correctly. |
| 130 | auto useIDs = getUseIDs(val); |
| 131 | |
| 132 | /// Apply shuffle to the uselist. |
| 133 | val.shuffleUseList(indices: permutation); |
| 134 | |
| 135 | /// Get the new order and verify the shuffle happened correctly. |
| 136 | auto permutedIDs = getUseIDs(val); |
| 137 | if (permutedIDs.size() != useIDs.size()) |
| 138 | return failure(); |
| 139 | for (size_t idx = 0; idx < permutation.size(); idx++) |
| 140 | if (useIDs[idx] != permutedIDs[permutation[idx]]) |
| 141 | return failure(); |
| 142 | |
| 143 | referenceUseListOrder.try_emplace( |
| 144 | Key: valueID++, Args: llvm::map_range(C: val.getUses(), F: [&](auto &use) { |
| 145 | return bytecode::getUseID(use, opNumbering.at(Val: use.getOwner())); |
| 146 | })); |
| 147 | } |
| 148 | return success(); |
| 149 | }; |
| 150 | |
| 151 | return walkOverValues(topLevelOp, callable: doShuffleForRange); |
| 152 | } |
| 153 | |
| 154 | LogicalResult verifyUseListOrders(Operation *topLevelOp) { |
| 155 | uint32_t valueID = 0; |
| 156 | /// Check that the use-list for the value range matches the one stored in |
| 157 | /// the reference. |
| 158 | auto doValidationForRange = [&](ValueRange range) -> LogicalResult { |
| 159 | for (auto val : range) { |
| 160 | if (val.use_empty() || val.hasOneUse()) |
| 161 | continue; |
| 162 | auto referenceOrder = referenceUseListOrder.at(Val: valueID++); |
| 163 | for (auto [use, referenceID] : |
| 164 | llvm::zip(t: val.getUses(), u&: referenceOrder)) { |
| 165 | uint64_t uniqueID = |
| 166 | bytecode::getUseID(val&: use, ownerID: opNumbering.at(Val: use.getOwner())); |
| 167 | if (uniqueID != referenceID) { |
| 168 | use.getOwner()->emitError() |
| 169 | << "found use-list order mismatch for value: " << val; |
| 170 | return failure(); |
| 171 | } |
| 172 | } |
| 173 | } |
| 174 | return success(); |
| 175 | }; |
| 176 | |
| 177 | return walkOverValues(topLevelOp, callable: doValidationForRange); |
| 178 | } |
| 179 | |
| 180 | /// Walk over blocks and operations and execute a callable over the ranges of |
| 181 | /// operands/results respectively. |
| 182 | template <typename FuncT> |
| 183 | LogicalResult walkOverValues(Operation *topLevelOp, FuncT callable) { |
| 184 | auto blockWalk = topLevelOp->walk([&](Block *block) { |
| 185 | if (failed(callable(block->getArguments()))) |
| 186 | return WalkResult::interrupt(); |
| 187 | return WalkResult::advance(); |
| 188 | }); |
| 189 | |
| 190 | if (blockWalk.wasInterrupted()) |
| 191 | return failure(); |
| 192 | |
| 193 | auto resultsWalk = topLevelOp->walk([&](Operation *op) { |
| 194 | if (failed(callable(op->getResults()))) |
| 195 | return WalkResult::interrupt(); |
| 196 | return WalkResult::advance(); |
| 197 | }); |
| 198 | |
| 199 | return failure(resultsWalk.wasInterrupted()); |
| 200 | } |
| 201 | |
| 202 | /// Creates a random permutation of the uselist order chain of the provided |
| 203 | /// value. |
| 204 | SmallVector<unsigned> getRandomPermutation(Value value) { |
| 205 | size_t numUses = std::distance(first: value.use_begin(), last: value.use_end()); |
| 206 | SmallVector<unsigned> permutation(numUses); |
| 207 | unsigned zero = 0; |
| 208 | std::iota(first: permutation.begin(), last: permutation.end(), value: zero); |
| 209 | std::shuffle(first: permutation.begin(), last: permutation.end(), g&: rng); |
| 210 | return permutation; |
| 211 | } |
| 212 | |
| 213 | /// Map each value to its use-list order encoded with unique use IDs. |
| 214 | DenseMap<uint32_t, SmallVector<uint64_t>> referenceUseListOrder; |
| 215 | |
| 216 | /// Map each operation to its global ID. |
| 217 | DenseMap<Operation *, uint32_t> opNumbering; |
| 218 | |
| 219 | std::default_random_engine rng; |
| 220 | }; |
| 221 | } // namespace |
| 222 | |
| 223 | namespace mlir { |
| 224 | void registerTestPreserveUseListOrders() { |
| 225 | PassRegistration<TestPreserveUseListOrders>(); |
| 226 | } |
| 227 | } // namespace mlir |
| 228 | |