1 | use core::ops::{Deref, DerefMut}; |
2 | use std::io::{BufRead, IoSlice, Read, Result, Write}; |
3 | |
4 | use crate::conn::{ConnectionCommon, SideData}; |
5 | |
6 | /// This type implements `io::Read` and `io::Write`, encapsulating |
7 | /// a Connection `C` and an underlying transport `T`, such as a socket. |
8 | /// |
9 | /// This allows you to use a rustls Connection like a normal stream. |
10 | #[derive (Debug)] |
11 | pub struct Stream<'a, C: 'a + ?Sized, T: 'a + Read + Write + ?Sized> { |
12 | /// Our TLS connection |
13 | pub conn: &'a mut C, |
14 | |
15 | /// The underlying transport, like a socket |
16 | pub sock: &'a mut T, |
17 | } |
18 | |
19 | impl<'a, C, T, S> Stream<'a, C, T> |
20 | where |
21 | C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>, |
22 | T: 'a + Read + Write, |
23 | S: SideData, |
24 | { |
25 | /// Make a new Stream using the Connection `conn` and socket-like object |
26 | /// `sock`. This does not fail and does no IO. |
27 | pub fn new(conn: &'a mut C, sock: &'a mut T) -> Self { |
28 | Self { conn, sock } |
29 | } |
30 | |
31 | /// If we're handshaking, complete all the IO for that. |
32 | /// If we have data to write, write it all. |
33 | fn complete_prior_io(&mut self) -> Result<()> { |
34 | if self.conn.is_handshaking() { |
35 | self.conn.complete_io(self.sock)?; |
36 | } |
37 | |
38 | if self.conn.wants_write() { |
39 | self.conn.complete_io(self.sock)?; |
40 | } |
41 | |
42 | Ok(()) |
43 | } |
44 | |
45 | fn prepare_read(&mut self) -> Result<()> { |
46 | self.complete_prior_io()?; |
47 | |
48 | // We call complete_io() in a loop since a single call may read only |
49 | // a partial packet from the underlying transport. A full packet is |
50 | // needed to get more plaintext, which we must do if EOF has not been |
51 | // hit. |
52 | while self.conn.wants_read() { |
53 | if self.conn.complete_io(self.sock)?.0 == 0 { |
54 | break; |
55 | } |
56 | } |
57 | |
58 | Ok(()) |
59 | } |
60 | |
61 | // Implements `BufRead::fill_buf` but with more flexible lifetimes, so StreamOwned can reuse it |
62 | fn fill_buf(mut self) -> Result<&'a [u8]> |
63 | where |
64 | S: 'a, |
65 | { |
66 | self.prepare_read()?; |
67 | self.conn.reader().into_first_chunk() |
68 | } |
69 | } |
70 | |
71 | impl<'a, C, T, S> Read for Stream<'a, C, T> |
72 | where |
73 | C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>, |
74 | T: 'a + Read + Write, |
75 | S: SideData, |
76 | { |
77 | fn read(&mut self, buf: &mut [u8]) -> Result<usize> { |
78 | self.prepare_read()?; |
79 | self.conn.reader().read(buf) |
80 | } |
81 | |
82 | #[cfg (read_buf)] |
83 | fn read_buf(&mut self, cursor: core::io::BorrowedCursor<'_>) -> Result<()> { |
84 | self.prepare_read()?; |
85 | self.conn.reader().read_buf(cursor) |
86 | } |
87 | } |
88 | |
89 | impl<'a, C, T, S> BufRead for Stream<'a, C, T> |
90 | where |
91 | C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>, |
92 | T: 'a + Read + Write, |
93 | S: 'a + SideData, |
94 | { |
95 | fn fill_buf(&mut self) -> Result<&[u8]> { |
96 | // reborrow to get an owned `Stream` |
97 | Stream { |
98 | conn: self.conn, |
99 | sock: self.sock, |
100 | } |
101 | .fill_buf() |
102 | } |
103 | |
104 | fn consume(&mut self, amt: usize) { |
105 | self.conn.reader().consume(amount:amt) |
106 | } |
107 | } |
108 | |
109 | impl<'a, C, T, S> Write for Stream<'a, C, T> |
110 | where |
111 | C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>, |
112 | T: 'a + Read + Write, |
113 | S: SideData, |
114 | { |
115 | fn write(&mut self, buf: &[u8]) -> Result<usize> { |
116 | self.complete_prior_io()?; |
117 | |
118 | let len = self.conn.writer().write(buf)?; |
119 | |
120 | // Try to write the underlying transport here, but don't let |
121 | // any errors mask the fact we've consumed `len` bytes. |
122 | // Callers will learn of permanent errors on the next call. |
123 | let _ = self.conn.complete_io(self.sock); |
124 | |
125 | Ok(len) |
126 | } |
127 | |
128 | fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> Result<usize> { |
129 | self.complete_prior_io()?; |
130 | |
131 | let len = self |
132 | .conn |
133 | .writer() |
134 | .write_vectored(bufs)?; |
135 | |
136 | // Try to write the underlying transport here, but don't let |
137 | // any errors mask the fact we've consumed `len` bytes. |
138 | // Callers will learn of permanent errors on the next call. |
139 | let _ = self.conn.complete_io(self.sock); |
140 | |
141 | Ok(len) |
142 | } |
143 | |
144 | fn flush(&mut self) -> Result<()> { |
145 | self.complete_prior_io()?; |
146 | |
147 | self.conn.writer().flush()?; |
148 | if self.conn.wants_write() { |
149 | self.conn.complete_io(self.sock)?; |
150 | } |
151 | Ok(()) |
152 | } |
153 | } |
154 | |
155 | /// This type implements `io::Read` and `io::Write`, encapsulating |
156 | /// and owning a Connection `C` and an underlying blocking transport |
157 | /// `T`, such as a socket. |
158 | /// |
159 | /// This allows you to use a rustls Connection like a normal stream. |
160 | #[derive (Debug)] |
161 | pub struct StreamOwned<C: Sized, T: Read + Write + Sized> { |
162 | /// Our connection |
163 | pub conn: C, |
164 | |
165 | /// The underlying transport, like a socket |
166 | pub sock: T, |
167 | } |
168 | |
169 | impl<C, T, S> StreamOwned<C, T> |
170 | where |
171 | C: DerefMut + Deref<Target = ConnectionCommon<S>>, |
172 | T: Read + Write, |
173 | S: SideData, |
174 | { |
175 | /// Make a new StreamOwned taking the Connection `conn` and socket-like |
176 | /// object `sock`. This does not fail and does no IO. |
177 | /// |
178 | /// This is the same as `Stream::new` except `conn` and `sock` are |
179 | /// moved into the StreamOwned. |
180 | pub fn new(conn: C, sock: T) -> Self { |
181 | Self { conn, sock } |
182 | } |
183 | |
184 | /// Get a reference to the underlying socket |
185 | pub fn get_ref(&self) -> &T { |
186 | &self.sock |
187 | } |
188 | |
189 | /// Get a mutable reference to the underlying socket |
190 | pub fn get_mut(&mut self) -> &mut T { |
191 | &mut self.sock |
192 | } |
193 | |
194 | /// Extract the `conn` and `sock` parts from the `StreamOwned` |
195 | pub fn into_parts(self) -> (C, T) { |
196 | (self.conn, self.sock) |
197 | } |
198 | } |
199 | |
200 | impl<'a, C, T, S> StreamOwned<C, T> |
201 | where |
202 | C: DerefMut + Deref<Target = ConnectionCommon<S>>, |
203 | T: Read + Write, |
204 | S: SideData, |
205 | { |
206 | fn as_stream(&'a mut self) -> Stream<'a, C, T> { |
207 | Stream { |
208 | conn: &mut self.conn, |
209 | sock: &mut self.sock, |
210 | } |
211 | } |
212 | } |
213 | |
214 | impl<C, T, S> Read for StreamOwned<C, T> |
215 | where |
216 | C: DerefMut + Deref<Target = ConnectionCommon<S>>, |
217 | T: Read + Write, |
218 | S: SideData, |
219 | { |
220 | fn read(&mut self, buf: &mut [u8]) -> Result<usize> { |
221 | self.as_stream().read(buf) |
222 | } |
223 | |
224 | #[cfg (read_buf)] |
225 | fn read_buf(&mut self, cursor: core::io::BorrowedCursor<'_>) -> Result<()> { |
226 | self.as_stream().read_buf(cursor) |
227 | } |
228 | } |
229 | |
230 | impl<C, T, S> BufRead for StreamOwned<C, T> |
231 | where |
232 | C: DerefMut + Deref<Target = ConnectionCommon<S>>, |
233 | T: Read + Write, |
234 | S: 'static + SideData, |
235 | { |
236 | fn fill_buf(&mut self) -> Result<&[u8]> { |
237 | self.as_stream().fill_buf() |
238 | } |
239 | |
240 | fn consume(&mut self, amt: usize) { |
241 | self.as_stream().consume(amount:amt) |
242 | } |
243 | } |
244 | |
245 | impl<C, T, S> Write for StreamOwned<C, T> |
246 | where |
247 | C: DerefMut + Deref<Target = ConnectionCommon<S>>, |
248 | T: Read + Write, |
249 | S: SideData, |
250 | { |
251 | fn write(&mut self, buf: &[u8]) -> Result<usize> { |
252 | self.as_stream().write(buf) |
253 | } |
254 | |
255 | fn flush(&mut self) -> Result<()> { |
256 | self.as_stream().flush() |
257 | } |
258 | } |
259 | |
260 | #[cfg (test)] |
261 | mod tests { |
262 | use std::net::TcpStream; |
263 | |
264 | use super::{Stream, StreamOwned}; |
265 | use crate::client::ClientConnection; |
266 | use crate::server::ServerConnection; |
267 | |
268 | #[test ] |
269 | fn stream_can_be_created_for_connection_and_tcpstream() { |
270 | type _Test<'a> = Stream<'a, ClientConnection, TcpStream>; |
271 | } |
272 | |
273 | #[test ] |
274 | fn streamowned_can_be_created_for_client_and_tcpstream() { |
275 | type _Test = StreamOwned<ClientConnection, TcpStream>; |
276 | } |
277 | |
278 | #[test ] |
279 | fn streamowned_can_be_created_for_server_and_tcpstream() { |
280 | type _Test = StreamOwned<ServerConnection, TcpStream>; |
281 | } |
282 | } |
283 | |