1//! Contains utilities for stdout and stderr.
2use crate::io::AsyncWrite;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5/// # Windows
6/// [`AsyncWrite`] adapter that finds last char boundary in given buffer and does not write the rest,
7/// if buffer contents seems to be `utf8`. Otherwise it only trims buffer down to `MAX_BUF`.
8/// That's why, wrapped writer will always receive well-formed utf-8 bytes.
9/// # Other platforms
10/// Passes data to `inner` as is.
11#[derive(Debug)]
12pub(crate) struct SplitByUtf8BoundaryIfWindows<W> {
13 inner: W,
14}
15
16impl<W> SplitByUtf8BoundaryIfWindows<W> {
17 pub(crate) fn new(inner: W) -> Self {
18 Self { inner }
19 }
20}
21
22// this constant is defined by Unicode standard.
23const MAX_BYTES_PER_CHAR: usize = 4;
24
25// Subject for tweaking here
26const MAGIC_CONST: usize = 8;
27
28impl<W> crate::io::AsyncWrite for SplitByUtf8BoundaryIfWindows<W>
29where
30 W: AsyncWrite + Unpin,
31{
32 fn poll_write(
33 mut self: Pin<&mut Self>,
34 cx: &mut Context<'_>,
35 mut buf: &[u8],
36 ) -> Poll<Result<usize, std::io::Error>> {
37 // just a closure to avoid repetitive code
38 let mut call_inner = move |buf| Pin::new(&mut self.inner).poll_write(cx, buf);
39
40 // 1. Only windows stdio can suffer from non-utf8.
41 // We also check for `test` so that we can write some tests
42 // for further code. Since `AsyncWrite` can always shrink
43 // buffer at its discretion, excessive (i.e. in tests) shrinking
44 // does not break correctness.
45 // 2. If buffer is small, it will not be shrunk.
46 // That's why, it's "textness" will not change, so we don't have
47 // to fixup it.
48 if cfg!(not(any(target_os = "windows", test))) || buf.len() <= crate::io::blocking::MAX_BUF
49 {
50 return call_inner(buf);
51 }
52
53 buf = &buf[..crate::io::blocking::MAX_BUF];
54
55 // Now there are two possibilities.
56 // If caller gave is binary buffer, we **should not** shrink it
57 // anymore, because excessive shrinking hits performance.
58 // If caller gave as binary buffer, we **must** additionally
59 // shrink it to strip incomplete char at the end of buffer.
60 // that's why check we will perform now is allowed to have
61 // false-positive.
62
63 // Now let's look at the first MAX_BYTES_PER_CHAR * MAGIC_CONST bytes.
64 // if they are (possibly incomplete) utf8, then we can be quite sure
65 // that input buffer was utf8.
66
67 let have_to_fix_up = match std::str::from_utf8(&buf[..MAX_BYTES_PER_CHAR * MAGIC_CONST]) {
68 Ok(_) => true,
69 Err(err) => {
70 let incomplete_bytes = MAX_BYTES_PER_CHAR * MAGIC_CONST - err.valid_up_to();
71 incomplete_bytes < MAX_BYTES_PER_CHAR
72 }
73 };
74
75 if have_to_fix_up {
76 // We must pop several bytes at the end which form incomplete
77 // character. To achieve it, we exploit UTF8 encoding:
78 // for any code point, all bytes except first start with 0b10 prefix.
79 // see https://en.wikipedia.org/wiki/UTF-8#Encoding for details
80 let trailing_incomplete_char_size = buf
81 .iter()
82 .rev()
83 .take(MAX_BYTES_PER_CHAR)
84 .position(|byte| *byte < 0b1000_0000 || *byte >= 0b1100_0000)
85 .unwrap_or(0)
86 + 1;
87 buf = &buf[..buf.len() - trailing_incomplete_char_size];
88 }
89
90 call_inner(buf)
91 }
92
93 fn poll_flush(
94 mut self: Pin<&mut Self>,
95 cx: &mut Context<'_>,
96 ) -> Poll<Result<(), std::io::Error>> {
97 Pin::new(&mut self.inner).poll_flush(cx)
98 }
99
100 fn poll_shutdown(
101 mut self: Pin<&mut Self>,
102 cx: &mut Context<'_>,
103 ) -> Poll<Result<(), std::io::Error>> {
104 Pin::new(&mut self.inner).poll_shutdown(cx)
105 }
106}
107
108#[cfg(test)]
109#[cfg(not(loom))]
110mod tests {
111 use crate::io::blocking::MAX_BUF;
112 use crate::io::AsyncWriteExt;
113 use std::io;
114 use std::pin::Pin;
115 use std::task::Context;
116 use std::task::Poll;
117
118 struct TextMockWriter;
119
120 impl crate::io::AsyncWrite for TextMockWriter {
121 fn poll_write(
122 self: Pin<&mut Self>,
123 _cx: &mut Context<'_>,
124 buf: &[u8],
125 ) -> Poll<Result<usize, io::Error>> {
126 assert!(buf.len() <= MAX_BUF);
127 assert!(std::str::from_utf8(buf).is_ok());
128 Poll::Ready(Ok(buf.len()))
129 }
130
131 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
132 Poll::Ready(Ok(()))
133 }
134
135 fn poll_shutdown(
136 self: Pin<&mut Self>,
137 _cx: &mut Context<'_>,
138 ) -> Poll<Result<(), io::Error>> {
139 Poll::Ready(Ok(()))
140 }
141 }
142
143 struct LoggingMockWriter {
144 write_history: Vec<usize>,
145 }
146
147 impl LoggingMockWriter {
148 fn new() -> Self {
149 LoggingMockWriter {
150 write_history: Vec::new(),
151 }
152 }
153 }
154
155 impl crate::io::AsyncWrite for LoggingMockWriter {
156 fn poll_write(
157 mut self: Pin<&mut Self>,
158 _cx: &mut Context<'_>,
159 buf: &[u8],
160 ) -> Poll<Result<usize, io::Error>> {
161 assert!(buf.len() <= MAX_BUF);
162 self.write_history.push(buf.len());
163 Poll::Ready(Ok(buf.len()))
164 }
165
166 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
167 Poll::Ready(Ok(()))
168 }
169
170 fn poll_shutdown(
171 self: Pin<&mut Self>,
172 _cx: &mut Context<'_>,
173 ) -> Poll<Result<(), io::Error>> {
174 Poll::Ready(Ok(()))
175 }
176 }
177
178 #[test]
179 #[cfg_attr(miri, ignore)]
180 fn test_splitter() {
181 let data = str::repeat("█", MAX_BUF);
182 let mut wr = super::SplitByUtf8BoundaryIfWindows::new(TextMockWriter);
183 let fut = async move {
184 wr.write_all(data.as_bytes()).await.unwrap();
185 };
186 crate::runtime::Builder::new_current_thread()
187 .build()
188 .unwrap()
189 .block_on(fut);
190 }
191
192 #[test]
193 #[cfg_attr(miri, ignore)]
194 fn test_pseudo_text() {
195 // In this test we write a piece of binary data, whose beginning is
196 // text though. We then validate that even in this corner case buffer
197 // was not shrunk too much.
198 let checked_count = super::MAGIC_CONST * super::MAX_BYTES_PER_CHAR;
199 let mut data: Vec<u8> = str::repeat("a", checked_count).into();
200 data.extend(std::iter::repeat(0b1010_1010).take(MAX_BUF - checked_count + 1));
201 let mut writer = LoggingMockWriter::new();
202 let mut splitter = super::SplitByUtf8BoundaryIfWindows::new(&mut writer);
203 crate::runtime::Builder::new_current_thread()
204 .build()
205 .unwrap()
206 .block_on(async {
207 splitter.write_all(&data).await.unwrap();
208 });
209 // Check that at most two writes were performed
210 assert!(writer.write_history.len() <= 2);
211 // Check that all has been written
212 assert_eq!(
213 writer.write_history.iter().copied().sum::<usize>(),
214 data.len()
215 );
216 // Check that at most MAX_BYTES_PER_CHAR + 1 (i.e. 5) bytes were shrunk
217 // from the buffer: one because it was outside of MAX_BUF boundary, and
218 // up to one "utf8 code point".
219 assert!(data.len() - writer.write_history[0] <= super::MAX_BYTES_PER_CHAR + 1);
220 }
221}
222