1 | use futures_io::AsyncWrite; |
2 | use futures_sink::Sink; |
3 | use std::{ |
4 | io::{self, IoSlice}, |
5 | pin::Pin, |
6 | task::{Context, Poll}, |
7 | }; |
8 | |
9 | /// Async wrapper that tracks whether it has been closed. |
10 | /// |
11 | /// See the `track_closed` methods on: |
12 | /// * [`SinkTestExt`](crate::sink::SinkTestExt::track_closed) |
13 | /// * [`AsyncWriteTestExt`](crate::io::AsyncWriteTestExt::track_closed) |
14 | #[pin_project::pin_project ] |
15 | #[derive(Debug)] |
16 | pub struct TrackClosed<T> { |
17 | #[pin] |
18 | inner: T, |
19 | closed: bool, |
20 | } |
21 | |
22 | impl<T> TrackClosed<T> { |
23 | pub(crate) fn new(inner: T) -> Self { |
24 | Self { inner, closed: false } |
25 | } |
26 | |
27 | /// Check whether this object has been closed. |
28 | pub fn is_closed(&self) -> bool { |
29 | self.closed |
30 | } |
31 | |
32 | /// Acquires a reference to the underlying object that this adaptor is |
33 | /// wrapping. |
34 | pub fn get_ref(&self) -> &T { |
35 | &self.inner |
36 | } |
37 | |
38 | /// Acquires a mutable reference to the underlying object that this |
39 | /// adaptor is wrapping. |
40 | pub fn get_mut(&mut self) -> &mut T { |
41 | &mut self.inner |
42 | } |
43 | |
44 | /// Acquires a pinned mutable reference to the underlying object that |
45 | /// this adaptor is wrapping. |
46 | pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { |
47 | self.project().inner |
48 | } |
49 | |
50 | /// Consumes this adaptor returning the underlying object. |
51 | pub fn into_inner(self) -> T { |
52 | self.inner |
53 | } |
54 | } |
55 | |
56 | impl<T: AsyncWrite> AsyncWrite for TrackClosed<T> { |
57 | fn poll_write( |
58 | self: Pin<&mut Self>, |
59 | cx: &mut Context<'_>, |
60 | buf: &[u8], |
61 | ) -> Poll<io::Result<usize>> { |
62 | if self.is_closed() { |
63 | return Poll::Ready(Err(io::Error::new( |
64 | io::ErrorKind::Other, |
65 | "Attempted to write after stream was closed" , |
66 | ))); |
67 | } |
68 | self.project().inner.poll_write(cx, buf) |
69 | } |
70 | |
71 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
72 | if self.is_closed() { |
73 | return Poll::Ready(Err(io::Error::new( |
74 | io::ErrorKind::Other, |
75 | "Attempted to flush after stream was closed" , |
76 | ))); |
77 | } |
78 | assert!(!self.is_closed()); |
79 | self.project().inner.poll_flush(cx) |
80 | } |
81 | |
82 | fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
83 | if self.is_closed() { |
84 | return Poll::Ready(Err(io::Error::new( |
85 | io::ErrorKind::Other, |
86 | "Attempted to close after stream was closed" , |
87 | ))); |
88 | } |
89 | let this = self.project(); |
90 | match this.inner.poll_close(cx) { |
91 | Poll::Ready(Ok(())) => { |
92 | *this.closed = true; |
93 | Poll::Ready(Ok(())) |
94 | } |
95 | other => other, |
96 | } |
97 | } |
98 | |
99 | fn poll_write_vectored( |
100 | self: Pin<&mut Self>, |
101 | cx: &mut Context<'_>, |
102 | bufs: &[IoSlice<'_>], |
103 | ) -> Poll<io::Result<usize>> { |
104 | if self.is_closed() { |
105 | return Poll::Ready(Err(io::Error::new( |
106 | io::ErrorKind::Other, |
107 | "Attempted to write after stream was closed" , |
108 | ))); |
109 | } |
110 | self.project().inner.poll_write_vectored(cx, bufs) |
111 | } |
112 | } |
113 | |
114 | impl<Item, T: Sink<Item>> Sink<Item> for TrackClosed<T> { |
115 | type Error = T::Error; |
116 | |
117 | fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
118 | assert!(!self.is_closed()); |
119 | self.project().inner.poll_ready(cx) |
120 | } |
121 | |
122 | fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> { |
123 | assert!(!self.is_closed()); |
124 | self.project().inner.start_send(item) |
125 | } |
126 | |
127 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
128 | assert!(!self.is_closed()); |
129 | self.project().inner.poll_flush(cx) |
130 | } |
131 | |
132 | fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
133 | assert!(!self.is_closed()); |
134 | let this = self.project(); |
135 | match this.inner.poll_close(cx) { |
136 | Poll::Ready(Ok(())) => { |
137 | *this.closed = true; |
138 | Poll::Ready(Ok(())) |
139 | } |
140 | other => other, |
141 | } |
142 | } |
143 | } |
144 | |