1use super::buf_writer::BufWriter;
2use futures_core::ready;
3use futures_core::task::{Context, Poll};
4use futures_io::AsyncWrite;
5use futures_io::IoSlice;
6use pin_project_lite::pin_project;
7use std::io;
8use std::pin::Pin;
9
10pin_project! {
11/// Wrap a writer, like [`BufWriter`] does, but prioritizes buffering lines
12///
13/// This was written based on `std::io::LineWriter` which goes into further details
14/// explaining the code.
15///
16/// Buffering is actually done using `BufWriter`. This class will leverage `BufWriter`
17/// to write on-each-line.
18#[derive(Debug)]
19pub struct LineWriter<W: AsyncWrite> {
20 #[pin]
21 buf_writer: BufWriter<W>,
22}
23}
24
25impl<W: AsyncWrite> LineWriter<W> {
26 /// Create a new `LineWriter` with default buffer capacity. The default is currently 1KB
27 /// which was taken from `std::io::LineWriter`
28 pub fn new(inner: W) -> LineWriter<W> {
29 LineWriter::with_capacity(1024, inner)
30 }
31
32 /// Creates a new `LineWriter` with the specified buffer capacity.
33 pub fn with_capacity(capacity: usize, inner: W) -> LineWriter<W> {
34 LineWriter { buf_writer: BufWriter::with_capacity(capacity, inner) }
35 }
36
37 /// Flush `buf_writer` if last char is "new line"
38 fn flush_if_completed_line(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
39 let this = self.project();
40 match this.buf_writer.buffer().last().copied() {
41 Some(b'\n') => this.buf_writer.flush_buf(cx),
42 _ => Poll::Ready(Ok(())),
43 }
44 }
45
46 /// Returns a reference to `buf_writer`'s internally buffered data.
47 pub fn buffer(&self) -> &[u8] {
48 self.buf_writer.buffer()
49 }
50
51 /// Acquires a reference to the underlying sink or stream that this combinator is
52 /// pulling from.
53 pub fn get_ref(&self) -> &W {
54 self.buf_writer.get_ref()
55 }
56}
57
58impl<W: AsyncWrite> AsyncWrite for LineWriter<W> {
59 fn poll_write(
60 mut self: Pin<&mut Self>,
61 cx: &mut Context<'_>,
62 buf: &[u8],
63 ) -> Poll<io::Result<usize>> {
64 let mut this = self.as_mut().project();
65 let newline_index = match memchr::memrchr(b'\n', buf) {
66 None => {
67 ready!(self.as_mut().flush_if_completed_line(cx)?);
68 return self.project().buf_writer.poll_write(cx, buf);
69 }
70 Some(newline_index) => newline_index + 1,
71 };
72
73 ready!(this.buf_writer.as_mut().poll_flush(cx)?);
74
75 let lines = &buf[..newline_index];
76
77 let flushed = { ready!(this.buf_writer.as_mut().inner_poll_write(cx, lines))? };
78
79 if flushed == 0 {
80 return Poll::Ready(Ok(0));
81 }
82
83 let tail = if flushed >= newline_index {
84 &buf[flushed..]
85 } else if newline_index - flushed <= this.buf_writer.capacity() {
86 &buf[flushed..newline_index]
87 } else {
88 let scan_area = &buf[flushed..];
89 let scan_area = &scan_area[..this.buf_writer.capacity()];
90 match memchr::memrchr(b'\n', scan_area) {
91 Some(newline_index) => &scan_area[..newline_index + 1],
92 None => scan_area,
93 }
94 };
95
96 let buffered = this.buf_writer.as_mut().write_to_buf(tail);
97 Poll::Ready(Ok(flushed + buffered))
98 }
99
100 fn poll_write_vectored(
101 mut self: Pin<&mut Self>,
102 cx: &mut Context<'_>,
103 bufs: &[IoSlice<'_>],
104 ) -> Poll<io::Result<usize>> {
105 let mut this = self.as_mut().project();
106 // `is_write_vectored()` is handled in original code, but not in this crate
107 // see https://github.com/rust-lang/rust/issues/70436
108
109 let last_newline_buf_idx = bufs
110 .iter()
111 .enumerate()
112 .rev()
113 .find_map(|(i, buf)| memchr::memchr(b'\n', buf).map(|_| i));
114 let last_newline_buf_idx = match last_newline_buf_idx {
115 None => {
116 ready!(self.as_mut().flush_if_completed_line(cx)?);
117 return self.project().buf_writer.poll_write_vectored(cx, bufs);
118 }
119 Some(i) => i,
120 };
121
122 ready!(this.buf_writer.as_mut().poll_flush(cx)?);
123
124 let (lines, tail) = bufs.split_at(last_newline_buf_idx + 1);
125
126 let flushed = { ready!(this.buf_writer.as_mut().inner_poll_write_vectored(cx, lines))? };
127 if flushed == 0 {
128 return Poll::Ready(Ok(0));
129 }
130
131 let lines_len = lines.iter().map(|buf| buf.len()).sum();
132 if flushed < lines_len {
133 return Poll::Ready(Ok(flushed));
134 }
135
136 let buffered: usize = tail
137 .iter()
138 .filter(|buf| !buf.is_empty())
139 .map(|buf| this.buf_writer.as_mut().write_to_buf(buf))
140 .take_while(|&n| n > 0)
141 .sum();
142
143 Poll::Ready(Ok(flushed + buffered))
144 }
145
146 /// Forward to `buf_writer` 's `BufWriter::poll_flush()`
147 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
148 self.as_mut().project().buf_writer.poll_flush(cx)
149 }
150
151 /// Forward to `buf_writer` 's `BufWriter::poll_close()`
152 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
153 self.as_mut().project().buf_writer.poll_close(cx)
154 }
155}
156