| 1 | use std::collections::VecDeque; |
| 2 | use tracing::{instrument, trace}; |
| 3 | |
| 4 | use super::{AuthMechanism, BoxedSplit, Command}; |
| 5 | use crate::{Error, Result}; |
| 6 | |
| 7 | // Common code for the client and server side of the handshake. |
| 8 | #[derive (Debug)] |
| 9 | pub(super) struct Common { |
| 10 | socket: BoxedSplit, |
| 11 | recv_buffer: Vec<u8>, |
| 12 | #[cfg (unix)] |
| 13 | received_fds: Vec<std::os::fd::OwnedFd>, |
| 14 | cap_unix_fd: bool, |
| 15 | // the current AUTH mechanism is front, ordered by priority |
| 16 | mechanisms: VecDeque<AuthMechanism>, |
| 17 | first_command: bool, |
| 18 | } |
| 19 | |
| 20 | impl Common { |
| 21 | /// Start a handshake on this client socket |
| 22 | pub fn new(socket: BoxedSplit, mechanisms: VecDeque<AuthMechanism>) -> Self { |
| 23 | Self { |
| 24 | socket, |
| 25 | recv_buffer: Vec::new(), |
| 26 | #[cfg (unix)] |
| 27 | received_fds: Vec::new(), |
| 28 | cap_unix_fd: false, |
| 29 | mechanisms, |
| 30 | first_command: true, |
| 31 | } |
| 32 | } |
| 33 | |
| 34 | #[cfg (all(unix, feature = "p2p" ))] |
| 35 | pub fn socket(&self) -> &BoxedSplit { |
| 36 | &self.socket |
| 37 | } |
| 38 | |
| 39 | pub fn socket_mut(&mut self) -> &mut BoxedSplit { |
| 40 | &mut self.socket |
| 41 | } |
| 42 | |
| 43 | pub fn set_cap_unix_fd(&mut self, cap_unix_fd: bool) { |
| 44 | self.cap_unix_fd = cap_unix_fd; |
| 45 | } |
| 46 | |
| 47 | #[cfg (feature = "p2p" )] |
| 48 | pub fn mechanisms(&self) -> &VecDeque<AuthMechanism> { |
| 49 | &self.mechanisms |
| 50 | } |
| 51 | |
| 52 | pub fn into_components(self) -> IntoComponentsReturn { |
| 53 | ( |
| 54 | self.socket, |
| 55 | self.recv_buffer, |
| 56 | #[cfg (unix)] |
| 57 | self.received_fds, |
| 58 | self.cap_unix_fd, |
| 59 | self.mechanisms, |
| 60 | ) |
| 61 | } |
| 62 | |
| 63 | #[instrument (skip(self))] |
| 64 | pub async fn write_command(&mut self, command: Command) -> Result<()> { |
| 65 | self.write_commands(&[command], None).await |
| 66 | } |
| 67 | |
| 68 | #[instrument (skip(self))] |
| 69 | pub async fn write_commands( |
| 70 | &mut self, |
| 71 | commands: &[Command], |
| 72 | extra_bytes: Option<&[u8]>, |
| 73 | ) -> Result<()> { |
| 74 | let mut send_buffer = |
| 75 | commands |
| 76 | .iter() |
| 77 | .map(Vec::<u8>::from) |
| 78 | .fold(vec![], |mut acc, mut c| { |
| 79 | if self.first_command { |
| 80 | // The first command is sent by the client so we can assume it's the client. |
| 81 | self.first_command = false; |
| 82 | // leading 0 is sent separately for `freebsd` and `dragonfly`. |
| 83 | #[cfg (not(any(target_os = "freebsd" , target_os = "dragonfly" )))] |
| 84 | acc.push(b' \0' ); |
| 85 | } |
| 86 | acc.append(&mut c); |
| 87 | acc.extend_from_slice(b" \r\n" ); |
| 88 | acc |
| 89 | }); |
| 90 | if let Some(extra_bytes) = extra_bytes { |
| 91 | send_buffer.extend_from_slice(extra_bytes); |
| 92 | } |
| 93 | while !send_buffer.is_empty() { |
| 94 | let written = self |
| 95 | .socket |
| 96 | .write_mut() |
| 97 | .sendmsg( |
| 98 | &send_buffer, |
| 99 | #[cfg (unix)] |
| 100 | &[], |
| 101 | ) |
| 102 | .await?; |
| 103 | send_buffer.drain(..written); |
| 104 | } |
| 105 | trace!("Wrote all commands" ); |
| 106 | Ok(()) |
| 107 | } |
| 108 | |
| 109 | #[instrument (skip(self))] |
| 110 | pub async fn read_command(&mut self) -> Result<Command> { |
| 111 | self.read_commands(1) |
| 112 | .await |
| 113 | .map(|cmds| cmds.into_iter().next().unwrap()) |
| 114 | } |
| 115 | |
| 116 | #[instrument (skip(self))] |
| 117 | pub async fn read_commands(&mut self, n_commands: usize) -> Result<Vec<Command>> { |
| 118 | let mut commands = Vec::with_capacity(n_commands); |
| 119 | let mut n_received_commands = 0; |
| 120 | 'outer: loop { |
| 121 | while let Some(lf_index) = self.recv_buffer.iter().position(|b| *b == b' \n' ) { |
| 122 | if self.recv_buffer[lf_index - 1] != b' \r' { |
| 123 | return Err(Error::Handshake("Invalid line ending in handshake" .into())); |
| 124 | } |
| 125 | |
| 126 | #[allow (unused_mut)] |
| 127 | let mut start_index = 0; |
| 128 | if self.first_command { |
| 129 | // The first command is sent by the client so we can assume it's the server. |
| 130 | self.first_command = false; |
| 131 | if self.recv_buffer[0] != b' \0' { |
| 132 | return Err(Error::Handshake( |
| 133 | "First client byte is not NUL!" .to_string(), |
| 134 | )); |
| 135 | } |
| 136 | |
| 137 | start_index = 1; |
| 138 | }; |
| 139 | |
| 140 | let line_bytes = self.recv_buffer.drain(..=lf_index); |
| 141 | let line = std::str::from_utf8(&line_bytes.as_slice()[start_index..]) |
| 142 | .map_err(|e| Error::Handshake(e.to_string()))?; |
| 143 | |
| 144 | trace!("Reading {line}" ); |
| 145 | commands.push(line.parse()?); |
| 146 | n_received_commands += 1; |
| 147 | |
| 148 | if n_received_commands == n_commands { |
| 149 | break 'outer; |
| 150 | } |
| 151 | } |
| 152 | |
| 153 | let mut buf = vec![0; 1024]; |
| 154 | let res = self.socket.read_mut().recvmsg(&mut buf).await?; |
| 155 | let read = { |
| 156 | #[cfg (unix)] |
| 157 | { |
| 158 | let (read, fds) = res; |
| 159 | if !fds.is_empty() { |
| 160 | // Most likely belonging to the messages already received. |
| 161 | self.received_fds.extend(fds); |
| 162 | } |
| 163 | read |
| 164 | } |
| 165 | #[cfg (not(unix))] |
| 166 | { |
| 167 | res |
| 168 | } |
| 169 | }; |
| 170 | if read == 0 { |
| 171 | return Err(Error::Handshake("Unexpected EOF during handshake" .into())); |
| 172 | } |
| 173 | self.recv_buffer.extend(&buf[..read]); |
| 174 | } |
| 175 | |
| 176 | Ok(commands) |
| 177 | } |
| 178 | |
| 179 | pub fn next_mechanism(&mut self) -> Result<AuthMechanism> { |
| 180 | self.mechanisms |
| 181 | .pop_front() |
| 182 | .ok_or_else(|| Error::Handshake("Exhausted available AUTH mechanisms" .into())) |
| 183 | } |
| 184 | } |
| 185 | |
| 186 | #[cfg (unix)] |
| 187 | type IntoComponentsReturn = ( |
| 188 | BoxedSplit, |
| 189 | Vec<u8>, |
| 190 | Vec<std::os::fd::OwnedFd>, |
| 191 | bool, |
| 192 | VecDeque<AuthMechanism>, |
| 193 | ); |
| 194 | #[cfg (not(unix))] |
| 195 | type IntoComponentsReturn = (BoxedSplit, Vec<u8>, bool, VecDeque<AuthMechanism>); |
| 196 | |