| 1 | //! Contains utilities for stdout and stderr. |
| 2 | use crate::io::AsyncWrite; |
| 3 | use std::pin::Pin; |
| 4 | use std::task::{Context, Poll}; |
| 5 | /// # Windows |
| 6 | /// [`AsyncWrite`] adapter that finds last char boundary in given buffer and does not write the rest, |
| 7 | /// if buffer contents seems to be `utf8`. Otherwise it only trims buffer down to `DEFAULT_MAX_BUF_SIZE`. |
| 8 | /// That's why, wrapped writer will always receive well-formed utf-8 bytes. |
| 9 | /// # Other platforms |
| 10 | /// Passes data to `inner` as is. |
| 11 | #[derive (Debug)] |
| 12 | pub(crate) struct SplitByUtf8BoundaryIfWindows<W> { |
| 13 | inner: W, |
| 14 | } |
| 15 | |
| 16 | impl<W> SplitByUtf8BoundaryIfWindows<W> { |
| 17 | pub(crate) fn new(inner: W) -> Self { |
| 18 | Self { inner } |
| 19 | } |
| 20 | } |
| 21 | |
| 22 | // this constant is defined by Unicode standard. |
| 23 | const MAX_BYTES_PER_CHAR: usize = 4; |
| 24 | |
| 25 | // Subject for tweaking here |
| 26 | const MAGIC_CONST: usize = 8; |
| 27 | |
| 28 | impl<W> crate::io::AsyncWrite for SplitByUtf8BoundaryIfWindows<W> |
| 29 | where |
| 30 | W: AsyncWrite + Unpin, |
| 31 | { |
| 32 | fn poll_write( |
| 33 | mut self: Pin<&mut Self>, |
| 34 | cx: &mut Context<'_>, |
| 35 | mut buf: &[u8], |
| 36 | ) -> Poll<Result<usize, std::io::Error>> { |
| 37 | // just a closure to avoid repetitive code |
| 38 | let mut call_inner = move |buf| Pin::new(&mut self.inner).poll_write(cx, buf); |
| 39 | |
| 40 | // 1. Only windows stdio can suffer from non-utf8. |
| 41 | // We also check for `test` so that we can write some tests |
| 42 | // for further code. Since `AsyncWrite` can always shrink |
| 43 | // buffer at its discretion, excessive (i.e. in tests) shrinking |
| 44 | // does not break correctness. |
| 45 | // 2. If buffer is small, it will not be shrunk. |
| 46 | // That's why, it's "textness" will not change, so we don't have |
| 47 | // to fixup it. |
| 48 | if cfg!(not(any(target_os = "windows" , test))) |
| 49 | || buf.len() <= crate::io::blocking::DEFAULT_MAX_BUF_SIZE |
| 50 | { |
| 51 | return call_inner(buf); |
| 52 | } |
| 53 | |
| 54 | buf = &buf[..crate::io::blocking::DEFAULT_MAX_BUF_SIZE]; |
| 55 | |
| 56 | // Now there are two possibilities. |
| 57 | // If caller gave is binary buffer, we **should not** shrink it |
| 58 | // anymore, because excessive shrinking hits performance. |
| 59 | // If caller gave as binary buffer, we **must** additionally |
| 60 | // shrink it to strip incomplete char at the end of buffer. |
| 61 | // that's why check we will perform now is allowed to have |
| 62 | // false-positive. |
| 63 | |
| 64 | // Now let's look at the first MAX_BYTES_PER_CHAR * MAGIC_CONST bytes. |
| 65 | // if they are (possibly incomplete) utf8, then we can be quite sure |
| 66 | // that input buffer was utf8. |
| 67 | |
| 68 | let have_to_fix_up = match std::str::from_utf8(&buf[..MAX_BYTES_PER_CHAR * MAGIC_CONST]) { |
| 69 | Ok(_) => true, |
| 70 | Err(err) => { |
| 71 | let incomplete_bytes = MAX_BYTES_PER_CHAR * MAGIC_CONST - err.valid_up_to(); |
| 72 | incomplete_bytes < MAX_BYTES_PER_CHAR |
| 73 | } |
| 74 | }; |
| 75 | |
| 76 | if have_to_fix_up { |
| 77 | // We must pop several bytes at the end which form incomplete |
| 78 | // character. To achieve it, we exploit UTF8 encoding: |
| 79 | // for any code point, all bytes except first start with 0b10 prefix. |
| 80 | // see https://en.wikipedia.org/wiki/UTF-8#Encoding for details |
| 81 | let trailing_incomplete_char_size = buf |
| 82 | .iter() |
| 83 | .rev() |
| 84 | .take(MAX_BYTES_PER_CHAR) |
| 85 | .position(|byte| *byte < 0b1000_0000 || *byte >= 0b1100_0000) |
| 86 | .unwrap_or(0) |
| 87 | + 1; |
| 88 | buf = &buf[..buf.len() - trailing_incomplete_char_size]; |
| 89 | } |
| 90 | |
| 91 | call_inner(buf) |
| 92 | } |
| 93 | |
| 94 | fn poll_flush( |
| 95 | mut self: Pin<&mut Self>, |
| 96 | cx: &mut Context<'_>, |
| 97 | ) -> Poll<Result<(), std::io::Error>> { |
| 98 | Pin::new(&mut self.inner).poll_flush(cx) |
| 99 | } |
| 100 | |
| 101 | fn poll_shutdown( |
| 102 | mut self: Pin<&mut Self>, |
| 103 | cx: &mut Context<'_>, |
| 104 | ) -> Poll<Result<(), std::io::Error>> { |
| 105 | Pin::new(&mut self.inner).poll_shutdown(cx) |
| 106 | } |
| 107 | } |
| 108 | |
| 109 | #[cfg (test)] |
| 110 | #[cfg (not(loom))] |
| 111 | mod tests { |
| 112 | use crate::io::blocking::DEFAULT_MAX_BUF_SIZE; |
| 113 | use crate::io::AsyncWriteExt; |
| 114 | use std::io; |
| 115 | use std::pin::Pin; |
| 116 | use std::task::Context; |
| 117 | use std::task::Poll; |
| 118 | |
| 119 | struct TextMockWriter; |
| 120 | |
| 121 | impl crate::io::AsyncWrite for TextMockWriter { |
| 122 | fn poll_write( |
| 123 | self: Pin<&mut Self>, |
| 124 | _cx: &mut Context<'_>, |
| 125 | buf: &[u8], |
| 126 | ) -> Poll<Result<usize, io::Error>> { |
| 127 | assert!(buf.len() <= DEFAULT_MAX_BUF_SIZE); |
| 128 | assert!(std::str::from_utf8(buf).is_ok()); |
| 129 | Poll::Ready(Ok(buf.len())) |
| 130 | } |
| 131 | |
| 132 | fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { |
| 133 | Poll::Ready(Ok(())) |
| 134 | } |
| 135 | |
| 136 | fn poll_shutdown( |
| 137 | self: Pin<&mut Self>, |
| 138 | _cx: &mut Context<'_>, |
| 139 | ) -> Poll<Result<(), io::Error>> { |
| 140 | Poll::Ready(Ok(())) |
| 141 | } |
| 142 | } |
| 143 | |
| 144 | struct LoggingMockWriter { |
| 145 | write_history: Vec<usize>, |
| 146 | } |
| 147 | |
| 148 | impl LoggingMockWriter { |
| 149 | fn new() -> Self { |
| 150 | LoggingMockWriter { |
| 151 | write_history: Vec::new(), |
| 152 | } |
| 153 | } |
| 154 | } |
| 155 | |
| 156 | impl crate::io::AsyncWrite for LoggingMockWriter { |
| 157 | fn poll_write( |
| 158 | mut self: Pin<&mut Self>, |
| 159 | _cx: &mut Context<'_>, |
| 160 | buf: &[u8], |
| 161 | ) -> Poll<Result<usize, io::Error>> { |
| 162 | assert!(buf.len() <= DEFAULT_MAX_BUF_SIZE); |
| 163 | self.write_history.push(buf.len()); |
| 164 | Poll::Ready(Ok(buf.len())) |
| 165 | } |
| 166 | |
| 167 | fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { |
| 168 | Poll::Ready(Ok(())) |
| 169 | } |
| 170 | |
| 171 | fn poll_shutdown( |
| 172 | self: Pin<&mut Self>, |
| 173 | _cx: &mut Context<'_>, |
| 174 | ) -> Poll<Result<(), io::Error>> { |
| 175 | Poll::Ready(Ok(())) |
| 176 | } |
| 177 | } |
| 178 | |
| 179 | #[test ] |
| 180 | #[cfg_attr (miri, ignore)] |
| 181 | fn test_splitter() { |
| 182 | let data = str::repeat("█" , DEFAULT_MAX_BUF_SIZE); |
| 183 | let mut wr = super::SplitByUtf8BoundaryIfWindows::new(TextMockWriter); |
| 184 | let fut = async move { |
| 185 | wr.write_all(data.as_bytes()).await.unwrap(); |
| 186 | }; |
| 187 | crate::runtime::Builder::new_current_thread() |
| 188 | .build() |
| 189 | .unwrap() |
| 190 | .block_on(fut); |
| 191 | } |
| 192 | |
| 193 | #[test ] |
| 194 | #[cfg_attr (miri, ignore)] |
| 195 | fn test_pseudo_text() { |
| 196 | // In this test we write a piece of binary data, whose beginning is |
| 197 | // text though. We then validate that even in this corner case buffer |
| 198 | // was not shrunk too much. |
| 199 | let checked_count = super::MAGIC_CONST * super::MAX_BYTES_PER_CHAR; |
| 200 | let mut data: Vec<u8> = str::repeat("a" , checked_count).into(); |
| 201 | data.extend(std::iter::repeat(0b1010_1010).take(DEFAULT_MAX_BUF_SIZE - checked_count + 1)); |
| 202 | let mut writer = LoggingMockWriter::new(); |
| 203 | let mut splitter = super::SplitByUtf8BoundaryIfWindows::new(&mut writer); |
| 204 | crate::runtime::Builder::new_current_thread() |
| 205 | .build() |
| 206 | .unwrap() |
| 207 | .block_on(async { |
| 208 | splitter.write_all(&data).await.unwrap(); |
| 209 | }); |
| 210 | // Check that at most two writes were performed |
| 211 | assert!(writer.write_history.len() <= 2); |
| 212 | // Check that all has been written |
| 213 | assert_eq!( |
| 214 | writer.write_history.iter().copied().sum::<usize>(), |
| 215 | data.len() |
| 216 | ); |
| 217 | // Check that at most MAX_BYTES_PER_CHAR + 1 (i.e. 5) bytes were shrunk |
| 218 | // from the buffer: one because it was outside of DEFAULT_MAX_BUF_SIZE boundary, and |
| 219 | // up to one "utf8 code point". |
| 220 | assert!(data.len() - writer.write_history[0] <= super::MAX_BYTES_PER_CHAR + 1); |
| 221 | } |
| 222 | } |
| 223 | |