1 | use crate::conn::{ConnectionCommon, SideData}; |
2 | |
3 | use core::ops::{Deref, DerefMut}; |
4 | use std::io::{IoSlice, Read, Result, Write}; |
5 | |
6 | /// This type implements `io::Read` and `io::Write`, encapsulating |
7 | /// a Connection `C` and an underlying transport `T`, such as a socket. |
8 | /// |
9 | /// This allows you to use a rustls Connection like a normal stream. |
10 | #[derive (Debug)] |
11 | pub struct Stream<'a, C: 'a + ?Sized, T: 'a + Read + Write + ?Sized> { |
12 | /// Our TLS connection |
13 | pub conn: &'a mut C, |
14 | |
15 | /// The underlying transport, like a socket |
16 | pub sock: &'a mut T, |
17 | } |
18 | |
19 | impl<'a, C, T, S> Stream<'a, C, T> |
20 | where |
21 | C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>, |
22 | T: 'a + Read + Write, |
23 | S: SideData, |
24 | { |
25 | /// Make a new Stream using the Connection `conn` and socket-like object |
26 | /// `sock`. This does not fail and does no IO. |
27 | pub fn new(conn: &'a mut C, sock: &'a mut T) -> Self { |
28 | Self { conn, sock } |
29 | } |
30 | |
31 | /// If we're handshaking, complete all the IO for that. |
32 | /// If we have data to write, write it all. |
33 | fn complete_prior_io(&mut self) -> Result<()> { |
34 | if self.conn.is_handshaking() { |
35 | self.conn.complete_io(self.sock)?; |
36 | } |
37 | |
38 | if self.conn.wants_write() { |
39 | self.conn.complete_io(self.sock)?; |
40 | } |
41 | |
42 | Ok(()) |
43 | } |
44 | } |
45 | |
46 | impl<'a, C, T, S> Read for Stream<'a, C, T> |
47 | where |
48 | C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>, |
49 | T: 'a + Read + Write, |
50 | S: SideData, |
51 | { |
52 | fn read(&mut self, buf: &mut [u8]) -> Result<usize> { |
53 | self.complete_prior_io()?; |
54 | |
55 | // We call complete_io() in a loop since a single call may read only |
56 | // a partial packet from the underlying transport. A full packet is |
57 | // needed to get more plaintext, which we must do if EOF has not been |
58 | // hit. |
59 | while self.conn.wants_read() { |
60 | if self.conn.complete_io(self.sock)?.0 == 0 { |
61 | break; |
62 | } |
63 | } |
64 | |
65 | self.conn.reader().read(buf) |
66 | } |
67 | |
68 | #[cfg (read_buf)] |
69 | fn read_buf(&mut self, cursor: core::io::BorrowedCursor<'_>) -> Result<()> { |
70 | self.complete_prior_io()?; |
71 | |
72 | // We call complete_io() in a loop since a single call may read only |
73 | // a partial packet from the underlying transport. A full packet is |
74 | // needed to get more plaintext, which we must do if EOF has not been |
75 | // hit. |
76 | while self.conn.wants_read() { |
77 | if self.conn.complete_io(self.sock)?.0 == 0 { |
78 | break; |
79 | } |
80 | } |
81 | |
82 | self.conn.reader().read_buf(cursor) |
83 | } |
84 | } |
85 | |
86 | impl<'a, C, T, S> Write for Stream<'a, C, T> |
87 | where |
88 | C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>, |
89 | T: 'a + Read + Write, |
90 | S: SideData, |
91 | { |
92 | fn write(&mut self, buf: &[u8]) -> Result<usize> { |
93 | self.complete_prior_io()?; |
94 | |
95 | let len = self.conn.writer().write(buf)?; |
96 | |
97 | // Try to write the underlying transport here, but don't let |
98 | // any errors mask the fact we've consumed `len` bytes. |
99 | // Callers will learn of permanent errors on the next call. |
100 | let _ = self.conn.complete_io(self.sock); |
101 | |
102 | Ok(len) |
103 | } |
104 | |
105 | fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> Result<usize> { |
106 | self.complete_prior_io()?; |
107 | |
108 | let len = self |
109 | .conn |
110 | .writer() |
111 | .write_vectored(bufs)?; |
112 | |
113 | // Try to write the underlying transport here, but don't let |
114 | // any errors mask the fact we've consumed `len` bytes. |
115 | // Callers will learn of permanent errors on the next call. |
116 | let _ = self.conn.complete_io(self.sock); |
117 | |
118 | Ok(len) |
119 | } |
120 | |
121 | fn flush(&mut self) -> Result<()> { |
122 | self.complete_prior_io()?; |
123 | |
124 | self.conn.writer().flush()?; |
125 | if self.conn.wants_write() { |
126 | self.conn.complete_io(self.sock)?; |
127 | } |
128 | Ok(()) |
129 | } |
130 | } |
131 | |
132 | /// This type implements `io::Read` and `io::Write`, encapsulating |
133 | /// and owning a Connection `C` and an underlying blocking transport |
134 | /// `T`, such as a socket. |
135 | /// |
136 | /// This allows you to use a rustls Connection like a normal stream. |
137 | #[derive (Debug)] |
138 | pub struct StreamOwned<C: Sized, T: Read + Write + Sized> { |
139 | /// Our connection |
140 | pub conn: C, |
141 | |
142 | /// The underlying transport, like a socket |
143 | pub sock: T, |
144 | } |
145 | |
146 | impl<C, T, S> StreamOwned<C, T> |
147 | where |
148 | C: DerefMut + Deref<Target = ConnectionCommon<S>>, |
149 | T: Read + Write, |
150 | S: SideData, |
151 | { |
152 | /// Make a new StreamOwned taking the Connection `conn` and socket-like |
153 | /// object `sock`. This does not fail and does no IO. |
154 | /// |
155 | /// This is the same as `Stream::new` except `conn` and `sock` are |
156 | /// moved into the StreamOwned. |
157 | pub fn new(conn: C, sock: T) -> Self { |
158 | Self { conn, sock } |
159 | } |
160 | |
161 | /// Get a reference to the underlying socket |
162 | pub fn get_ref(&self) -> &T { |
163 | &self.sock |
164 | } |
165 | |
166 | /// Get a mutable reference to the underlying socket |
167 | pub fn get_mut(&mut self) -> &mut T { |
168 | &mut self.sock |
169 | } |
170 | |
171 | /// Extract the `conn` and `sock` parts from the `StreamOwned` |
172 | pub fn into_parts(self) -> (C, T) { |
173 | (self.conn, self.sock) |
174 | } |
175 | } |
176 | |
177 | impl<'a, C, T, S> StreamOwned<C, T> |
178 | where |
179 | C: DerefMut + Deref<Target = ConnectionCommon<S>>, |
180 | T: Read + Write, |
181 | S: SideData, |
182 | { |
183 | fn as_stream(&'a mut self) -> Stream<'a, C, T> { |
184 | Stream { |
185 | conn: &mut self.conn, |
186 | sock: &mut self.sock, |
187 | } |
188 | } |
189 | } |
190 | |
191 | impl<C, T, S> Read for StreamOwned<C, T> |
192 | where |
193 | C: DerefMut + Deref<Target = ConnectionCommon<S>>, |
194 | T: Read + Write, |
195 | S: SideData, |
196 | { |
197 | fn read(&mut self, buf: &mut [u8]) -> Result<usize> { |
198 | self.as_stream().read(buf) |
199 | } |
200 | |
201 | #[cfg (read_buf)] |
202 | fn read_buf(&mut self, cursor: core::io::BorrowedCursor<'_>) -> Result<()> { |
203 | self.as_stream().read_buf(cursor) |
204 | } |
205 | } |
206 | |
207 | impl<C, T, S> Write for StreamOwned<C, T> |
208 | where |
209 | C: DerefMut + Deref<Target = ConnectionCommon<S>>, |
210 | T: Read + Write, |
211 | S: SideData, |
212 | { |
213 | fn write(&mut self, buf: &[u8]) -> Result<usize> { |
214 | self.as_stream().write(buf) |
215 | } |
216 | |
217 | fn flush(&mut self) -> Result<()> { |
218 | self.as_stream().flush() |
219 | } |
220 | } |
221 | |
222 | #[cfg (test)] |
223 | mod tests { |
224 | use super::{Stream, StreamOwned}; |
225 | use crate::client::ClientConnection; |
226 | use crate::server::ServerConnection; |
227 | use std::net::TcpStream; |
228 | |
229 | #[test ] |
230 | fn stream_can_be_created_for_connection_and_tcpstream() { |
231 | type _Test<'a> = Stream<'a, ClientConnection, TcpStream>; |
232 | } |
233 | |
234 | #[test ] |
235 | fn streamowned_can_be_created_for_client_and_tcpstream() { |
236 | type _Test = StreamOwned<ClientConnection, TcpStream>; |
237 | } |
238 | |
239 | #[test ] |
240 | fn streamowned_can_be_created_for_server_and_tcpstream() { |
241 | type _Test = StreamOwned<ServerConnection, TcpStream>; |
242 | } |
243 | } |
244 | |