1use crate::io::{AsyncBufRead, AsyncRead, ReadBuf};
2
3use pin_project_lite::pin_project;
4use std::fmt;
5use std::io;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9pin_project! {
10 /// Stream for the [`chain`](super::AsyncReadExt::chain) method.
11 #[must_use = "streams do nothing unless polled"]
12 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
13 pub struct Chain<T, U> {
14 #[pin]
15 first: T,
16 #[pin]
17 second: U,
18 done_first: bool,
19 }
20}
21
22pub(super) fn chain<T, U>(first: T, second: U) -> Chain<T, U>
23where
24 T: AsyncRead,
25 U: AsyncRead,
26{
27 Chain {
28 first,
29 second,
30 done_first: false,
31 }
32}
33
34impl<T, U> Chain<T, U>
35where
36 T: AsyncRead,
37 U: AsyncRead,
38{
39 /// Gets references to the underlying readers in this `Chain`.
40 pub fn get_ref(&self) -> (&T, &U) {
41 (&self.first, &self.second)
42 }
43
44 /// Gets mutable references to the underlying readers in this `Chain`.
45 ///
46 /// Care should be taken to avoid modifying the internal I/O state of the
47 /// underlying readers as doing so may corrupt the internal state of this
48 /// `Chain`.
49 pub fn get_mut(&mut self) -> (&mut T, &mut U) {
50 (&mut self.first, &mut self.second)
51 }
52
53 /// Gets pinned mutable references to the underlying readers in this `Chain`.
54 ///
55 /// Care should be taken to avoid modifying the internal I/O state of the
56 /// underlying readers as doing so may corrupt the internal state of this
57 /// `Chain`.
58 pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>) {
59 let me = self.project();
60 (me.first, me.second)
61 }
62
63 /// Consumes the `Chain`, returning the wrapped readers.
64 pub fn into_inner(self) -> (T, U) {
65 (self.first, self.second)
66 }
67}
68
69impl<T, U> fmt::Debug for Chain<T, U>
70where
71 T: fmt::Debug,
72 U: fmt::Debug,
73{
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 f&mut DebugStruct<'_, '_>.debug_struct("Chain")
76 .field("t", &self.first)
77 .field(name:"u", &self.second)
78 .finish()
79 }
80}
81
82impl<T, U> AsyncRead for Chain<T, U>
83where
84 T: AsyncRead,
85 U: AsyncRead,
86{
87 fn poll_read(
88 self: Pin<&mut Self>,
89 cx: &mut Context<'_>,
90 buf: &mut ReadBuf<'_>,
91 ) -> Poll<io::Result<()>> {
92 let me: Projection<'_, T, U> = self.project();
93
94 if !*me.done_first {
95 let rem: usize = buf.remaining();
96 ready!(me.first.poll_read(cx, buf))?;
97 if buf.remaining() == rem {
98 *me.done_first = true;
99 } else {
100 return Poll::Ready(Ok(()));
101 }
102 }
103 me.second.poll_read(cx, buf)
104 }
105}
106
107impl<T, U> AsyncBufRead for Chain<T, U>
108where
109 T: AsyncBufRead,
110 U: AsyncBufRead,
111{
112 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
113 let me: Projection<'_, T, U> = self.project();
114
115 if !*me.done_first {
116 match ready!(me.first.poll_fill_buf(cx)?) {
117 [] => {
118 *me.done_first = true;
119 }
120 buf: &[u8] => return Poll::Ready(Ok(buf)),
121 }
122 }
123 me.second.poll_fill_buf(cx)
124 }
125
126 fn consume(self: Pin<&mut Self>, amt: usize) {
127 let me: Projection<'_, T, U> = self.project();
128 if !*me.done_first {
129 me.first.consume(amt)
130 } else {
131 me.second.consume(amt)
132 }
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[test]
141 fn assert_unpin() {
142 crate::is_unpin::<Chain<(), ()>>();
143 }
144}
145