| 1 | //! Module defining an Either type. |
| 2 | use std::{ |
| 3 | future::Future, |
| 4 | io::SeekFrom, |
| 5 | pin::Pin, |
| 6 | task::{Context, Poll}, |
| 7 | }; |
| 8 | use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf, Result}; |
| 9 | |
| 10 | /// Combines two different futures, streams, or sinks having the same associated types into a single type. |
| 11 | /// |
| 12 | /// This type implements common asynchronous traits such as [`Future`] and those in Tokio. |
| 13 | /// |
| 14 | /// [`Future`]: std::future::Future |
| 15 | /// |
| 16 | /// # Example |
| 17 | /// |
| 18 | /// The following code will not work: |
| 19 | /// |
| 20 | /// ```compile_fail |
| 21 | /// # fn some_condition() -> bool { true } |
| 22 | /// # async fn some_async_function() -> u32 { 10 } |
| 23 | /// # async fn other_async_function() -> u32 { 20 } |
| 24 | /// #[tokio::main] |
| 25 | /// async fn main() { |
| 26 | /// let result = if some_condition() { |
| 27 | /// some_async_function() |
| 28 | /// } else { |
| 29 | /// other_async_function() // <- Will print: "`if` and `else` have incompatible types" |
| 30 | /// }; |
| 31 | /// |
| 32 | /// println!("Result is {}" , result.await); |
| 33 | /// } |
| 34 | /// ``` |
| 35 | /// |
| 36 | // This is because although the output types for both futures is the same, the exact future |
| 37 | // types are different, but the compiler must be able to choose a single type for the |
| 38 | // `result` variable. |
| 39 | /// |
| 40 | /// When the output type is the same, we can wrap each future in `Either` to avoid the |
| 41 | /// issue: |
| 42 | /// |
| 43 | /// ``` |
| 44 | /// use tokio_util::either::Either; |
| 45 | /// # fn some_condition() -> bool { true } |
| 46 | /// # async fn some_async_function() -> u32 { 10 } |
| 47 | /// # async fn other_async_function() -> u32 { 20 } |
| 48 | /// |
| 49 | /// #[tokio::main] |
| 50 | /// async fn main() { |
| 51 | /// let result = if some_condition() { |
| 52 | /// Either::Left(some_async_function()) |
| 53 | /// } else { |
| 54 | /// Either::Right(other_async_function()) |
| 55 | /// }; |
| 56 | /// |
| 57 | /// let value = result.await; |
| 58 | /// println!("Result is {}" , value); |
| 59 | /// # assert_eq!(value, 10); |
| 60 | /// } |
| 61 | /// ``` |
| 62 | #[allow (missing_docs)] // Doc-comments for variants in this particular case don't make much sense. |
| 63 | #[derive (Debug, Clone)] |
| 64 | pub enum Either<L, R> { |
| 65 | Left(L), |
| 66 | Right(R), |
| 67 | } |
| 68 | |
| 69 | /// A small helper macro which reduces amount of boilerplate in the actual trait method implementation. |
| 70 | /// It takes an invocation of method as an argument (e.g. `self.poll(cx)`), and redirects it to either |
| 71 | /// enum variant held in `self`. |
| 72 | macro_rules! delegate_call { |
| 73 | ($self:ident.$method:ident($($args:ident),+)) => { |
| 74 | unsafe { |
| 75 | match $self.get_unchecked_mut() { |
| 76 | Self::Left(l) => Pin::new_unchecked(l).$method($($args),+), |
| 77 | Self::Right(r) => Pin::new_unchecked(r).$method($($args),+), |
| 78 | } |
| 79 | } |
| 80 | } |
| 81 | } |
| 82 | |
| 83 | impl<L, R, O> Future for Either<L, R> |
| 84 | where |
| 85 | L: Future<Output = O>, |
| 86 | R: Future<Output = O>, |
| 87 | { |
| 88 | type Output = O; |
| 89 | |
| 90 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
| 91 | delegate_call!(self.poll(cx)) |
| 92 | } |
| 93 | } |
| 94 | |
| 95 | impl<L, R> AsyncRead for Either<L, R> |
| 96 | where |
| 97 | L: AsyncRead, |
| 98 | R: AsyncRead, |
| 99 | { |
| 100 | fn poll_read( |
| 101 | self: Pin<&mut Self>, |
| 102 | cx: &mut Context<'_>, |
| 103 | buf: &mut ReadBuf<'_>, |
| 104 | ) -> Poll<Result<()>> { |
| 105 | delegate_call!(self.poll_read(cx, buf)) |
| 106 | } |
| 107 | } |
| 108 | |
| 109 | impl<L, R> AsyncBufRead for Either<L, R> |
| 110 | where |
| 111 | L: AsyncBufRead, |
| 112 | R: AsyncBufRead, |
| 113 | { |
| 114 | fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<&[u8]>> { |
| 115 | delegate_call!(self.poll_fill_buf(cx)) |
| 116 | } |
| 117 | |
| 118 | fn consume(self: Pin<&mut Self>, amt: usize) { |
| 119 | delegate_call!(self.consume(amt)); |
| 120 | } |
| 121 | } |
| 122 | |
| 123 | impl<L, R> AsyncSeek for Either<L, R> |
| 124 | where |
| 125 | L: AsyncSeek, |
| 126 | R: AsyncSeek, |
| 127 | { |
| 128 | fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> Result<()> { |
| 129 | delegate_call!(self.start_seek(position)) |
| 130 | } |
| 131 | |
| 132 | fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64>> { |
| 133 | delegate_call!(self.poll_complete(cx)) |
| 134 | } |
| 135 | } |
| 136 | |
| 137 | impl<L, R> AsyncWrite for Either<L, R> |
| 138 | where |
| 139 | L: AsyncWrite, |
| 140 | R: AsyncWrite, |
| 141 | { |
| 142 | fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> { |
| 143 | delegate_call!(self.poll_write(cx, buf)) |
| 144 | } |
| 145 | |
| 146 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> { |
| 147 | delegate_call!(self.poll_flush(cx)) |
| 148 | } |
| 149 | |
| 150 | fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> { |
| 151 | delegate_call!(self.poll_shutdown(cx)) |
| 152 | } |
| 153 | |
| 154 | fn poll_write_vectored( |
| 155 | self: Pin<&mut Self>, |
| 156 | cx: &mut Context<'_>, |
| 157 | bufs: &[std::io::IoSlice<'_>], |
| 158 | ) -> Poll<std::result::Result<usize, std::io::Error>> { |
| 159 | delegate_call!(self.poll_write_vectored(cx, bufs)) |
| 160 | } |
| 161 | |
| 162 | fn is_write_vectored(&self) -> bool { |
| 163 | match self { |
| 164 | Self::Left(l) => l.is_write_vectored(), |
| 165 | Self::Right(r) => r.is_write_vectored(), |
| 166 | } |
| 167 | } |
| 168 | } |
| 169 | |
| 170 | impl<L, R> futures_core::stream::Stream for Either<L, R> |
| 171 | where |
| 172 | L: futures_core::stream::Stream, |
| 173 | R: futures_core::stream::Stream<Item = L::Item>, |
| 174 | { |
| 175 | type Item = L::Item; |
| 176 | |
| 177 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
| 178 | delegate_call!(self.poll_next(cx)) |
| 179 | } |
| 180 | } |
| 181 | |
| 182 | impl<L, R, Item, Error> futures_sink::Sink<Item> for Either<L, R> |
| 183 | where |
| 184 | L: futures_sink::Sink<Item, Error = Error>, |
| 185 | R: futures_sink::Sink<Item, Error = Error>, |
| 186 | { |
| 187 | type Error = Error; |
| 188 | |
| 189 | fn poll_ready( |
| 190 | self: Pin<&mut Self>, |
| 191 | cx: &mut Context<'_>, |
| 192 | ) -> Poll<std::result::Result<(), Self::Error>> { |
| 193 | delegate_call!(self.poll_ready(cx)) |
| 194 | } |
| 195 | |
| 196 | fn start_send(self: Pin<&mut Self>, item: Item) -> std::result::Result<(), Self::Error> { |
| 197 | delegate_call!(self.start_send(item)) |
| 198 | } |
| 199 | |
| 200 | fn poll_flush( |
| 201 | self: Pin<&mut Self>, |
| 202 | cx: &mut Context<'_>, |
| 203 | ) -> Poll<std::result::Result<(), Self::Error>> { |
| 204 | delegate_call!(self.poll_flush(cx)) |
| 205 | } |
| 206 | |
| 207 | fn poll_close( |
| 208 | self: Pin<&mut Self>, |
| 209 | cx: &mut Context<'_>, |
| 210 | ) -> Poll<std::result::Result<(), Self::Error>> { |
| 211 | delegate_call!(self.poll_close(cx)) |
| 212 | } |
| 213 | } |
| 214 | |
| 215 | #[cfg (test)] |
| 216 | mod tests { |
| 217 | use super::*; |
| 218 | use tokio::io::{repeat, AsyncReadExt, Repeat}; |
| 219 | use tokio_stream::{once, Once, StreamExt}; |
| 220 | |
| 221 | #[tokio::test ] |
| 222 | async fn either_is_stream() { |
| 223 | let mut either: Either<Once<u32>, Once<u32>> = Either::Left(once(1)); |
| 224 | |
| 225 | assert_eq!(Some(1u32), either.next().await); |
| 226 | } |
| 227 | |
| 228 | #[tokio::test ] |
| 229 | async fn either_is_async_read() { |
| 230 | let mut buffer = [0; 3]; |
| 231 | let mut either: Either<Repeat, Repeat> = Either::Right(repeat(0b101)); |
| 232 | |
| 233 | either.read_exact(&mut buffer).await.unwrap(); |
| 234 | assert_eq!(buffer, [0b101, 0b101, 0b101]); |
| 235 | } |
| 236 | } |
| 237 | |