1#![cfg(not(loom))]
2
3//! A mock type implementing [`AsyncRead`] and [`AsyncWrite`].
4//!
5//!
6//! # Overview
7//!
8//! Provides a type that implements [`AsyncRead`] + [`AsyncWrite`] that can be configured
9//! to handle an arbitrary sequence of read and write operations. This is useful
10//! for writing unit tests for networking services as using an actual network
11//! type is fairly non deterministic.
12//!
13//! # Usage
14//!
15//! Attempting to write data that the mock isn't expecting will result in a
16//! panic.
17//!
18//! [`AsyncRead`]: tokio::io::AsyncRead
19//! [`AsyncWrite`]: tokio::io::AsyncWrite
20
21use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
22use tokio::sync::mpsc;
23use tokio::time::{self, Duration, Instant, Sleep};
24use tokio_stream::wrappers::UnboundedReceiverStream;
25
26use futures_core::{ready, Stream};
27use std::collections::VecDeque;
28use std::fmt;
29use std::future::Future;
30use std::pin::Pin;
31use std::sync::Arc;
32use std::task::{self, Poll, Waker};
33use std::{cmp, io};
34
35/// An I/O object that follows a predefined script.
36///
37/// This value is created by `Builder` and implements `AsyncRead` + `AsyncWrite`. It
38/// follows the scenario described by the builder and panics otherwise.
39#[derive(Debug)]
40pub struct Mock {
41 inner: Inner,
42}
43
44/// A handle to send additional actions to the related `Mock`.
45#[derive(Debug)]
46pub struct Handle {
47 tx: mpsc::UnboundedSender<Action>,
48}
49
50/// Builds `Mock` instances.
51#[derive(Debug, Clone, Default)]
52pub struct Builder {
53 // Sequence of actions for the Mock to take
54 actions: VecDeque<Action>,
55}
56
57#[derive(Debug, Clone)]
58enum Action {
59 Read(Vec<u8>),
60 Write(Vec<u8>),
61 Wait(Duration),
62 // Wrapped in Arc so that Builder can be cloned and Send.
63 // Mock is not cloned as does not need to check Rc for ref counts.
64 ReadError(Option<Arc<io::Error>>),
65 WriteError(Option<Arc<io::Error>>),
66}
67
68struct Inner {
69 actions: VecDeque<Action>,
70 waiting: Option<Instant>,
71 sleep: Option<Pin<Box<Sleep>>>,
72 read_wait: Option<Waker>,
73 rx: UnboundedReceiverStream<Action>,
74}
75
76impl Builder {
77 /// Return a new, empty `Builder`.
78 pub fn new() -> Self {
79 Self::default()
80 }
81
82 /// Sequence a `read` operation.
83 ///
84 /// The next operation in the mock's script will be to expect a `read` call
85 /// and return `buf`.
86 pub fn read(&mut self, buf: &[u8]) -> &mut Self {
87 self.actions.push_back(Action::Read(buf.into()));
88 self
89 }
90
91 /// Sequence a `read` operation that produces an error.
92 ///
93 /// The next operation in the mock's script will be to expect a `read` call
94 /// and return `error`.
95 pub fn read_error(&mut self, error: io::Error) -> &mut Self {
96 let error = Some(error.into());
97 self.actions.push_back(Action::ReadError(error));
98 self
99 }
100
101 /// Sequence a `write` operation.
102 ///
103 /// The next operation in the mock's script will be to expect a `write`
104 /// call.
105 pub fn write(&mut self, buf: &[u8]) -> &mut Self {
106 self.actions.push_back(Action::Write(buf.into()));
107 self
108 }
109
110 /// Sequence a `write` operation that produces an error.
111 ///
112 /// The next operation in the mock's script will be to expect a `write`
113 /// call that provides `error`.
114 pub fn write_error(&mut self, error: io::Error) -> &mut Self {
115 let error = Some(error.into());
116 self.actions.push_back(Action::WriteError(error));
117 self
118 }
119
120 /// Sequence a wait.
121 ///
122 /// The next operation in the mock's script will be to wait without doing so
123 /// for `duration` amount of time.
124 pub fn wait(&mut self, duration: Duration) -> &mut Self {
125 let duration = cmp::max(duration, Duration::from_millis(1));
126 self.actions.push_back(Action::Wait(duration));
127 self
128 }
129
130 /// Build a `Mock` value according to the defined script.
131 pub fn build(&mut self) -> Mock {
132 let (mock, _) = self.build_with_handle();
133 mock
134 }
135
136 /// Build a `Mock` value paired with a handle
137 pub fn build_with_handle(&mut self) -> (Mock, Handle) {
138 let (inner, handle) = Inner::new(self.actions.clone());
139
140 let mock = Mock { inner };
141
142 (mock, handle)
143 }
144}
145
146impl Handle {
147 /// Sequence a `read` operation.
148 ///
149 /// The next operation in the mock's script will be to expect a `read` call
150 /// and return `buf`.
151 pub fn read(&mut self, buf: &[u8]) -> &mut Self {
152 self.tx.send(Action::Read(buf.into())).unwrap();
153 self
154 }
155
156 /// Sequence a `read` operation error.
157 ///
158 /// The next operation in the mock's script will be to expect a `read` call
159 /// and return `error`.
160 pub fn read_error(&mut self, error: io::Error) -> &mut Self {
161 let error = Some(error.into());
162 self.tx.send(Action::ReadError(error)).unwrap();
163 self
164 }
165
166 /// Sequence a `write` operation.
167 ///
168 /// The next operation in the mock's script will be to expect a `write`
169 /// call.
170 pub fn write(&mut self, buf: &[u8]) -> &mut Self {
171 self.tx.send(Action::Write(buf.into())).unwrap();
172 self
173 }
174
175 /// Sequence a `write` operation error.
176 ///
177 /// The next operation in the mock's script will be to expect a `write`
178 /// call error.
179 pub fn write_error(&mut self, error: io::Error) -> &mut Self {
180 let error = Some(error.into());
181 self.tx.send(Action::WriteError(error)).unwrap();
182 self
183 }
184}
185
186impl Inner {
187 fn new(actions: VecDeque<Action>) -> (Inner, Handle) {
188 let (tx, rx) = mpsc::unbounded_channel();
189
190 let rx = UnboundedReceiverStream::new(rx);
191
192 let inner = Inner {
193 actions,
194 sleep: None,
195 read_wait: None,
196 rx,
197 waiting: None,
198 };
199
200 let handle = Handle { tx };
201
202 (inner, handle)
203 }
204
205 fn poll_action(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<Action>> {
206 Pin::new(&mut self.rx).poll_next(cx)
207 }
208
209 fn read(&mut self, dst: &mut ReadBuf<'_>) -> io::Result<()> {
210 match self.action() {
211 Some(&mut Action::Read(ref mut data)) => {
212 // Figure out how much to copy
213 let n = cmp::min(dst.remaining(), data.len());
214
215 // Copy the data into the `dst` slice
216 dst.put_slice(&data[..n]);
217
218 // Drain the data from the source
219 data.drain(..n);
220
221 Ok(())
222 }
223 Some(&mut Action::ReadError(ref mut err)) => {
224 // As the
225 let err = err.take().expect("Should have been removed from actions.");
226 let err = Arc::try_unwrap(err).expect("There are no other references.");
227 Err(err)
228 }
229 Some(_) => {
230 // Either waiting or expecting a write
231 Err(io::ErrorKind::WouldBlock.into())
232 }
233 None => Ok(()),
234 }
235 }
236
237 fn write(&mut self, mut src: &[u8]) -> io::Result<usize> {
238 let mut ret = 0;
239
240 if self.actions.is_empty() {
241 return Err(io::ErrorKind::BrokenPipe.into());
242 }
243
244 if let Some(&mut Action::Wait(..)) = self.action() {
245 return Err(io::ErrorKind::WouldBlock.into());
246 }
247
248 if let Some(&mut Action::WriteError(ref mut err)) = self.action() {
249 let err = err.take().expect("Should have been removed from actions.");
250 let err = Arc::try_unwrap(err).expect("There are no other references.");
251 return Err(err);
252 }
253
254 for i in 0..self.actions.len() {
255 match self.actions[i] {
256 Action::Write(ref mut expect) => {
257 let n = cmp::min(src.len(), expect.len());
258
259 assert_eq!(&src[..n], &expect[..n]);
260
261 // Drop data that was matched
262 expect.drain(..n);
263 src = &src[n..];
264
265 ret += n;
266
267 if src.is_empty() {
268 return Ok(ret);
269 }
270 }
271 Action::Wait(..) | Action::WriteError(..) => {
272 break;
273 }
274 _ => {}
275 }
276
277 // TODO: remove write
278 }
279
280 Ok(ret)
281 }
282
283 fn remaining_wait(&mut self) -> Option<Duration> {
284 match self.action() {
285 Some(&mut Action::Wait(dur)) => Some(dur),
286 _ => None,
287 }
288 }
289
290 fn action(&mut self) -> Option<&mut Action> {
291 loop {
292 if self.actions.is_empty() {
293 return None;
294 }
295
296 match self.actions[0] {
297 Action::Read(ref mut data) => {
298 if !data.is_empty() {
299 break;
300 }
301 }
302 Action::Write(ref mut data) => {
303 if !data.is_empty() {
304 break;
305 }
306 }
307 Action::Wait(ref mut dur) => {
308 if let Some(until) = self.waiting {
309 let now = Instant::now();
310
311 if now < until {
312 break;
313 } else {
314 self.waiting = None;
315 }
316 } else {
317 self.waiting = Some(Instant::now() + *dur);
318 break;
319 }
320 }
321 Action::ReadError(ref mut error) | Action::WriteError(ref mut error) => {
322 if error.is_some() {
323 break;
324 }
325 }
326 }
327
328 let _action = self.actions.pop_front();
329 }
330
331 self.actions.front_mut()
332 }
333}
334
335// ===== impl Inner =====
336
337impl Mock {
338 fn maybe_wakeup_reader(&mut self) {
339 match self.inner.action() {
340 Some(&mut Action::Read(_)) | Some(&mut Action::ReadError(_)) | None => {
341 if let Some(waker) = self.inner.read_wait.take() {
342 waker.wake();
343 }
344 }
345 _ => {}
346 }
347 }
348}
349
350impl AsyncRead for Mock {
351 fn poll_read(
352 mut self: Pin<&mut Self>,
353 cx: &mut task::Context<'_>,
354 buf: &mut ReadBuf<'_>,
355 ) -> Poll<io::Result<()>> {
356 loop {
357 if let Some(ref mut sleep) = self.inner.sleep {
358 ready!(Pin::new(sleep).poll(cx));
359 }
360
361 // If a sleep is set, it has already fired
362 self.inner.sleep = None;
363
364 // Capture 'filled' to monitor if it changed
365 let filled = buf.filled().len();
366
367 match self.inner.read(buf) {
368 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
369 if let Some(rem) = self.inner.remaining_wait() {
370 let until = Instant::now() + rem;
371 self.inner.sleep = Some(Box::pin(time::sleep_until(until)));
372 } else {
373 self.inner.read_wait = Some(cx.waker().clone());
374 return Poll::Pending;
375 }
376 }
377 Ok(()) => {
378 if buf.filled().len() == filled {
379 match ready!(self.inner.poll_action(cx)) {
380 Some(action) => {
381 self.inner.actions.push_back(action);
382 continue;
383 }
384 None => {
385 return Poll::Ready(Ok(()));
386 }
387 }
388 } else {
389 return Poll::Ready(Ok(()));
390 }
391 }
392 Err(e) => return Poll::Ready(Err(e)),
393 }
394 }
395 }
396}
397
398impl AsyncWrite for Mock {
399 fn poll_write(
400 mut self: Pin<&mut Self>,
401 cx: &mut task::Context<'_>,
402 buf: &[u8],
403 ) -> Poll<io::Result<usize>> {
404 loop {
405 if let Some(ref mut sleep) = self.inner.sleep {
406 ready!(Pin::new(sleep).poll(cx));
407 }
408
409 // If a sleep is set, it has already fired
410 self.inner.sleep = None;
411
412 if self.inner.actions.is_empty() {
413 match self.inner.poll_action(cx) {
414 Poll::Pending => {
415 // do not propagate pending
416 }
417 Poll::Ready(Some(action)) => {
418 self.inner.actions.push_back(action);
419 }
420 Poll::Ready(None) => {
421 panic!("unexpected write");
422 }
423 }
424 }
425
426 match self.inner.write(buf) {
427 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
428 if let Some(rem) = self.inner.remaining_wait() {
429 let until = Instant::now() + rem;
430 self.inner.sleep = Some(Box::pin(time::sleep_until(until)));
431 } else {
432 panic!("unexpected WouldBlock");
433 }
434 }
435 Ok(0) => {
436 // TODO: Is this correct?
437 if !self.inner.actions.is_empty() {
438 return Poll::Pending;
439 }
440
441 // TODO: Extract
442 match ready!(self.inner.poll_action(cx)) {
443 Some(action) => {
444 self.inner.actions.push_back(action);
445 continue;
446 }
447 None => {
448 panic!("unexpected write");
449 }
450 }
451 }
452 ret => {
453 self.maybe_wakeup_reader();
454 return Poll::Ready(ret);
455 }
456 }
457 }
458 }
459
460 fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
461 Poll::Ready(Ok(()))
462 }
463
464 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
465 Poll::Ready(Ok(()))
466 }
467}
468
469/// Ensures that Mock isn't dropped with data "inside".
470impl Drop for Mock {
471 fn drop(&mut self) {
472 // Avoid double panicking, since makes debugging much harder.
473 if std::thread::panicking() {
474 return;
475 }
476
477 self.inner.actions.iter().for_each(|a| match a {
478 Action::Read(data) => assert!(data.is_empty(), "There is still data left to read."),
479 Action::Write(data) => assert!(data.is_empty(), "There is still data left to write."),
480 _ => (),
481 });
482 }
483}
484/*
485/// Returns `true` if called from the context of a futures-rs Task
486fn is_task_ctx() -> bool {
487 use std::panic;
488
489 // Save the existing panic hook
490 let h = panic::take_hook();
491
492 // Install a new one that does nothing
493 panic::set_hook(Box::new(|_| {}));
494
495 // Attempt to call the fn
496 let r = panic::catch_unwind(|| task::current()).is_ok();
497
498 // Re-install the old one
499 panic::set_hook(h);
500
501 // Return the result
502 r
503}
504*/
505
506impl fmt::Debug for Inner {
507 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
508 write!(f, "Inner {{...}}")
509 }
510}
511