| 1 | use std::io; |
| 2 | use std::net::ToSocketAddrs; |
| 3 | use std::net::{SocketAddr, TcpListener, TcpStream}; |
| 4 | use std::sync::atomic::{AtomicBool, Ordering}; |
| 5 | use std::sync::Arc; |
| 6 | use std::thread; |
| 7 | use std::time::Duration; |
| 8 | |
| 9 | use crate::{Agent, AgentBuilder}; |
| 10 | |
| 11 | #[cfg (not(feature = "testdeps" ))] |
| 12 | fn test_server_handler(_stream: TcpStream) -> io::Result<()> { |
| 13 | Ok(()) |
| 14 | } |
| 15 | |
| 16 | #[cfg (feature = "testdeps" )] |
| 17 | fn test_server_handler(stream: TcpStream) -> io::Result<()> { |
| 18 | use hootbin::serve_single; |
| 19 | let o = stream.try_clone().expect("TcpStream to be clonable" ); |
| 20 | let i = stream; |
| 21 | match serve_single(i, o, "https://hootbin.test/" ) { |
| 22 | Ok(()) => {} |
| 23 | Err(e) => { |
| 24 | if let hootbin::Error::Io(ioe) = &e { |
| 25 | if ioe.kind() == io::ErrorKind::UnexpectedEof { |
| 26 | // accept this. the pre-connect below is always erroring. |
| 27 | return Ok(()); |
| 28 | } |
| 29 | } |
| 30 | |
| 31 | println!("TestServer error: {:?}" , e); |
| 32 | } |
| 33 | }; |
| 34 | Ok(()) |
| 35 | } |
| 36 | |
| 37 | // An agent to be installed by default for tests and doctests, such |
| 38 | // that all hostnames resolve to a TestServer on localhost. |
| 39 | pub(crate) fn test_agent() -> Agent { |
| 40 | #[cfg (test)] |
| 41 | let _ = env_logger::try_init(); |
| 42 | |
| 43 | let testserver: TestServer = TestServer::new(test_server_handler); |
| 44 | // Slightly tricky thing here: we want to make sure the TestServer lives |
| 45 | // as long as the agent. This is accomplished by `move`ing it into the |
| 46 | // closure, which becomes owned by the agent. |
| 47 | AgentBuilderAgentBuilder::new() |
| 48 | .resolver(move |h: &str| -> io::Result<Vec<SocketAddr>> { |
| 49 | // Don't override resolution for HTTPS requests yet, since we |
| 50 | // don't have a setup for an HTTPS testserver. Also, skip localhost |
| 51 | // resolutions since those may come from a unittest that set up |
| 52 | // its own, specific testserver. |
| 53 | if h.ends_with(":443" ) || h.starts_with("localhost:" ) { |
| 54 | return Ok(h.to_socket_addrs()?.collect::<Vec<_>>()); |
| 55 | } |
| 56 | let addr: SocketAddr = format!("127.0.0.1: {}" , testserver.port).parse().unwrap(); |
| 57 | Ok(vec![addr]) |
| 58 | }) |
| 59 | .build() |
| 60 | } |
| 61 | |
| 62 | pub struct TestServer { |
| 63 | pub port: u16, |
| 64 | pub done: Arc<AtomicBool>, |
| 65 | } |
| 66 | |
| 67 | pub struct TestHeaders(Vec<String>); |
| 68 | |
| 69 | #[allow (dead_code)] |
| 70 | impl TestHeaders { |
| 71 | // Return the path for a request, e.g. /foo from "GET /foo HTTP/1.1" |
| 72 | pub fn path(&self) -> &str { |
| 73 | if self.0.is_empty() { |
| 74 | "" |
| 75 | } else { |
| 76 | self.0[0].split(' ' ).nth(1).unwrap() |
| 77 | } |
| 78 | } |
| 79 | |
| 80 | #[cfg (feature = "cookies" )] |
| 81 | pub fn headers(&self) -> &[String] { |
| 82 | &self.0[1..] |
| 83 | } |
| 84 | } |
| 85 | |
| 86 | // Read a stream until reaching a blank line, in order to consume |
| 87 | // request headers. |
| 88 | #[cfg (test)] |
| 89 | pub fn read_request(stream: &TcpStream) -> TestHeaders { |
| 90 | use std::io::{BufRead, BufReader}; |
| 91 | |
| 92 | let mut results = vec![]; |
| 93 | for line in BufReader::new(stream).lines() { |
| 94 | match line { |
| 95 | Err(e) => { |
| 96 | eprintln!("testserver: in read_request: {}" , e); |
| 97 | break; |
| 98 | } |
| 99 | Ok(line) if line.is_empty() => break, |
| 100 | Ok(line) => results.push(line), |
| 101 | }; |
| 102 | } |
| 103 | // Consume rest of body. TODO maybe capture the body for inspection in the test? |
| 104 | // There's a risk stream is ended here, and fill_buf() would block. |
| 105 | stream.set_nonblocking(true).ok(); |
| 106 | let mut reader = BufReader::new(stream); |
| 107 | while let Ok(buf) = reader.fill_buf() { |
| 108 | let amount = buf.len(); |
| 109 | if amount == 0 { |
| 110 | break; |
| 111 | } |
| 112 | reader.consume(amount); |
| 113 | } |
| 114 | TestHeaders(results) |
| 115 | } |
| 116 | |
| 117 | impl TestServer { |
| 118 | pub fn new(handler: fn(TcpStream) -> io::Result<()>) -> Self { |
| 119 | let listener = TcpListener::bind("127.0.0.1:0" ).unwrap(); |
| 120 | let port = listener.local_addr().unwrap().port(); |
| 121 | let done = Arc::new(AtomicBool::new(false)); |
| 122 | let done_clone = done.clone(); |
| 123 | thread::spawn(move || { |
| 124 | for stream in listener.incoming() { |
| 125 | if let Err(e) = stream { |
| 126 | eprintln!("testserver: handling just-accepted stream: {}" , e); |
| 127 | break; |
| 128 | } |
| 129 | if done.load(Ordering::SeqCst) { |
| 130 | break; |
| 131 | } else { |
| 132 | thread::spawn(move || handler(stream.unwrap())); |
| 133 | } |
| 134 | } |
| 135 | }); |
| 136 | // before returning from new(), ensure the server is ready to accept connections |
| 137 | while let Err(e) = TcpStream::connect(format!("127.0.0.1: {}" , port)) { |
| 138 | match e.kind() { |
| 139 | io::ErrorKind::ConnectionRefused => { |
| 140 | std::thread::sleep(Duration::from_millis(100)); |
| 141 | continue; |
| 142 | } |
| 143 | _ => eprintln!("testserver: pre-connect with error {}" , e), |
| 144 | } |
| 145 | } |
| 146 | TestServer { |
| 147 | port, |
| 148 | done: done_clone, |
| 149 | } |
| 150 | } |
| 151 | } |
| 152 | |
| 153 | impl Drop for TestServer { |
| 154 | fn drop(&mut self) { |
| 155 | self.done.store(val:true, order:Ordering::SeqCst); |
| 156 | // Connect once to unblock the listen loop. |
| 157 | if let Err(e: Error) = TcpStream::connect(addr:format!("localhost: {}" , self.port)) { |
| 158 | eprintln!("error dropping testserver: {}" , e); |
| 159 | } |
| 160 | } |
| 161 | } |
| 162 | |