1use crate::InOutBuf;
2use core::{marker::PhantomData, ptr};
3use generic_array::{ArrayLength, GenericArray};
4
5/// Custom pointer type which contains one immutable (input) and one mutable
6/// (output) pointer, which are either equal or non-overlapping.
7pub struct InOut<'inp, 'out, T> {
8 pub(crate) in_ptr: *const T,
9 pub(crate) out_ptr: *mut T,
10 pub(crate) _pd: PhantomData<(&'inp T, &'out mut T)>,
11}
12
13impl<'inp, 'out, T> InOut<'inp, 'out, T> {
14 /// Reborrow `self`.
15 #[inline(always)]
16 pub fn reborrow<'a>(&'a mut self) -> InOut<'a, 'a, T> {
17 Self {
18 in_ptr: self.in_ptr,
19 out_ptr: self.out_ptr,
20 _pd: PhantomData,
21 }
22 }
23
24 /// Get immutable reference to the input value.
25 #[inline(always)]
26 pub fn get_in<'a>(&'a self) -> &'a T {
27 unsafe { &*self.in_ptr }
28 }
29
30 /// Get mutable reference to the output value.
31 #[inline(always)]
32 pub fn get_out<'a>(&'a mut self) -> &'a mut T {
33 unsafe { &mut *self.out_ptr }
34 }
35
36 /// Convert `self` to a pair of raw input and output pointers.
37 #[inline(always)]
38 pub fn into_raw(self) -> (*const T, *mut T) {
39 (self.in_ptr, self.out_ptr)
40 }
41
42 /// Create `InOut` from raw input and output pointers.
43 ///
44 /// # Safety
45 /// Behavior is undefined if any of the following conditions are violated:
46 /// - `in_ptr` must point to a properly initialized value of type `T` and
47 /// must be valid for reads.
48 /// - `out_ptr` must point to a properly initialized value of type `T` and
49 /// must be valid for both reads and writes.
50 /// - `in_ptr` and `out_ptr` must be either equal or non-overlapping.
51 /// - If `in_ptr` and `out_ptr` are equal, then the memory referenced by
52 /// them must not be accessed through any other pointer (not derived from
53 /// the return value) for the duration of lifetime 'a. Both read and write
54 /// accesses are forbidden.
55 /// - If `in_ptr` and `out_ptr` are not equal, then the memory referenced by
56 /// `out_ptr` must not be accessed through any other pointer (not derived from
57 /// the return value) for the duration of lifetime `'a`. Both read and write
58 /// accesses are forbidden. The memory referenced by `in_ptr` must not be
59 /// mutated for the duration of lifetime `'a`, except inside an `UnsafeCell`.
60 #[inline(always)]
61 pub unsafe fn from_raw(in_ptr: *const T, out_ptr: *mut T) -> InOut<'inp, 'out, T> {
62 Self {
63 in_ptr,
64 out_ptr,
65 _pd: PhantomData,
66 }
67 }
68}
69
70impl<'inp, 'out, T: Clone> InOut<'inp, 'out, T> {
71 /// Clone input value and return it.
72 #[inline(always)]
73 pub fn clone_in(&self) -> T {
74 unsafe { (&*self.in_ptr).clone() }
75 }
76}
77
78impl<'a, T> From<&'a mut T> for InOut<'a, 'a, T> {
79 #[inline(always)]
80 fn from(val: &'a mut T) -> Self {
81 let p: *mut T = val as *mut T;
82 Self {
83 in_ptr: p,
84 out_ptr: p,
85 _pd: PhantomData,
86 }
87 }
88}
89
90impl<'inp, 'out, T> From<(&'inp T, &'out mut T)> for InOut<'inp, 'out, T> {
91 #[inline(always)]
92 fn from((in_val: &T, out_val: &mut T): (&'inp T, &'out mut T)) -> Self {
93 Self {
94 in_ptr: in_val as *const T,
95 out_ptr: out_val as *mut T,
96 _pd: Default::default(),
97 }
98 }
99}
100
101impl<'inp, 'out, T, N: ArrayLength<T>> InOut<'inp, 'out, GenericArray<T, N>> {
102 /// Returns `InOut` for the given position.
103 ///
104 /// # Panics
105 /// If `pos` greater or equal to array length.
106 #[inline(always)]
107 pub fn get<'a>(&'a mut self, pos: usize) -> InOut<'a, 'a, T> {
108 assert!(pos < N::USIZE);
109 unsafe {
110 InOut {
111 in_ptr: (self.in_ptr as *const T).add(pos),
112 out_ptr: (self.out_ptr as *mut T).add(pos),
113 _pd: PhantomData,
114 }
115 }
116 }
117
118 /// Convert `InOut` array to `InOutBuf`.
119 #[inline(always)]
120 pub fn into_buf(self) -> InOutBuf<'inp, 'out, T> {
121 InOutBuf {
122 in_ptr: self.in_ptr as *const T,
123 out_ptr: self.out_ptr as *mut T,
124 len: N::USIZE,
125 _pd: PhantomData,
126 }
127 }
128}
129
130impl<'inp, 'out, N: ArrayLength<u8>> InOut<'inp, 'out, GenericArray<u8, N>> {
131 /// XOR `data` with values behind the input slice and write
132 /// result to the output slice.
133 ///
134 /// # Panics
135 /// If `data` length is not equal to the buffer length.
136 #[inline(always)]
137 #[allow(clippy::needless_range_loop)]
138 pub fn xor_in2out(&mut self, data: &GenericArray<u8, N>) {
139 unsafe {
140 let input: GenericArray = ptr::read(self.in_ptr);
141 let mut temp: GenericArray = GenericArray::<u8, N>::default();
142 for i: usize in 0..N::USIZE {
143 temp[i] = input[i] ^ data[i];
144 }
145 ptr::write(self.out_ptr, src:temp);
146 }
147 }
148}
149
150impl<'inp, 'out, N, M> InOut<'inp, 'out, GenericArray<GenericArray<u8, N>, M>>
151where
152 N: ArrayLength<u8>,
153 M: ArrayLength<GenericArray<u8, N>>,
154{
155 /// XOR `data` with values behind the input slice and write
156 /// result to the output slice.
157 ///
158 /// # Panics
159 /// If `data` length is not equal to the buffer length.
160 #[inline(always)]
161 #[allow(clippy::needless_range_loop)]
162 pub fn xor_in2out(&mut self, data: &GenericArray<GenericArray<u8, N>, M>) {
163 unsafe {
164 let input: GenericArray, …> = ptr::read(self.in_ptr);
165 let mut temp: GenericArray, …> = GenericArray::<GenericArray<u8, N>, M>::default();
166 for i: usize in 0..M::USIZE {
167 for j: usize in 0..N::USIZE {
168 temp[i][j] = input[i][j] ^ data[i][j];
169 }
170 }
171 ptr::write(self.out_ptr, src:temp);
172 }
173 }
174}
175