1 | use std::collections::VecDeque; |
2 | use tracing::{instrument, trace}; |
3 | |
4 | use super::{AuthMechanism, BoxedSplit, Command}; |
5 | use crate::{Error, Result}; |
6 | |
7 | // Common code for the client and server side of the handshake. |
8 | #[derive (Debug)] |
9 | pub(super) struct Common { |
10 | socket: BoxedSplit, |
11 | recv_buffer: Vec<u8>, |
12 | #[cfg (unix)] |
13 | received_fds: Vec<std::os::fd::OwnedFd>, |
14 | cap_unix_fd: bool, |
15 | // the current AUTH mechanism is front, ordered by priority |
16 | mechanisms: VecDeque<AuthMechanism>, |
17 | first_command: bool, |
18 | } |
19 | |
20 | impl Common { |
21 | /// Start a handshake on this client socket |
22 | pub fn new(socket: BoxedSplit, mechanisms: VecDeque<AuthMechanism>) -> Self { |
23 | Self { |
24 | socket, |
25 | recv_buffer: Vec::new(), |
26 | #[cfg (unix)] |
27 | received_fds: Vec::new(), |
28 | cap_unix_fd: false, |
29 | mechanisms, |
30 | first_command: true, |
31 | } |
32 | } |
33 | |
34 | #[cfg (all(unix, feature = "p2p" ))] |
35 | pub fn socket(&self) -> &BoxedSplit { |
36 | &self.socket |
37 | } |
38 | |
39 | pub fn socket_mut(&mut self) -> &mut BoxedSplit { |
40 | &mut self.socket |
41 | } |
42 | |
43 | pub fn set_cap_unix_fd(&mut self, cap_unix_fd: bool) { |
44 | self.cap_unix_fd = cap_unix_fd; |
45 | } |
46 | |
47 | #[cfg (feature = "p2p" )] |
48 | pub fn mechanisms(&self) -> &VecDeque<AuthMechanism> { |
49 | &self.mechanisms |
50 | } |
51 | |
52 | pub fn into_components(self) -> IntoComponentsReturn { |
53 | ( |
54 | self.socket, |
55 | self.recv_buffer, |
56 | #[cfg (unix)] |
57 | self.received_fds, |
58 | self.cap_unix_fd, |
59 | self.mechanisms, |
60 | ) |
61 | } |
62 | |
63 | #[instrument (skip(self))] |
64 | pub async fn write_command(&mut self, command: Command) -> Result<()> { |
65 | self.write_commands(&[command], None).await |
66 | } |
67 | |
68 | #[instrument (skip(self))] |
69 | pub async fn write_commands( |
70 | &mut self, |
71 | commands: &[Command], |
72 | extra_bytes: Option<&[u8]>, |
73 | ) -> Result<()> { |
74 | let mut send_buffer = |
75 | commands |
76 | .iter() |
77 | .map(Vec::<u8>::from) |
78 | .fold(vec![], |mut acc, mut c| { |
79 | if self.first_command { |
80 | // The first command is sent by the client so we can assume it's the client. |
81 | self.first_command = false; |
82 | // leading 0 is sent separately for `freebsd` and `dragonfly`. |
83 | #[cfg (not(any(target_os = "freebsd" , target_os = "dragonfly" )))] |
84 | acc.push(b' \0' ); |
85 | } |
86 | acc.append(&mut c); |
87 | acc.extend_from_slice(b" \r\n" ); |
88 | acc |
89 | }); |
90 | if let Some(extra_bytes) = extra_bytes { |
91 | send_buffer.extend_from_slice(extra_bytes); |
92 | } |
93 | while !send_buffer.is_empty() { |
94 | let written = self |
95 | .socket |
96 | .write_mut() |
97 | .sendmsg( |
98 | &send_buffer, |
99 | #[cfg (unix)] |
100 | &[], |
101 | ) |
102 | .await?; |
103 | send_buffer.drain(..written); |
104 | } |
105 | trace!("Wrote all commands" ); |
106 | Ok(()) |
107 | } |
108 | |
109 | #[instrument (skip(self))] |
110 | pub async fn read_command(&mut self) -> Result<Command> { |
111 | self.read_commands(1) |
112 | .await |
113 | .map(|cmds| cmds.into_iter().next().unwrap()) |
114 | } |
115 | |
116 | #[instrument (skip(self))] |
117 | pub async fn read_commands(&mut self, n_commands: usize) -> Result<Vec<Command>> { |
118 | let mut commands = Vec::with_capacity(n_commands); |
119 | let mut n_received_commands = 0; |
120 | 'outer: loop { |
121 | while let Some(lf_index) = self.recv_buffer.iter().position(|b| *b == b' \n' ) { |
122 | if self.recv_buffer[lf_index - 1] != b' \r' { |
123 | return Err(Error::Handshake("Invalid line ending in handshake" .into())); |
124 | } |
125 | |
126 | #[allow (unused_mut)] |
127 | let mut start_index = 0; |
128 | if self.first_command { |
129 | // The first command is sent by the client so we can assume it's the server. |
130 | self.first_command = false; |
131 | if self.recv_buffer[0] != b' \0' { |
132 | return Err(Error::Handshake( |
133 | "First client byte is not NUL!" .to_string(), |
134 | )); |
135 | } |
136 | |
137 | start_index = 1; |
138 | }; |
139 | |
140 | let line_bytes = self.recv_buffer.drain(..=lf_index); |
141 | let line = std::str::from_utf8(&line_bytes.as_slice()[start_index..]) |
142 | .map_err(|e| Error::Handshake(e.to_string()))?; |
143 | |
144 | trace!("Reading {line}" ); |
145 | commands.push(line.parse()?); |
146 | n_received_commands += 1; |
147 | |
148 | if n_received_commands == n_commands { |
149 | break 'outer; |
150 | } |
151 | } |
152 | |
153 | let mut buf = vec![0; 1024]; |
154 | let res = self.socket.read_mut().recvmsg(&mut buf).await?; |
155 | let read = { |
156 | #[cfg (unix)] |
157 | { |
158 | let (read, fds) = res; |
159 | if !fds.is_empty() { |
160 | // Most likely belonging to the messages already received. |
161 | self.received_fds.extend(fds); |
162 | } |
163 | read |
164 | } |
165 | #[cfg (not(unix))] |
166 | { |
167 | res |
168 | } |
169 | }; |
170 | if read == 0 { |
171 | return Err(Error::Handshake("Unexpected EOF during handshake" .into())); |
172 | } |
173 | self.recv_buffer.extend(&buf[..read]); |
174 | } |
175 | |
176 | Ok(commands) |
177 | } |
178 | |
179 | pub fn next_mechanism(&mut self) -> Result<AuthMechanism> { |
180 | self.mechanisms |
181 | .pop_front() |
182 | .ok_or_else(|| Error::Handshake("Exhausted available AUTH mechanisms" .into())) |
183 | } |
184 | } |
185 | |
186 | #[cfg (unix)] |
187 | type IntoComponentsReturn = ( |
188 | BoxedSplit, |
189 | Vec<u8>, |
190 | Vec<std::os::fd::OwnedFd>, |
191 | bool, |
192 | VecDeque<AuthMechanism>, |
193 | ); |
194 | #[cfg (not(unix))] |
195 | type IntoComponentsReturn = (BoxedSplit, Vec<u8>, bool, VecDeque<AuthMechanism>); |
196 | |