1 | //! Contains utilities for stdout and stderr. |
2 | use crate::io::AsyncWrite; |
3 | use std::pin::Pin; |
4 | use std::task::{Context, Poll}; |
5 | /// # Windows |
6 | /// [`AsyncWrite`] adapter that finds last char boundary in given buffer and does not write the rest, |
7 | /// if buffer contents seems to be `utf8`. Otherwise it only trims buffer down to `DEFAULT_MAX_BUF_SIZE`. |
8 | /// That's why, wrapped writer will always receive well-formed utf-8 bytes. |
9 | /// # Other platforms |
10 | /// Passes data to `inner` as is. |
11 | #[derive (Debug)] |
12 | pub(crate) struct SplitByUtf8BoundaryIfWindows<W> { |
13 | inner: W, |
14 | } |
15 | |
16 | impl<W> SplitByUtf8BoundaryIfWindows<W> { |
17 | pub(crate) fn new(inner: W) -> Self { |
18 | Self { inner } |
19 | } |
20 | } |
21 | |
22 | // this constant is defined by Unicode standard. |
23 | const MAX_BYTES_PER_CHAR: usize = 4; |
24 | |
25 | // Subject for tweaking here |
26 | const MAGIC_CONST: usize = 8; |
27 | |
28 | impl<W> crate::io::AsyncWrite for SplitByUtf8BoundaryIfWindows<W> |
29 | where |
30 | W: AsyncWrite + Unpin, |
31 | { |
32 | fn poll_write( |
33 | mut self: Pin<&mut Self>, |
34 | cx: &mut Context<'_>, |
35 | mut buf: &[u8], |
36 | ) -> Poll<Result<usize, std::io::Error>> { |
37 | // just a closure to avoid repetitive code |
38 | let mut call_inner = move |buf| Pin::new(&mut self.inner).poll_write(cx, buf); |
39 | |
40 | // 1. Only windows stdio can suffer from non-utf8. |
41 | // We also check for `test` so that we can write some tests |
42 | // for further code. Since `AsyncWrite` can always shrink |
43 | // buffer at its discretion, excessive (i.e. in tests) shrinking |
44 | // does not break correctness. |
45 | // 2. If buffer is small, it will not be shrunk. |
46 | // That's why, it's "textness" will not change, so we don't have |
47 | // to fixup it. |
48 | if cfg!(not(any(target_os = "windows" , test))) |
49 | || buf.len() <= crate::io::blocking::DEFAULT_MAX_BUF_SIZE |
50 | { |
51 | return call_inner(buf); |
52 | } |
53 | |
54 | buf = &buf[..crate::io::blocking::DEFAULT_MAX_BUF_SIZE]; |
55 | |
56 | // Now there are two possibilities. |
57 | // If caller gave is binary buffer, we **should not** shrink it |
58 | // anymore, because excessive shrinking hits performance. |
59 | // If caller gave as binary buffer, we **must** additionally |
60 | // shrink it to strip incomplete char at the end of buffer. |
61 | // that's why check we will perform now is allowed to have |
62 | // false-positive. |
63 | |
64 | // Now let's look at the first MAX_BYTES_PER_CHAR * MAGIC_CONST bytes. |
65 | // if they are (possibly incomplete) utf8, then we can be quite sure |
66 | // that input buffer was utf8. |
67 | |
68 | let have_to_fix_up = match std::str::from_utf8(&buf[..MAX_BYTES_PER_CHAR * MAGIC_CONST]) { |
69 | Ok(_) => true, |
70 | Err(err) => { |
71 | let incomplete_bytes = MAX_BYTES_PER_CHAR * MAGIC_CONST - err.valid_up_to(); |
72 | incomplete_bytes < MAX_BYTES_PER_CHAR |
73 | } |
74 | }; |
75 | |
76 | if have_to_fix_up { |
77 | // We must pop several bytes at the end which form incomplete |
78 | // character. To achieve it, we exploit UTF8 encoding: |
79 | // for any code point, all bytes except first start with 0b10 prefix. |
80 | // see https://en.wikipedia.org/wiki/UTF-8#Encoding for details |
81 | let trailing_incomplete_char_size = buf |
82 | .iter() |
83 | .rev() |
84 | .take(MAX_BYTES_PER_CHAR) |
85 | .position(|byte| *byte < 0b1000_0000 || *byte >= 0b1100_0000) |
86 | .unwrap_or(0) |
87 | + 1; |
88 | buf = &buf[..buf.len() - trailing_incomplete_char_size]; |
89 | } |
90 | |
91 | call_inner(buf) |
92 | } |
93 | |
94 | fn poll_flush( |
95 | mut self: Pin<&mut Self>, |
96 | cx: &mut Context<'_>, |
97 | ) -> Poll<Result<(), std::io::Error>> { |
98 | Pin::new(&mut self.inner).poll_flush(cx) |
99 | } |
100 | |
101 | fn poll_shutdown( |
102 | mut self: Pin<&mut Self>, |
103 | cx: &mut Context<'_>, |
104 | ) -> Poll<Result<(), std::io::Error>> { |
105 | Pin::new(&mut self.inner).poll_shutdown(cx) |
106 | } |
107 | } |
108 | |
109 | #[cfg (test)] |
110 | #[cfg (not(loom))] |
111 | mod tests { |
112 | use crate::io::blocking::DEFAULT_MAX_BUF_SIZE; |
113 | use crate::io::AsyncWriteExt; |
114 | use std::io; |
115 | use std::pin::Pin; |
116 | use std::task::Context; |
117 | use std::task::Poll; |
118 | |
119 | struct TextMockWriter; |
120 | |
121 | impl crate::io::AsyncWrite for TextMockWriter { |
122 | fn poll_write( |
123 | self: Pin<&mut Self>, |
124 | _cx: &mut Context<'_>, |
125 | buf: &[u8], |
126 | ) -> Poll<Result<usize, io::Error>> { |
127 | assert!(buf.len() <= DEFAULT_MAX_BUF_SIZE); |
128 | assert!(std::str::from_utf8(buf).is_ok()); |
129 | Poll::Ready(Ok(buf.len())) |
130 | } |
131 | |
132 | fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { |
133 | Poll::Ready(Ok(())) |
134 | } |
135 | |
136 | fn poll_shutdown( |
137 | self: Pin<&mut Self>, |
138 | _cx: &mut Context<'_>, |
139 | ) -> Poll<Result<(), io::Error>> { |
140 | Poll::Ready(Ok(())) |
141 | } |
142 | } |
143 | |
144 | struct LoggingMockWriter { |
145 | write_history: Vec<usize>, |
146 | } |
147 | |
148 | impl LoggingMockWriter { |
149 | fn new() -> Self { |
150 | LoggingMockWriter { |
151 | write_history: Vec::new(), |
152 | } |
153 | } |
154 | } |
155 | |
156 | impl crate::io::AsyncWrite for LoggingMockWriter { |
157 | fn poll_write( |
158 | mut self: Pin<&mut Self>, |
159 | _cx: &mut Context<'_>, |
160 | buf: &[u8], |
161 | ) -> Poll<Result<usize, io::Error>> { |
162 | assert!(buf.len() <= DEFAULT_MAX_BUF_SIZE); |
163 | self.write_history.push(buf.len()); |
164 | Poll::Ready(Ok(buf.len())) |
165 | } |
166 | |
167 | fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { |
168 | Poll::Ready(Ok(())) |
169 | } |
170 | |
171 | fn poll_shutdown( |
172 | self: Pin<&mut Self>, |
173 | _cx: &mut Context<'_>, |
174 | ) -> Poll<Result<(), io::Error>> { |
175 | Poll::Ready(Ok(())) |
176 | } |
177 | } |
178 | |
179 | #[test ] |
180 | #[cfg_attr (miri, ignore)] |
181 | fn test_splitter() { |
182 | let data = str::repeat("█" , DEFAULT_MAX_BUF_SIZE); |
183 | let mut wr = super::SplitByUtf8BoundaryIfWindows::new(TextMockWriter); |
184 | let fut = async move { |
185 | wr.write_all(data.as_bytes()).await.unwrap(); |
186 | }; |
187 | crate::runtime::Builder::new_current_thread() |
188 | .build() |
189 | .unwrap() |
190 | .block_on(fut); |
191 | } |
192 | |
193 | #[test ] |
194 | #[cfg_attr (miri, ignore)] |
195 | fn test_pseudo_text() { |
196 | // In this test we write a piece of binary data, whose beginning is |
197 | // text though. We then validate that even in this corner case buffer |
198 | // was not shrunk too much. |
199 | let checked_count = super::MAGIC_CONST * super::MAX_BYTES_PER_CHAR; |
200 | let mut data: Vec<u8> = str::repeat("a" , checked_count).into(); |
201 | data.extend(std::iter::repeat(0b1010_1010).take(DEFAULT_MAX_BUF_SIZE - checked_count + 1)); |
202 | let mut writer = LoggingMockWriter::new(); |
203 | let mut splitter = super::SplitByUtf8BoundaryIfWindows::new(&mut writer); |
204 | crate::runtime::Builder::new_current_thread() |
205 | .build() |
206 | .unwrap() |
207 | .block_on(async { |
208 | splitter.write_all(&data).await.unwrap(); |
209 | }); |
210 | // Check that at most two writes were performed |
211 | assert!(writer.write_history.len() <= 2); |
212 | // Check that all has been written |
213 | assert_eq!( |
214 | writer.write_history.iter().copied().sum::<usize>(), |
215 | data.len() |
216 | ); |
217 | // Check that at most MAX_BYTES_PER_CHAR + 1 (i.e. 5) bytes were shrunk |
218 | // from the buffer: one because it was outside of DEFAULT_MAX_BUF_SIZE boundary, and |
219 | // up to one "utf8 code point". |
220 | assert!(data.len() - writer.write_history[0] <= super::MAX_BYTES_PER_CHAR + 1); |
221 | } |
222 | } |
223 | |