1 | use std::io; |
2 | use std::net::ToSocketAddrs; |
3 | use std::net::{SocketAddr, TcpListener, TcpStream}; |
4 | use std::sync::atomic::{AtomicBool, Ordering}; |
5 | use std::sync::Arc; |
6 | use std::thread; |
7 | use std::time::Duration; |
8 | |
9 | use crate::{Agent, AgentBuilder}; |
10 | |
11 | #[cfg (not(feature = "testdeps" ))] |
12 | fn test_server_handler(_stream: TcpStream) -> io::Result<()> { |
13 | Ok(()) |
14 | } |
15 | |
16 | #[cfg (feature = "testdeps" )] |
17 | fn test_server_handler(stream: TcpStream) -> io::Result<()> { |
18 | use hootbin::serve_single; |
19 | let o = stream.try_clone().expect("TcpStream to be clonable" ); |
20 | let i = stream; |
21 | match serve_single(i, o, "https://hootbin.test/" ) { |
22 | Ok(()) => {} |
23 | Err(e) => { |
24 | if let hootbin::Error::Io(ioe) = &e { |
25 | if ioe.kind() == io::ErrorKind::UnexpectedEof { |
26 | // accept this. the pre-connect below is always erroring. |
27 | return Ok(()); |
28 | } |
29 | } |
30 | |
31 | println!("TestServer error: {:?}" , e); |
32 | } |
33 | }; |
34 | Ok(()) |
35 | } |
36 | |
37 | // An agent to be installed by default for tests and doctests, such |
38 | // that all hostnames resolve to a TestServer on localhost. |
39 | pub(crate) fn test_agent() -> Agent { |
40 | #[cfg (test)] |
41 | let _ = env_logger::try_init(); |
42 | |
43 | let testserver: TestServer = TestServer::new(test_server_handler); |
44 | // Slightly tricky thing here: we want to make sure the TestServer lives |
45 | // as long as the agent. This is accomplished by `move`ing it into the |
46 | // closure, which becomes owned by the agent. |
47 | AgentBuilderAgentBuilder::new() |
48 | .resolver(move |h: &str| -> io::Result<Vec<SocketAddr>> { |
49 | // Don't override resolution for HTTPS requests yet, since we |
50 | // don't have a setup for an HTTPS testserver. Also, skip localhost |
51 | // resolutions since those may come from a unittest that set up |
52 | // its own, specific testserver. |
53 | if h.ends_with(":443" ) || h.starts_with("localhost:" ) { |
54 | return Ok(h.to_socket_addrs()?.collect::<Vec<_>>()); |
55 | } |
56 | let addr: SocketAddr = format!("127.0.0.1: {}" , testserver.port).parse().unwrap(); |
57 | Ok(vec![addr]) |
58 | }) |
59 | .build() |
60 | } |
61 | |
62 | pub struct TestServer { |
63 | pub port: u16, |
64 | pub done: Arc<AtomicBool>, |
65 | } |
66 | |
67 | pub struct TestHeaders(Vec<String>); |
68 | |
69 | #[allow (dead_code)] |
70 | impl TestHeaders { |
71 | // Return the path for a request, e.g. /foo from "GET /foo HTTP/1.1" |
72 | pub fn path(&self) -> &str { |
73 | if self.0.is_empty() { |
74 | "" |
75 | } else { |
76 | self.0[0].split(' ' ).nth(1).unwrap() |
77 | } |
78 | } |
79 | |
80 | #[cfg (feature = "cookies" )] |
81 | pub fn headers(&self) -> &[String] { |
82 | &self.0[1..] |
83 | } |
84 | } |
85 | |
86 | // Read a stream until reaching a blank line, in order to consume |
87 | // request headers. |
88 | #[cfg (test)] |
89 | pub fn read_request(stream: &TcpStream) -> TestHeaders { |
90 | use std::io::{BufRead, BufReader}; |
91 | |
92 | let mut results = vec![]; |
93 | for line in BufReader::new(stream).lines() { |
94 | match line { |
95 | Err(e) => { |
96 | eprintln!("testserver: in read_request: {}" , e); |
97 | break; |
98 | } |
99 | Ok(line) if line.is_empty() => break, |
100 | Ok(line) => results.push(line), |
101 | }; |
102 | } |
103 | // Consume rest of body. TODO maybe capture the body for inspection in the test? |
104 | // There's a risk stream is ended here, and fill_buf() would block. |
105 | stream.set_nonblocking(true).ok(); |
106 | let mut reader = BufReader::new(stream); |
107 | while let Ok(buf) = reader.fill_buf() { |
108 | let amount = buf.len(); |
109 | if amount == 0 { |
110 | break; |
111 | } |
112 | reader.consume(amount); |
113 | } |
114 | TestHeaders(results) |
115 | } |
116 | |
117 | impl TestServer { |
118 | pub fn new(handler: fn(TcpStream) -> io::Result<()>) -> Self { |
119 | let listener = TcpListener::bind("127.0.0.1:0" ).unwrap(); |
120 | let port = listener.local_addr().unwrap().port(); |
121 | let done = Arc::new(AtomicBool::new(false)); |
122 | let done_clone = done.clone(); |
123 | thread::spawn(move || { |
124 | for stream in listener.incoming() { |
125 | if let Err(e) = stream { |
126 | eprintln!("testserver: handling just-accepted stream: {}" , e); |
127 | break; |
128 | } |
129 | if done.load(Ordering::SeqCst) { |
130 | break; |
131 | } else { |
132 | thread::spawn(move || handler(stream.unwrap())); |
133 | } |
134 | } |
135 | }); |
136 | // before returning from new(), ensure the server is ready to accept connections |
137 | while let Err(e) = TcpStream::connect(format!("127.0.0.1: {}" , port)) { |
138 | match e.kind() { |
139 | io::ErrorKind::ConnectionRefused => { |
140 | std::thread::sleep(Duration::from_millis(100)); |
141 | continue; |
142 | } |
143 | _ => eprintln!("testserver: pre-connect with error {}" , e), |
144 | } |
145 | } |
146 | TestServer { |
147 | port, |
148 | done: done_clone, |
149 | } |
150 | } |
151 | } |
152 | |
153 | impl Drop for TestServer { |
154 | fn drop(&mut self) { |
155 | self.done.store(val:true, order:Ordering::SeqCst); |
156 | // Connect once to unblock the listen loop. |
157 | if let Err(e: Error) = TcpStream::connect(addr:format!("localhost: {}" , self.port)) { |
158 | eprintln!("error dropping testserver: {}" , e); |
159 | } |
160 | } |
161 | } |
162 | |