1use std::alloc::{GlobalAlloc, Layout, System};
2use std::ptr::null_mut;
3use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
4
5use bytes::{Buf, Bytes};
6
7#[global_allocator]
8static LEDGER: Ledger = Ledger::new();
9
10const LEDGER_LENGTH: usize = 2048;
11
12struct Ledger {
13 alloc_table: [(AtomicPtr<u8>, AtomicUsize); LEDGER_LENGTH],
14}
15
16impl Ledger {
17 const fn new() -> Self {
18 const ELEM: (AtomicPtr<u8>, AtomicUsize) =
19 (AtomicPtr::new(null_mut()), AtomicUsize::new(0));
20 let alloc_table = [ELEM; LEDGER_LENGTH];
21
22 Self { alloc_table }
23 }
24
25 /// Iterate over our table until we find an open entry, then insert into said entry
26 fn insert(&self, ptr: *mut u8, size: usize) {
27 for (entry_ptr, entry_size) in self.alloc_table.iter() {
28 // SeqCst is good enough here, we don't care about perf, i just want to be correct!
29 if entry_ptr
30 .compare_exchange(null_mut(), ptr, Ordering::SeqCst, Ordering::SeqCst)
31 .is_ok()
32 {
33 entry_size.store(size, Ordering::SeqCst);
34 break;
35 }
36 }
37 }
38
39 fn remove(&self, ptr: *mut u8) -> usize {
40 for (entry_ptr, entry_size) in self.alloc_table.iter() {
41 // set the value to be something that will never try and be deallocated, so that we
42 // don't have any chance of a race condition
43 //
44 // dont worry, LEDGER_LENGTH is really long to compensate for us not reclaiming space
45 if entry_ptr
46 .compare_exchange(
47 ptr,
48 invalid_ptr(usize::MAX),
49 Ordering::SeqCst,
50 Ordering::SeqCst,
51 )
52 .is_ok()
53 {
54 return entry_size.load(Ordering::SeqCst);
55 }
56 }
57
58 panic!("Couldn't find a matching entry for {:x?}", ptr);
59 }
60}
61
62unsafe impl GlobalAlloc for Ledger {
63 unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
64 let size = layout.size();
65 let ptr = System.alloc(layout);
66 self.insert(ptr, size);
67 ptr
68 }
69
70 unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
71 let orig_size = self.remove(ptr);
72
73 if orig_size != layout.size() {
74 panic!(
75 "bad dealloc: alloc size was {}, dealloc size is {}",
76 orig_size,
77 layout.size()
78 );
79 } else {
80 System.dealloc(ptr, layout);
81 }
82 }
83}
84
85#[test]
86fn test_bytes_advance() {
87 let mut bytes = Bytes::from(vec![10, 20, 30]);
88 bytes.advance(1);
89 drop(bytes);
90}
91
92#[test]
93fn test_bytes_truncate() {
94 let mut bytes = Bytes::from(vec![10, 20, 30]);
95 bytes.truncate(2);
96 drop(bytes);
97}
98
99#[test]
100fn test_bytes_truncate_and_advance() {
101 let mut bytes = Bytes::from(vec![10, 20, 30]);
102 bytes.truncate(2);
103 bytes.advance(1);
104 drop(bytes);
105}
106
107/// Returns a dangling pointer with the given address. This is used to store
108/// integer data in pointer fields.
109#[inline]
110fn invalid_ptr<T>(addr: usize) -> *mut T {
111 let ptr = std::ptr::null_mut::<u8>().wrapping_add(addr);
112 debug_assert_eq!(ptr as usize, addr);
113 ptr.cast::<T>()
114}
115
116#[test]
117fn test_bytes_into_vec() {
118 let vec = vec![33u8; 1024];
119
120 // Test cases where kind == KIND_VEC
121 let b1 = Bytes::from(vec.clone());
122 assert_eq!(Vec::from(b1), vec);
123
124 // Test cases where kind == KIND_ARC, ref_cnt == 1
125 let b1 = Bytes::from(vec.clone());
126 drop(b1.clone());
127 assert_eq!(Vec::from(b1), vec);
128
129 // Test cases where kind == KIND_ARC, ref_cnt == 2
130 let b1 = Bytes::from(vec.clone());
131 let b2 = b1.clone();
132 assert_eq!(Vec::from(b1), vec);
133
134 // Test cases where vtable = SHARED_VTABLE, kind == KIND_ARC, ref_cnt == 1
135 assert_eq!(Vec::from(b2), vec);
136
137 // Test cases where offset != 0
138 let mut b1 = Bytes::from(vec.clone());
139 let b2 = b1.split_off(20);
140
141 assert_eq!(Vec::from(b2), vec[20..]);
142 assert_eq!(Vec::from(b1), vec[..20]);
143}
144