1use futures_io::AsyncWrite;
2use futures_sink::Sink;
3use std::{
4 io::{self, IoSlice},
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9/// Async wrapper that tracks whether it has been closed.
10///
11/// See the `track_closed` methods on:
12/// * [`SinkTestExt`](crate::sink::SinkTestExt::track_closed)
13/// * [`AsyncWriteTestExt`](crate::io::AsyncWriteTestExt::track_closed)
14#[pin_project::pin_project]
15#[derive(Debug)]
16pub struct TrackClosed<T> {
17 #[pin]
18 inner: T,
19 closed: bool,
20}
21
22impl<T> TrackClosed<T> {
23 pub(crate) fn new(inner: T) -> Self {
24 Self { inner, closed: false }
25 }
26
27 /// Check whether this object has been closed.
28 pub fn is_closed(&self) -> bool {
29 self.closed
30 }
31
32 /// Acquires a reference to the underlying object that this adaptor is
33 /// wrapping.
34 pub fn get_ref(&self) -> &T {
35 &self.inner
36 }
37
38 /// Acquires a mutable reference to the underlying object that this
39 /// adaptor is wrapping.
40 pub fn get_mut(&mut self) -> &mut T {
41 &mut self.inner
42 }
43
44 /// Acquires a pinned mutable reference to the underlying object that
45 /// this adaptor is wrapping.
46 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
47 self.project().inner
48 }
49
50 /// Consumes this adaptor returning the underlying object.
51 pub fn into_inner(self) -> T {
52 self.inner
53 }
54}
55
56impl<T: AsyncWrite> AsyncWrite for TrackClosed<T> {
57 fn poll_write(
58 self: Pin<&mut Self>,
59 cx: &mut Context<'_>,
60 buf: &[u8],
61 ) -> Poll<io::Result<usize>> {
62 if self.is_closed() {
63 return Poll::Ready(Err(io::Error::new(
64 io::ErrorKind::Other,
65 "Attempted to write after stream was closed",
66 )));
67 }
68 self.project().inner.poll_write(cx, buf)
69 }
70
71 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
72 if self.is_closed() {
73 return Poll::Ready(Err(io::Error::new(
74 io::ErrorKind::Other,
75 "Attempted to flush after stream was closed",
76 )));
77 }
78 assert!(!self.is_closed());
79 self.project().inner.poll_flush(cx)
80 }
81
82 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
83 if self.is_closed() {
84 return Poll::Ready(Err(io::Error::new(
85 io::ErrorKind::Other,
86 "Attempted to close after stream was closed",
87 )));
88 }
89 let this = self.project();
90 match this.inner.poll_close(cx) {
91 Poll::Ready(Ok(())) => {
92 *this.closed = true;
93 Poll::Ready(Ok(()))
94 }
95 other => other,
96 }
97 }
98
99 fn poll_write_vectored(
100 self: Pin<&mut Self>,
101 cx: &mut Context<'_>,
102 bufs: &[IoSlice<'_>],
103 ) -> Poll<io::Result<usize>> {
104 if self.is_closed() {
105 return Poll::Ready(Err(io::Error::new(
106 io::ErrorKind::Other,
107 "Attempted to write after stream was closed",
108 )));
109 }
110 self.project().inner.poll_write_vectored(cx, bufs)
111 }
112}
113
114impl<Item, T: Sink<Item>> Sink<Item> for TrackClosed<T> {
115 type Error = T::Error;
116
117 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
118 assert!(!self.is_closed());
119 self.project().inner.poll_ready(cx)
120 }
121
122 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
123 assert!(!self.is_closed());
124 self.project().inner.start_send(item)
125 }
126
127 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
128 assert!(!self.is_closed());
129 self.project().inner.poll_flush(cx)
130 }
131
132 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
133 assert!(!self.is_closed());
134 let this = self.project();
135 match this.inner.poll_close(cx) {
136 Poll::Ready(Ok(())) => {
137 *this.closed = true;
138 Poll::Ready(Ok(()))
139 }
140 other => other,
141 }
142 }
143}
144