| 1 | use crate::InOutBuf; |
| 2 | use core::{marker::PhantomData, ptr}; |
| 3 | use generic_array::{ArrayLength, GenericArray}; |
| 4 | |
| 5 | /// Custom pointer type which contains one immutable (input) and one mutable |
| 6 | /// (output) pointer, which are either equal or non-overlapping. |
| 7 | pub struct InOut<'inp, 'out, T> { |
| 8 | pub(crate) in_ptr: *const T, |
| 9 | pub(crate) out_ptr: *mut T, |
| 10 | pub(crate) _pd: PhantomData<(&'inp T, &'out mut T)>, |
| 11 | } |
| 12 | |
| 13 | impl<'inp, 'out, T> InOut<'inp, 'out, T> { |
| 14 | /// Reborrow `self`. |
| 15 | #[inline (always)] |
| 16 | pub fn reborrow<'a>(&'a mut self) -> InOut<'a, 'a, T> { |
| 17 | Self { |
| 18 | in_ptr: self.in_ptr, |
| 19 | out_ptr: self.out_ptr, |
| 20 | _pd: PhantomData, |
| 21 | } |
| 22 | } |
| 23 | |
| 24 | /// Get immutable reference to the input value. |
| 25 | #[inline (always)] |
| 26 | pub fn get_in<'a>(&'a self) -> &'a T { |
| 27 | unsafe { &*self.in_ptr } |
| 28 | } |
| 29 | |
| 30 | /// Get mutable reference to the output value. |
| 31 | #[inline (always)] |
| 32 | pub fn get_out<'a>(&'a mut self) -> &'a mut T { |
| 33 | unsafe { &mut *self.out_ptr } |
| 34 | } |
| 35 | |
| 36 | /// Convert `self` to a pair of raw input and output pointers. |
| 37 | #[inline (always)] |
| 38 | pub fn into_raw(self) -> (*const T, *mut T) { |
| 39 | (self.in_ptr, self.out_ptr) |
| 40 | } |
| 41 | |
| 42 | /// Create `InOut` from raw input and output pointers. |
| 43 | /// |
| 44 | /// # Safety |
| 45 | /// Behavior is undefined if any of the following conditions are violated: |
| 46 | /// - `in_ptr` must point to a properly initialized value of type `T` and |
| 47 | /// must be valid for reads. |
| 48 | /// - `out_ptr` must point to a properly initialized value of type `T` and |
| 49 | /// must be valid for both reads and writes. |
| 50 | /// - `in_ptr` and `out_ptr` must be either equal or non-overlapping. |
| 51 | /// - If `in_ptr` and `out_ptr` are equal, then the memory referenced by |
| 52 | /// them must not be accessed through any other pointer (not derived from |
| 53 | /// the return value) for the duration of lifetime 'a. Both read and write |
| 54 | /// accesses are forbidden. |
| 55 | /// - If `in_ptr` and `out_ptr` are not equal, then the memory referenced by |
| 56 | /// `out_ptr` must not be accessed through any other pointer (not derived from |
| 57 | /// the return value) for the duration of lifetime `'a`. Both read and write |
| 58 | /// accesses are forbidden. The memory referenced by `in_ptr` must not be |
| 59 | /// mutated for the duration of lifetime `'a`, except inside an `UnsafeCell`. |
| 60 | #[inline (always)] |
| 61 | pub unsafe fn from_raw(in_ptr: *const T, out_ptr: *mut T) -> InOut<'inp, 'out, T> { |
| 62 | Self { |
| 63 | in_ptr, |
| 64 | out_ptr, |
| 65 | _pd: PhantomData, |
| 66 | } |
| 67 | } |
| 68 | } |
| 69 | |
| 70 | impl<'inp, 'out, T: Clone> InOut<'inp, 'out, T> { |
| 71 | /// Clone input value and return it. |
| 72 | #[inline (always)] |
| 73 | pub fn clone_in(&self) -> T { |
| 74 | unsafe { (&*self.in_ptr).clone() } |
| 75 | } |
| 76 | } |
| 77 | |
| 78 | impl<'a, T> From<&'a mut T> for InOut<'a, 'a, T> { |
| 79 | #[inline (always)] |
| 80 | fn from(val: &'a mut T) -> Self { |
| 81 | let p: *mut T = val as *mut T; |
| 82 | Self { |
| 83 | in_ptr: p, |
| 84 | out_ptr: p, |
| 85 | _pd: PhantomData, |
| 86 | } |
| 87 | } |
| 88 | } |
| 89 | |
| 90 | impl<'inp, 'out, T> From<(&'inp T, &'out mut T)> for InOut<'inp, 'out, T> { |
| 91 | #[inline (always)] |
| 92 | fn from((in_val: &'inp T, out_val: &'out mut T): (&'inp T, &'out mut T)) -> Self { |
| 93 | Self { |
| 94 | in_ptr: in_val as *const T, |
| 95 | out_ptr: out_val as *mut T, |
| 96 | _pd: Default::default(), |
| 97 | } |
| 98 | } |
| 99 | } |
| 100 | |
| 101 | impl<'inp, 'out, T, N: ArrayLength<T>> InOut<'inp, 'out, GenericArray<T, N>> { |
| 102 | /// Returns `InOut` for the given position. |
| 103 | /// |
| 104 | /// # Panics |
| 105 | /// If `pos` greater or equal to array length. |
| 106 | #[inline (always)] |
| 107 | pub fn get<'a>(&'a mut self, pos: usize) -> InOut<'a, 'a, T> { |
| 108 | assert!(pos < N::USIZE); |
| 109 | unsafe { |
| 110 | InOut { |
| 111 | in_ptr: (self.in_ptr as *const T).add(pos), |
| 112 | out_ptr: (self.out_ptr as *mut T).add(pos), |
| 113 | _pd: PhantomData, |
| 114 | } |
| 115 | } |
| 116 | } |
| 117 | |
| 118 | /// Convert `InOut` array to `InOutBuf`. |
| 119 | #[inline (always)] |
| 120 | pub fn into_buf(self) -> InOutBuf<'inp, 'out, T> { |
| 121 | InOutBuf { |
| 122 | in_ptr: self.in_ptr as *const T, |
| 123 | out_ptr: self.out_ptr as *mut T, |
| 124 | len: N::USIZE, |
| 125 | _pd: PhantomData, |
| 126 | } |
| 127 | } |
| 128 | } |
| 129 | |
| 130 | impl<'inp, 'out, N: ArrayLength<u8>> InOut<'inp, 'out, GenericArray<u8, N>> { |
| 131 | /// XOR `data` with values behind the input slice and write |
| 132 | /// result to the output slice. |
| 133 | /// |
| 134 | /// # Panics |
| 135 | /// If `data` length is not equal to the buffer length. |
| 136 | #[inline (always)] |
| 137 | #[allow (clippy::needless_range_loop)] |
| 138 | pub fn xor_in2out(&mut self, data: &GenericArray<u8, N>) { |
| 139 | unsafe { |
| 140 | let input: GenericArray = ptr::read(self.in_ptr); |
| 141 | let mut temp: GenericArray = GenericArray::<u8, N>::default(); |
| 142 | for i: usize in 0..N::USIZE { |
| 143 | temp[i] = input[i] ^ data[i]; |
| 144 | } |
| 145 | ptr::write(self.out_ptr, src:temp); |
| 146 | } |
| 147 | } |
| 148 | } |
| 149 | |
| 150 | impl<'inp, 'out, N, M> InOut<'inp, 'out, GenericArray<GenericArray<u8, N>, M>> |
| 151 | where |
| 152 | N: ArrayLength<u8>, |
| 153 | M: ArrayLength<GenericArray<u8, N>>, |
| 154 | { |
| 155 | /// XOR `data` with values behind the input slice and write |
| 156 | /// result to the output slice. |
| 157 | /// |
| 158 | /// # Panics |
| 159 | /// If `data` length is not equal to the buffer length. |
| 160 | #[inline (always)] |
| 161 | #[allow (clippy::needless_range_loop)] |
| 162 | pub fn xor_in2out(&mut self, data: &GenericArray<GenericArray<u8, N>, M>) { |
| 163 | unsafe { |
| 164 | let input: GenericArray, …> = ptr::read(self.in_ptr); |
| 165 | let mut temp: GenericArray, …> = GenericArray::<GenericArray<u8, N>, M>::default(); |
| 166 | for i: usize in 0..M::USIZE { |
| 167 | for j: usize in 0..N::USIZE { |
| 168 | temp[i][j] = input[i][j] ^ data[i][j]; |
| 169 | } |
| 170 | } |
| 171 | ptr::write(self.out_ptr, src:temp); |
| 172 | } |
| 173 | } |
| 174 | } |
| 175 | |