1 | //! Contains utilities for stdout and stderr. |
2 | use crate::io::AsyncWrite; |
3 | use std::pin::Pin; |
4 | use std::task::{Context, Poll}; |
5 | /// # Windows |
6 | /// [`AsyncWrite`] adapter that finds last char boundary in given buffer and does not write the rest, |
7 | /// if buffer contents seems to be `utf8`. Otherwise it only trims buffer down to `MAX_BUF`. |
8 | /// That's why, wrapped writer will always receive well-formed utf-8 bytes. |
9 | /// # Other platforms |
10 | /// Passes data to `inner` as is. |
11 | #[derive (Debug)] |
12 | pub(crate) struct SplitByUtf8BoundaryIfWindows<W> { |
13 | inner: W, |
14 | } |
15 | |
16 | impl<W> SplitByUtf8BoundaryIfWindows<W> { |
17 | pub(crate) fn new(inner: W) -> Self { |
18 | Self { inner } |
19 | } |
20 | } |
21 | |
22 | // this constant is defined by Unicode standard. |
23 | const MAX_BYTES_PER_CHAR: usize = 4; |
24 | |
25 | // Subject for tweaking here |
26 | const MAGIC_CONST: usize = 8; |
27 | |
28 | impl<W> crate::io::AsyncWrite for SplitByUtf8BoundaryIfWindows<W> |
29 | where |
30 | W: AsyncWrite + Unpin, |
31 | { |
32 | fn poll_write( |
33 | mut self: Pin<&mut Self>, |
34 | cx: &mut Context<'_>, |
35 | mut buf: &[u8], |
36 | ) -> Poll<Result<usize, std::io::Error>> { |
37 | // just a closure to avoid repetitive code |
38 | let mut call_inner = move |buf| Pin::new(&mut self.inner).poll_write(cx, buf); |
39 | |
40 | // 1. Only windows stdio can suffer from non-utf8. |
41 | // We also check for `test` so that we can write some tests |
42 | // for further code. Since `AsyncWrite` can always shrink |
43 | // buffer at its discretion, excessive (i.e. in tests) shrinking |
44 | // does not break correctness. |
45 | // 2. If buffer is small, it will not be shrunk. |
46 | // That's why, it's "textness" will not change, so we don't have |
47 | // to fixup it. |
48 | if cfg!(not(any(target_os = "windows" , test))) || buf.len() <= crate::io::blocking::MAX_BUF |
49 | { |
50 | return call_inner(buf); |
51 | } |
52 | |
53 | buf = &buf[..crate::io::blocking::MAX_BUF]; |
54 | |
55 | // Now there are two possibilities. |
56 | // If caller gave is binary buffer, we **should not** shrink it |
57 | // anymore, because excessive shrinking hits performance. |
58 | // If caller gave as binary buffer, we **must** additionally |
59 | // shrink it to strip incomplete char at the end of buffer. |
60 | // that's why check we will perform now is allowed to have |
61 | // false-positive. |
62 | |
63 | // Now let's look at the first MAX_BYTES_PER_CHAR * MAGIC_CONST bytes. |
64 | // if they are (possibly incomplete) utf8, then we can be quite sure |
65 | // that input buffer was utf8. |
66 | |
67 | let have_to_fix_up = match std::str::from_utf8(&buf[..MAX_BYTES_PER_CHAR * MAGIC_CONST]) { |
68 | Ok(_) => true, |
69 | Err(err) => { |
70 | let incomplete_bytes = MAX_BYTES_PER_CHAR * MAGIC_CONST - err.valid_up_to(); |
71 | incomplete_bytes < MAX_BYTES_PER_CHAR |
72 | } |
73 | }; |
74 | |
75 | if have_to_fix_up { |
76 | // We must pop several bytes at the end which form incomplete |
77 | // character. To achieve it, we exploit UTF8 encoding: |
78 | // for any code point, all bytes except first start with 0b10 prefix. |
79 | // see https://en.wikipedia.org/wiki/UTF-8#Encoding for details |
80 | let trailing_incomplete_char_size = buf |
81 | .iter() |
82 | .rev() |
83 | .take(MAX_BYTES_PER_CHAR) |
84 | .position(|byte| *byte < 0b1000_0000 || *byte >= 0b1100_0000) |
85 | .unwrap_or(0) |
86 | + 1; |
87 | buf = &buf[..buf.len() - trailing_incomplete_char_size]; |
88 | } |
89 | |
90 | call_inner(buf) |
91 | } |
92 | |
93 | fn poll_flush( |
94 | mut self: Pin<&mut Self>, |
95 | cx: &mut Context<'_>, |
96 | ) -> Poll<Result<(), std::io::Error>> { |
97 | Pin::new(&mut self.inner).poll_flush(cx) |
98 | } |
99 | |
100 | fn poll_shutdown( |
101 | mut self: Pin<&mut Self>, |
102 | cx: &mut Context<'_>, |
103 | ) -> Poll<Result<(), std::io::Error>> { |
104 | Pin::new(&mut self.inner).poll_shutdown(cx) |
105 | } |
106 | } |
107 | |
108 | #[cfg (test)] |
109 | #[cfg (not(loom))] |
110 | mod tests { |
111 | use crate::io::blocking::MAX_BUF; |
112 | use crate::io::AsyncWriteExt; |
113 | use std::io; |
114 | use std::pin::Pin; |
115 | use std::task::Context; |
116 | use std::task::Poll; |
117 | |
118 | struct TextMockWriter; |
119 | |
120 | impl crate::io::AsyncWrite for TextMockWriter { |
121 | fn poll_write( |
122 | self: Pin<&mut Self>, |
123 | _cx: &mut Context<'_>, |
124 | buf: &[u8], |
125 | ) -> Poll<Result<usize, io::Error>> { |
126 | assert!(buf.len() <= MAX_BUF); |
127 | assert!(std::str::from_utf8(buf).is_ok()); |
128 | Poll::Ready(Ok(buf.len())) |
129 | } |
130 | |
131 | fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { |
132 | Poll::Ready(Ok(())) |
133 | } |
134 | |
135 | fn poll_shutdown( |
136 | self: Pin<&mut Self>, |
137 | _cx: &mut Context<'_>, |
138 | ) -> Poll<Result<(), io::Error>> { |
139 | Poll::Ready(Ok(())) |
140 | } |
141 | } |
142 | |
143 | struct LoggingMockWriter { |
144 | write_history: Vec<usize>, |
145 | } |
146 | |
147 | impl LoggingMockWriter { |
148 | fn new() -> Self { |
149 | LoggingMockWriter { |
150 | write_history: Vec::new(), |
151 | } |
152 | } |
153 | } |
154 | |
155 | impl crate::io::AsyncWrite for LoggingMockWriter { |
156 | fn poll_write( |
157 | mut self: Pin<&mut Self>, |
158 | _cx: &mut Context<'_>, |
159 | buf: &[u8], |
160 | ) -> Poll<Result<usize, io::Error>> { |
161 | assert!(buf.len() <= MAX_BUF); |
162 | self.write_history.push(buf.len()); |
163 | Poll::Ready(Ok(buf.len())) |
164 | } |
165 | |
166 | fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { |
167 | Poll::Ready(Ok(())) |
168 | } |
169 | |
170 | fn poll_shutdown( |
171 | self: Pin<&mut Self>, |
172 | _cx: &mut Context<'_>, |
173 | ) -> Poll<Result<(), io::Error>> { |
174 | Poll::Ready(Ok(())) |
175 | } |
176 | } |
177 | |
178 | #[test ] |
179 | #[cfg_attr (miri, ignore)] |
180 | fn test_splitter() { |
181 | let data = str::repeat("█" , MAX_BUF); |
182 | let mut wr = super::SplitByUtf8BoundaryIfWindows::new(TextMockWriter); |
183 | let fut = async move { |
184 | wr.write_all(data.as_bytes()).await.unwrap(); |
185 | }; |
186 | crate::runtime::Builder::new_current_thread() |
187 | .build() |
188 | .unwrap() |
189 | .block_on(fut); |
190 | } |
191 | |
192 | #[test ] |
193 | #[cfg_attr (miri, ignore)] |
194 | fn test_pseudo_text() { |
195 | // In this test we write a piece of binary data, whose beginning is |
196 | // text though. We then validate that even in this corner case buffer |
197 | // was not shrunk too much. |
198 | let checked_count = super::MAGIC_CONST * super::MAX_BYTES_PER_CHAR; |
199 | let mut data: Vec<u8> = str::repeat("a" , checked_count).into(); |
200 | data.extend(std::iter::repeat(0b1010_1010).take(MAX_BUF - checked_count + 1)); |
201 | let mut writer = LoggingMockWriter::new(); |
202 | let mut splitter = super::SplitByUtf8BoundaryIfWindows::new(&mut writer); |
203 | crate::runtime::Builder::new_current_thread() |
204 | .build() |
205 | .unwrap() |
206 | .block_on(async { |
207 | splitter.write_all(&data).await.unwrap(); |
208 | }); |
209 | // Check that at most two writes were performed |
210 | assert!(writer.write_history.len() <= 2); |
211 | // Check that all has been written |
212 | assert_eq!( |
213 | writer.write_history.iter().copied().sum::<usize>(), |
214 | data.len() |
215 | ); |
216 | // Check that at most MAX_BYTES_PER_CHAR + 1 (i.e. 5) bytes were shrunk |
217 | // from the buffer: one because it was outside of MAX_BUF boundary, and |
218 | // up to one "utf8 code point". |
219 | assert!(data.len() - writer.write_history[0] <= super::MAX_BYTES_PER_CHAR + 1); |
220 | } |
221 | } |
222 | |