1use crate::{
2 errors::{IntoArrayError, NotEqualError},
3 InOut,
4};
5use core::{marker::PhantomData, slice};
6use generic_array::{ArrayLength, GenericArray};
7
8/// Custom slice type which references one immutable (input) slice and one
9/// mutable (output) slice of equal length. Input and output slices are
10/// either the same or do not overlap.
11pub struct InOutBuf<'inp, 'out, T> {
12 pub(crate) in_ptr: *const T,
13 pub(crate) out_ptr: *mut T,
14 pub(crate) len: usize,
15 pub(crate) _pd: PhantomData<(&'inp T, &'out mut T)>,
16}
17
18impl<'a, T> From<&'a mut [T]> for InOutBuf<'a, 'a, T> {
19 #[inline(always)]
20 fn from(buf: &'a mut [T]) -> Self {
21 let p: *mut T = buf.as_mut_ptr();
22 Self {
23 in_ptr: p,
24 out_ptr: p,
25 len: buf.len(),
26 _pd: PhantomData,
27 }
28 }
29}
30
31impl<'a, T> InOutBuf<'a, 'a, T> {
32 /// Create `InOutBuf` from a single mutable reference.
33 #[inline(always)]
34 pub fn from_mut(val: &'a mut T) -> InOutBuf<'a, 'a, T> {
35 let p: *mut T = val as *mut T;
36 Self {
37 in_ptr: p,
38 out_ptr: p,
39 len: 1,
40 _pd: PhantomData,
41 }
42 }
43}
44
45impl<'inp, 'out, T> IntoIterator for InOutBuf<'inp, 'out, T> {
46 type Item = InOut<'inp, 'out, T>;
47 type IntoIter = InOutBufIter<'inp, 'out, T>;
48
49 #[inline(always)]
50 fn into_iter(self) -> Self::IntoIter {
51 InOutBufIter { buf: self, pos: 0 }
52 }
53}
54
55impl<'inp, 'out, T> InOutBuf<'inp, 'out, T> {
56 /// Create `InOutBuf` from a pair of immutable and mutable references.
57 #[inline(always)]
58 pub fn from_ref_mut(in_val: &'inp T, out_val: &'out mut T) -> Self {
59 Self {
60 in_ptr: in_val as *const T,
61 out_ptr: out_val as *mut T,
62 len: 1,
63 _pd: PhantomData,
64 }
65 }
66
67 /// Create `InOutBuf` from immutable and mutable slices.
68 ///
69 /// Returns an error if length of slices is not equal to each other.
70 #[inline(always)]
71 pub fn new(in_buf: &'inp [T], out_buf: &'out mut [T]) -> Result<Self, NotEqualError> {
72 if in_buf.len() != out_buf.len() {
73 Err(NotEqualError)
74 } else {
75 Ok(Self {
76 in_ptr: in_buf.as_ptr(),
77 out_ptr: out_buf.as_mut_ptr(),
78 len: in_buf.len(),
79 _pd: Default::default(),
80 })
81 }
82 }
83
84 /// Get length of the inner buffers.
85 #[inline(always)]
86 pub fn len(&self) -> usize {
87 self.len
88 }
89
90 /// Returns `true` if the buffer has a length of 0.
91 #[inline(always)]
92 pub fn is_empty(&self) -> bool {
93 self.len == 0
94 }
95
96 /// Returns `InOut` for given position.
97 ///
98 /// # Panics
99 /// If `pos` greater or equal to buffer length.
100 #[inline(always)]
101 pub fn get<'a>(&'a mut self, pos: usize) -> InOut<'a, 'a, T> {
102 assert!(pos < self.len);
103 unsafe {
104 InOut {
105 in_ptr: self.in_ptr.add(pos),
106 out_ptr: self.out_ptr.add(pos),
107 _pd: PhantomData,
108 }
109 }
110 }
111
112 /// Get input slice.
113 #[inline(always)]
114 pub fn get_in<'a>(&'a self) -> &'a [T] {
115 unsafe { slice::from_raw_parts(self.in_ptr, self.len) }
116 }
117
118 /// Get output slice.
119 #[inline(always)]
120 pub fn get_out<'a>(&'a mut self) -> &'a mut [T] {
121 unsafe { slice::from_raw_parts_mut(self.out_ptr, self.len) }
122 }
123
124 /// Consume self and return output slice with lifetime `'a`.
125 #[inline(always)]
126 pub fn into_out(self) -> &'out mut [T] {
127 unsafe { slice::from_raw_parts_mut(self.out_ptr, self.len) }
128 }
129
130 /// Get raw input and output pointers.
131 #[inline(always)]
132 pub fn into_raw(self) -> (*const T, *mut T) {
133 (self.in_ptr, self.out_ptr)
134 }
135
136 /// Reborrow `self`.
137 #[inline(always)]
138 pub fn reborrow<'a>(&'a mut self) -> InOutBuf<'a, 'a, T> {
139 Self {
140 in_ptr: self.in_ptr,
141 out_ptr: self.out_ptr,
142 len: self.len,
143 _pd: PhantomData,
144 }
145 }
146
147 /// Create [`InOutBuf`] from raw input and output pointers.
148 ///
149 /// # Safety
150 /// Behavior is undefined if any of the following conditions are violated:
151 /// - `in_ptr` must point to a properly initialized value of type `T` and
152 /// must be valid for reads for `len * mem::size_of::<T>()` many bytes.
153 /// - `out_ptr` must point to a properly initialized value of type `T` and
154 /// must be valid for both reads and writes for `len * mem::size_of::<T>()`
155 /// many bytes.
156 /// - `in_ptr` and `out_ptr` must be either equal or non-overlapping.
157 /// - If `in_ptr` and `out_ptr` are equal, then the memory referenced by
158 /// them must not be accessed through any other pointer (not derived from
159 /// the return value) for the duration of lifetime 'a. Both read and write
160 /// accesses are forbidden.
161 /// - If `in_ptr` and `out_ptr` are not equal, then the memory referenced by
162 /// `out_ptr` must not be accessed through any other pointer (not derived from
163 /// the return value) for the duration of lifetime 'a. Both read and write
164 /// accesses are forbidden. The memory referenced by `in_ptr` must not be
165 /// mutated for the duration of lifetime `'a`, except inside an `UnsafeCell`.
166 /// - The total size `len * mem::size_of::<T>()` must be no larger than `isize::MAX`.
167 #[inline(always)]
168 pub unsafe fn from_raw(
169 in_ptr: *const T,
170 out_ptr: *mut T,
171 len: usize,
172 ) -> InOutBuf<'inp, 'out, T> {
173 Self {
174 in_ptr,
175 out_ptr,
176 len,
177 _pd: PhantomData,
178 }
179 }
180
181 /// Divides one buffer into two at `mid` index.
182 ///
183 /// The first will contain all indices from `[0, mid)` (excluding
184 /// the index `mid` itself) and the second will contain all
185 /// indices from `[mid, len)` (excluding the index `len` itself).
186 ///
187 /// # Panics
188 ///
189 /// Panics if `mid > len`.
190 #[inline(always)]
191 pub fn split_at(self, mid: usize) -> (InOutBuf<'inp, 'out, T>, InOutBuf<'inp, 'out, T>) {
192 assert!(mid <= self.len);
193 let (tail_in_ptr, tail_out_ptr) = unsafe { (self.in_ptr.add(mid), self.out_ptr.add(mid)) };
194 (
195 InOutBuf {
196 in_ptr: self.in_ptr,
197 out_ptr: self.out_ptr,
198 len: mid,
199 _pd: PhantomData,
200 },
201 InOutBuf {
202 in_ptr: tail_in_ptr,
203 out_ptr: tail_out_ptr,
204 len: self.len() - mid,
205 _pd: PhantomData,
206 },
207 )
208 }
209
210 /// Partition buffer into 2 parts: buffer of arrays and tail.
211 #[inline(always)]
212 pub fn into_chunks<N: ArrayLength<T>>(
213 self,
214 ) -> (
215 InOutBuf<'inp, 'out, GenericArray<T, N>>,
216 InOutBuf<'inp, 'out, T>,
217 ) {
218 let chunks = self.len() / N::USIZE;
219 let tail_pos = N::USIZE * chunks;
220 let tail_len = self.len() - tail_pos;
221 unsafe {
222 let chunks = InOutBuf {
223 in_ptr: self.in_ptr as *const GenericArray<T, N>,
224 out_ptr: self.out_ptr as *mut GenericArray<T, N>,
225 len: chunks,
226 _pd: PhantomData,
227 };
228 let tail = InOutBuf {
229 in_ptr: self.in_ptr.add(tail_pos),
230 out_ptr: self.out_ptr.add(tail_pos),
231 len: tail_len,
232 _pd: PhantomData,
233 };
234 (chunks, tail)
235 }
236 }
237}
238
239impl<'inp, 'out> InOutBuf<'inp, 'out, u8> {
240 /// XORs `data` with values behind the input slice and write
241 /// result to the output slice.
242 ///
243 /// # Panics
244 /// If `data` length is not equal to the buffer length.
245 #[inline(always)]
246 #[allow(clippy::needless_range_loop)]
247 pub fn xor_in2out(&mut self, data: &[u8]) {
248 assert_eq!(self.len(), data.len());
249 unsafe {
250 for i: usize in 0..data.len() {
251 let in_ptr: *const u8 = self.in_ptr.add(count:i);
252 let out_ptr: *mut u8 = self.out_ptr.add(count:i);
253 *out_ptr = *in_ptr ^ data[i];
254 }
255 }
256 }
257}
258
259impl<'inp, 'out, T, N> TryInto<InOut<'inp, 'out, GenericArray<T, N>>> for InOutBuf<'inp, 'out, T>
260where
261 N: ArrayLength<T>,
262{
263 type Error = IntoArrayError;
264
265 #[inline(always)]
266 fn try_into(self) -> Result<InOut<'inp, 'out, GenericArray<T, N>>, Self::Error> {
267 if self.len() == N::USIZE {
268 Ok(InOut {
269 in_ptr: self.in_ptr as *const _,
270 out_ptr: self.out_ptr as *mut _,
271 _pd: PhantomData,
272 })
273 } else {
274 Err(IntoArrayError)
275 }
276 }
277}
278
279/// Iterator over [`InOutBuf`].
280pub struct InOutBufIter<'inp, 'out, T> {
281 buf: InOutBuf<'inp, 'out, T>,
282 pos: usize,
283}
284
285impl<'inp, 'out, T> Iterator for InOutBufIter<'inp, 'out, T> {
286 type Item = InOut<'inp, 'out, T>;
287
288 #[inline(always)]
289 fn next(&mut self) -> Option<Self::Item> {
290 if self.buf.len() == self.pos {
291 return None;
292 }
293 let res: InOut<'_, '_, T> = unsafe {
294 InOut {
295 in_ptr: self.buf.in_ptr.add(self.pos),
296 out_ptr: self.buf.out_ptr.add(self.pos),
297 _pd: PhantomData,
298 }
299 };
300 self.pos += 1;
301 Some(res)
302 }
303}
304