1 | //===-- hashtable_fuzz.cpp ------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | /// |
9 | /// Fuzzing test for llvm-libc hashtable implementations. |
10 | /// |
11 | //===----------------------------------------------------------------------===// |
12 | #include "hdr/types/ENTRY.h" |
13 | #include "src/__support/CPP/bit.h" |
14 | #include "src/__support/CPP/string_view.h" |
15 | #include "src/__support/HashTable/table.h" |
16 | #include "src/__support/macros/config.h" |
17 | |
18 | namespace LIBC_NAMESPACE_DECL { |
19 | |
20 | // A fuzzing payload starts with |
21 | // - uint16_t: initial capacity for table A |
22 | // - uint64_t: seed for table A |
23 | // - uint16_t: initial capacity for table B |
24 | // - uint64_t: seed for table B |
25 | // Followed by a sequence of actions: |
26 | // - CrossCheck: only a single byte valued (4 mod 5) |
27 | // - Find: a single byte valued (3 mod 5) followed by a null-terminated string |
28 | // - Insert: a single byte valued (0,1,2 mod 5) followed by a null-terminated |
29 | // string |
30 | static constexpr size_t = |
31 | 2 * (sizeof(uint16_t) + sizeof(uint64_t)); |
32 | extern "C" size_t LLVMFuzzerMutate(uint8_t *data, size_t size, size_t max_size); |
33 | extern "C" size_t LLVMFuzzerCustomMutator(uint8_t *data, size_t size, |
34 | size_t max_size, unsigned int seed) { |
35 | size = LLVMFuzzerMutate(data, size, max_size); |
36 | // not enough to read the initial capacities and seeds |
37 | if (size < INITIAL_HEADER_SIZE) |
38 | return 0; |
39 | |
40 | // skip the initial capacities and seeds |
41 | size_t i = INITIAL_HEADER_SIZE; |
42 | while (i < size) { |
43 | // cross check |
44 | if (static_cast<uint8_t>(data[i]) % 5 == 4) { |
45 | // skip the cross check byte |
46 | ++i; |
47 | continue; |
48 | } |
49 | |
50 | // find or insert |
51 | // check if there is enough space for the action byte and the |
52 | // null-terminator |
53 | if (i + 2 >= max_size) |
54 | return i; |
55 | // skip the action byte |
56 | ++i; |
57 | // skip the null-terminated string |
58 | while (i < max_size && data[i] != 0) |
59 | ++i; |
60 | // in the case the string is not null-terminated, null-terminate it |
61 | if (i == max_size && data[i - 1] != 0) { |
62 | data[i - 1] = 0; |
63 | return max_size; |
64 | } |
65 | |
66 | // move to the next action |
67 | ++i; |
68 | } |
69 | // return the new size |
70 | return i; |
71 | } |
72 | |
73 | // a tagged union |
74 | struct Action { |
75 | enum class Tag { Find, Insert, CrossCheck } tag; |
76 | cpp::string_view key; |
77 | }; |
78 | |
79 | static struct { |
80 | size_t remaining; |
81 | const char *buffer; |
82 | |
83 | template <typename T> T next() { |
84 | static_assert(cpp::is_integral<T>::value, "T must be an integral type" ); |
85 | |
86 | char data[sizeof(T)]; |
87 | |
88 | for (size_t i = 0; i < sizeof(T); i++) |
89 | data[i] = buffer[i]; |
90 | buffer += sizeof(T); |
91 | remaining -= sizeof(T); |
92 | return cpp::bit_cast<T>(data); |
93 | } |
94 | |
95 | cpp::string_view next_string() { |
96 | cpp::string_view result(buffer); |
97 | buffer = result.end() + 1; |
98 | remaining -= result.size() + 1; |
99 | return result; |
100 | } |
101 | |
102 | Action next_action() { |
103 | uint8_t byte = next<uint8_t>(); |
104 | switch (byte % 5) { |
105 | case 4: |
106 | return {Action::Tag::CrossCheck, {}}; |
107 | case 3: |
108 | return {Action::Tag::Find, next_string()}; |
109 | default: |
110 | return {Action::Tag::Insert, next_string()}; |
111 | } |
112 | } |
113 | } global_status; |
114 | |
115 | class HashTable { |
116 | internal::HashTable *table; |
117 | |
118 | public: |
119 | HashTable(uint64_t size, uint64_t seed) |
120 | : table(internal::HashTable::allocate(size, seed)) {} |
121 | HashTable(internal::HashTable *table) : table(table) {} |
122 | ~HashTable() { internal::HashTable::deallocate(table); } |
123 | HashTable(HashTable &&other) : table(other.table) { other.table = nullptr; } |
124 | bool is_valid() const { return table != nullptr; } |
125 | ENTRY *find(const char *key) { return table->find(key); } |
126 | ENTRY *insert(const ENTRY &entry) { |
127 | return internal::HashTable::insert(this->table, entry); |
128 | } |
129 | using iterator = internal::HashTable::iterator; |
130 | iterator begin() const { return table->begin(); } |
131 | iterator end() const { return table->end(); } |
132 | }; |
133 | |
134 | HashTable next_hashtable() { |
135 | size_t size = global_status.next<uint16_t>(); |
136 | uint64_t seed = global_status.next<uint64_t>(); |
137 | return HashTable(size, seed); |
138 | } |
139 | |
140 | extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { |
141 | global_status.buffer = reinterpret_cast<const char *>(data); |
142 | global_status.remaining = size; |
143 | if (global_status.remaining < INITIAL_HEADER_SIZE) |
144 | return 0; |
145 | |
146 | HashTable table_a = next_hashtable(); |
147 | HashTable table_b = next_hashtable(); |
148 | for (;;) { |
149 | if (global_status.remaining == 0) |
150 | break; |
151 | Action action = global_status.next_action(); |
152 | switch (action.tag) { |
153 | case Action::Tag::Find: { |
154 | if (static_cast<bool>(table_a.find(action.key.data())) != |
155 | static_cast<bool>(table_b.find(action.key.data()))) |
156 | __builtin_trap(); |
157 | break; |
158 | } |
159 | case Action::Tag::Insert: { |
160 | char *ptr = const_cast<char *>(action.key.data()); |
161 | ENTRY *a = table_a.insert(ENTRY{ptr, ptr}); |
162 | ENTRY *b = table_b.insert(ENTRY{ptr, ptr}); |
163 | if (a->data != b->data) |
164 | __builtin_trap(); |
165 | break; |
166 | } |
167 | case Action::Tag::CrossCheck: { |
168 | for (ENTRY a : table_a) |
169 | if (const ENTRY *b = table_b.find(a.key); a.data != b->data) |
170 | __builtin_trap(); |
171 | |
172 | for (ENTRY b : table_b) |
173 | if (const ENTRY *a = table_a.find(b.key); a->data != b.data) |
174 | __builtin_trap(); |
175 | |
176 | break; |
177 | } |
178 | } |
179 | } |
180 | return 0; |
181 | } |
182 | |
183 | } // namespace LIBC_NAMESPACE_DECL |
184 | |