1 | use crate::{ |
2 | errors::{IntoArrayError, NotEqualError}, |
3 | InOut, |
4 | }; |
5 | use core::{marker::PhantomData, slice}; |
6 | use generic_array::{ArrayLength, GenericArray}; |
7 | |
8 | /// Custom slice type which references one immutable (input) slice and one |
9 | /// mutable (output) slice of equal length. Input and output slices are |
10 | /// either the same or do not overlap. |
11 | pub struct InOutBuf<'inp, 'out, T> { |
12 | pub(crate) in_ptr: *const T, |
13 | pub(crate) out_ptr: *mut T, |
14 | pub(crate) len: usize, |
15 | pub(crate) _pd: PhantomData<(&'inp T, &'out mut T)>, |
16 | } |
17 | |
18 | impl<'a, T> From<&'a mut [T]> for InOutBuf<'a, 'a, T> { |
19 | #[inline (always)] |
20 | fn from(buf: &'a mut [T]) -> Self { |
21 | let p: *mut T = buf.as_mut_ptr(); |
22 | Self { |
23 | in_ptr: p, |
24 | out_ptr: p, |
25 | len: buf.len(), |
26 | _pd: PhantomData, |
27 | } |
28 | } |
29 | } |
30 | |
31 | impl<'a, T> InOutBuf<'a, 'a, T> { |
32 | /// Create `InOutBuf` from a single mutable reference. |
33 | #[inline (always)] |
34 | pub fn from_mut(val: &'a mut T) -> InOutBuf<'a, 'a, T> { |
35 | let p: *mut T = val as *mut T; |
36 | Self { |
37 | in_ptr: p, |
38 | out_ptr: p, |
39 | len: 1, |
40 | _pd: PhantomData, |
41 | } |
42 | } |
43 | } |
44 | |
45 | impl<'inp, 'out, T> IntoIterator for InOutBuf<'inp, 'out, T> { |
46 | type Item = InOut<'inp, 'out, T>; |
47 | type IntoIter = InOutBufIter<'inp, 'out, T>; |
48 | |
49 | #[inline (always)] |
50 | fn into_iter(self) -> Self::IntoIter { |
51 | InOutBufIter { buf: self, pos: 0 } |
52 | } |
53 | } |
54 | |
55 | impl<'inp, 'out, T> InOutBuf<'inp, 'out, T> { |
56 | /// Create `InOutBuf` from a pair of immutable and mutable references. |
57 | #[inline (always)] |
58 | pub fn from_ref_mut(in_val: &'inp T, out_val: &'out mut T) -> Self { |
59 | Self { |
60 | in_ptr: in_val as *const T, |
61 | out_ptr: out_val as *mut T, |
62 | len: 1, |
63 | _pd: PhantomData, |
64 | } |
65 | } |
66 | |
67 | /// Create `InOutBuf` from immutable and mutable slices. |
68 | /// |
69 | /// Returns an error if length of slices is not equal to each other. |
70 | #[inline (always)] |
71 | pub fn new(in_buf: &'inp [T], out_buf: &'out mut [T]) -> Result<Self, NotEqualError> { |
72 | if in_buf.len() != out_buf.len() { |
73 | Err(NotEqualError) |
74 | } else { |
75 | Ok(Self { |
76 | in_ptr: in_buf.as_ptr(), |
77 | out_ptr: out_buf.as_mut_ptr(), |
78 | len: in_buf.len(), |
79 | _pd: Default::default(), |
80 | }) |
81 | } |
82 | } |
83 | |
84 | /// Get length of the inner buffers. |
85 | #[inline (always)] |
86 | pub fn len(&self) -> usize { |
87 | self.len |
88 | } |
89 | |
90 | /// Returns `true` if the buffer has a length of 0. |
91 | #[inline (always)] |
92 | pub fn is_empty(&self) -> bool { |
93 | self.len == 0 |
94 | } |
95 | |
96 | /// Returns `InOut` for given position. |
97 | /// |
98 | /// # Panics |
99 | /// If `pos` greater or equal to buffer length. |
100 | #[inline (always)] |
101 | pub fn get<'a>(&'a mut self, pos: usize) -> InOut<'a, 'a, T> { |
102 | assert!(pos < self.len); |
103 | unsafe { |
104 | InOut { |
105 | in_ptr: self.in_ptr.add(pos), |
106 | out_ptr: self.out_ptr.add(pos), |
107 | _pd: PhantomData, |
108 | } |
109 | } |
110 | } |
111 | |
112 | /// Get input slice. |
113 | #[inline (always)] |
114 | pub fn get_in<'a>(&'a self) -> &'a [T] { |
115 | unsafe { slice::from_raw_parts(self.in_ptr, self.len) } |
116 | } |
117 | |
118 | /// Get output slice. |
119 | #[inline (always)] |
120 | pub fn get_out<'a>(&'a mut self) -> &'a mut [T] { |
121 | unsafe { slice::from_raw_parts_mut(self.out_ptr, self.len) } |
122 | } |
123 | |
124 | /// Consume self and return output slice with lifetime `'a`. |
125 | #[inline (always)] |
126 | pub fn into_out(self) -> &'out mut [T] { |
127 | unsafe { slice::from_raw_parts_mut(self.out_ptr, self.len) } |
128 | } |
129 | |
130 | /// Get raw input and output pointers. |
131 | #[inline (always)] |
132 | pub fn into_raw(self) -> (*const T, *mut T) { |
133 | (self.in_ptr, self.out_ptr) |
134 | } |
135 | |
136 | /// Reborrow `self`. |
137 | #[inline (always)] |
138 | pub fn reborrow<'a>(&'a mut self) -> InOutBuf<'a, 'a, T> { |
139 | Self { |
140 | in_ptr: self.in_ptr, |
141 | out_ptr: self.out_ptr, |
142 | len: self.len, |
143 | _pd: PhantomData, |
144 | } |
145 | } |
146 | |
147 | /// Create [`InOutBuf`] from raw input and output pointers. |
148 | /// |
149 | /// # Safety |
150 | /// Behavior is undefined if any of the following conditions are violated: |
151 | /// - `in_ptr` must point to a properly initialized value of type `T` and |
152 | /// must be valid for reads for `len * mem::size_of::<T>()` many bytes. |
153 | /// - `out_ptr` must point to a properly initialized value of type `T` and |
154 | /// must be valid for both reads and writes for `len * mem::size_of::<T>()` |
155 | /// many bytes. |
156 | /// - `in_ptr` and `out_ptr` must be either equal or non-overlapping. |
157 | /// - If `in_ptr` and `out_ptr` are equal, then the memory referenced by |
158 | /// them must not be accessed through any other pointer (not derived from |
159 | /// the return value) for the duration of lifetime 'a. Both read and write |
160 | /// accesses are forbidden. |
161 | /// - If `in_ptr` and `out_ptr` are not equal, then the memory referenced by |
162 | /// `out_ptr` must not be accessed through any other pointer (not derived from |
163 | /// the return value) for the duration of lifetime 'a. Both read and write |
164 | /// accesses are forbidden. The memory referenced by `in_ptr` must not be |
165 | /// mutated for the duration of lifetime `'a`, except inside an `UnsafeCell`. |
166 | /// - The total size `len * mem::size_of::<T>()` must be no larger than `isize::MAX`. |
167 | #[inline (always)] |
168 | pub unsafe fn from_raw( |
169 | in_ptr: *const T, |
170 | out_ptr: *mut T, |
171 | len: usize, |
172 | ) -> InOutBuf<'inp, 'out, T> { |
173 | Self { |
174 | in_ptr, |
175 | out_ptr, |
176 | len, |
177 | _pd: PhantomData, |
178 | } |
179 | } |
180 | |
181 | /// Divides one buffer into two at `mid` index. |
182 | /// |
183 | /// The first will contain all indices from `[0, mid)` (excluding |
184 | /// the index `mid` itself) and the second will contain all |
185 | /// indices from `[mid, len)` (excluding the index `len` itself). |
186 | /// |
187 | /// # Panics |
188 | /// |
189 | /// Panics if `mid > len`. |
190 | #[inline (always)] |
191 | pub fn split_at(self, mid: usize) -> (InOutBuf<'inp, 'out, T>, InOutBuf<'inp, 'out, T>) { |
192 | assert!(mid <= self.len); |
193 | let (tail_in_ptr, tail_out_ptr) = unsafe { (self.in_ptr.add(mid), self.out_ptr.add(mid)) }; |
194 | ( |
195 | InOutBuf { |
196 | in_ptr: self.in_ptr, |
197 | out_ptr: self.out_ptr, |
198 | len: mid, |
199 | _pd: PhantomData, |
200 | }, |
201 | InOutBuf { |
202 | in_ptr: tail_in_ptr, |
203 | out_ptr: tail_out_ptr, |
204 | len: self.len() - mid, |
205 | _pd: PhantomData, |
206 | }, |
207 | ) |
208 | } |
209 | |
210 | /// Partition buffer into 2 parts: buffer of arrays and tail. |
211 | #[inline (always)] |
212 | pub fn into_chunks<N: ArrayLength<T>>( |
213 | self, |
214 | ) -> ( |
215 | InOutBuf<'inp, 'out, GenericArray<T, N>>, |
216 | InOutBuf<'inp, 'out, T>, |
217 | ) { |
218 | let chunks = self.len() / N::USIZE; |
219 | let tail_pos = N::USIZE * chunks; |
220 | let tail_len = self.len() - tail_pos; |
221 | unsafe { |
222 | let chunks = InOutBuf { |
223 | in_ptr: self.in_ptr as *const GenericArray<T, N>, |
224 | out_ptr: self.out_ptr as *mut GenericArray<T, N>, |
225 | len: chunks, |
226 | _pd: PhantomData, |
227 | }; |
228 | let tail = InOutBuf { |
229 | in_ptr: self.in_ptr.add(tail_pos), |
230 | out_ptr: self.out_ptr.add(tail_pos), |
231 | len: tail_len, |
232 | _pd: PhantomData, |
233 | }; |
234 | (chunks, tail) |
235 | } |
236 | } |
237 | } |
238 | |
239 | impl<'inp, 'out> InOutBuf<'inp, 'out, u8> { |
240 | /// XORs `data` with values behind the input slice and write |
241 | /// result to the output slice. |
242 | /// |
243 | /// # Panics |
244 | /// If `data` length is not equal to the buffer length. |
245 | #[inline (always)] |
246 | #[allow (clippy::needless_range_loop)] |
247 | pub fn xor_in2out(&mut self, data: &[u8]) { |
248 | assert_eq!(self.len(), data.len()); |
249 | unsafe { |
250 | for i: usize in 0..data.len() { |
251 | let in_ptr: *const u8 = self.in_ptr.add(count:i); |
252 | let out_ptr: *mut u8 = self.out_ptr.add(count:i); |
253 | *out_ptr = *in_ptr ^ data[i]; |
254 | } |
255 | } |
256 | } |
257 | } |
258 | |
259 | impl<'inp, 'out, T, N> TryInto<InOut<'inp, 'out, GenericArray<T, N>>> for InOutBuf<'inp, 'out, T> |
260 | where |
261 | N: ArrayLength<T>, |
262 | { |
263 | type Error = IntoArrayError; |
264 | |
265 | #[inline (always)] |
266 | fn try_into(self) -> Result<InOut<'inp, 'out, GenericArray<T, N>>, Self::Error> { |
267 | if self.len() == N::USIZE { |
268 | Ok(InOut { |
269 | in_ptr: self.in_ptr as *const _, |
270 | out_ptr: self.out_ptr as *mut _, |
271 | _pd: PhantomData, |
272 | }) |
273 | } else { |
274 | Err(IntoArrayError) |
275 | } |
276 | } |
277 | } |
278 | |
279 | /// Iterator over [`InOutBuf`]. |
280 | pub struct InOutBufIter<'inp, 'out, T> { |
281 | buf: InOutBuf<'inp, 'out, T>, |
282 | pos: usize, |
283 | } |
284 | |
285 | impl<'inp, 'out, T> Iterator for InOutBufIter<'inp, 'out, T> { |
286 | type Item = InOut<'inp, 'out, T>; |
287 | |
288 | #[inline (always)] |
289 | fn next(&mut self) -> Option<Self::Item> { |
290 | if self.buf.len() == self.pos { |
291 | return None; |
292 | } |
293 | let res: InOut<'_, '_, T> = unsafe { |
294 | InOut { |
295 | in_ptr: self.buf.in_ptr.add(self.pos), |
296 | out_ptr: self.buf.out_ptr.add(self.pos), |
297 | _pd: PhantomData, |
298 | } |
299 | }; |
300 | self.pos += 1; |
301 | Some(res) |
302 | } |
303 | } |
304 | |