1use std::borrow::Cow;
2use std::convert::TryFrom;
3use std::io::{self, IoSliceMut};
4use std::iter::FusedIterator;
5#[cfg(feature = "tokio")]
6use std::pin::Pin;
7#[cfg(feature = "tokio")]
8use std::task::{Context, Poll};
9use std::time::Duration;
10
11#[cfg(feature = "tokio")]
12use tokio::io::{ReadBuf, SeekFrom};
13
14use crate::progress_bar::ProgressBar;
15use crate::state::ProgressFinish;
16use crate::style::ProgressStyle;
17
18/// Wraps an iterator to display its progress.
19pub trait ProgressIterator
20where
21 Self: Sized + Iterator,
22{
23 /// Wrap an iterator with default styling. Uses `Iterator::size_hint` to get length.
24 /// Returns `Some(..)` only if `size_hint.1` is `Some`. If you want to create a progress bar
25 /// even if `size_hint.1` returns `None` use `progress_count` or `progress_with` instead.
26 fn try_progress(self) -> Option<ProgressBarIter<Self>> {
27 self.size_hint()
28 .1
29 .map(|len| self.progress_count(u64::try_from(len).unwrap()))
30 }
31
32 /// Wrap an iterator with default styling.
33 fn progress(self) -> ProgressBarIter<Self>
34 where
35 Self: ExactSizeIterator,
36 {
37 let len = u64::try_from(self.len()).unwrap();
38 self.progress_count(len)
39 }
40
41 /// Wrap an iterator with an explicit element count.
42 fn progress_count(self, len: u64) -> ProgressBarIter<Self> {
43 self.progress_with(ProgressBar::new(len))
44 }
45
46 /// Wrap an iterator with a custom progress bar.
47 fn progress_with(self, progress: ProgressBar) -> ProgressBarIter<Self>;
48
49 /// Wrap an iterator with a progress bar and style it.
50 fn progress_with_style(self, style: crate::ProgressStyle) -> ProgressBarIter<Self>
51 where
52 Self: ExactSizeIterator,
53 {
54 let len = u64::try_from(self.len()).unwrap();
55 let bar = ProgressBar::new(len).with_style(style);
56 self.progress_with(bar)
57 }
58}
59
60/// Wraps an iterator to display its progress.
61#[derive(Debug)]
62pub struct ProgressBarIter<T> {
63 pub(crate) it: T,
64 pub progress: ProgressBar,
65}
66
67impl<T> ProgressBarIter<T> {
68 /// Builder-like function for setting underlying progress bar's style.
69 ///
70 /// See [ProgressBar::with_style].
71 pub fn with_style(mut self, style: ProgressStyle) -> Self {
72 self.progress = self.progress.with_style(style);
73 self
74 }
75
76 /// Builder-like function for setting underlying progress bar's prefix.
77 ///
78 /// See [ProgressBar::with_prefix].
79 pub fn with_prefix(mut self, prefix: impl Into<Cow<'static, str>>) -> Self {
80 self.progress = self.progress.with_prefix(prefix);
81 self
82 }
83
84 /// Builder-like function for setting underlying progress bar's message.
85 ///
86 /// See [ProgressBar::with_message].
87 pub fn with_message(mut self, message: impl Into<Cow<'static, str>>) -> Self {
88 self.progress = self.progress.with_message(message);
89 self
90 }
91
92 /// Builder-like function for setting underlying progress bar's position.
93 ///
94 /// See [ProgressBar::with_position].
95 pub fn with_position(mut self, position: u64) -> Self {
96 self.progress = self.progress.with_position(position);
97 self
98 }
99
100 /// Builder-like function for setting underlying progress bar's elapsed time.
101 ///
102 /// See [ProgressBar::with_elapsed].
103 pub fn with_elapsed(mut self, elapsed: Duration) -> Self {
104 self.progress = self.progress.with_elapsed(elapsed);
105 self
106 }
107
108 /// Builder-like function for setting underlying progress bar's finish behavior.
109 ///
110 /// See [ProgressBar::with_finish].
111 pub fn with_finish(mut self, finish: ProgressFinish) -> Self {
112 self.progress = self.progress.with_finish(finish);
113 self
114 }
115}
116
117impl<S, T: Iterator<Item = S>> Iterator for ProgressBarIter<T> {
118 type Item = S;
119
120 fn next(&mut self) -> Option<Self::Item> {
121 let item: Option = self.it.next();
122
123 if item.is_some() {
124 self.progress.inc(delta:1);
125 } else if !self.progress.is_finished() {
126 self.progress.finish_using_style();
127 }
128
129 item
130 }
131}
132
133impl<T: ExactSizeIterator> ExactSizeIterator for ProgressBarIter<T> {
134 fn len(&self) -> usize {
135 self.it.len()
136 }
137}
138
139impl<T: DoubleEndedIterator> DoubleEndedIterator for ProgressBarIter<T> {
140 fn next_back(&mut self) -> Option<Self::Item> {
141 let item: Option<::Item> = self.it.next_back();
142
143 if item.is_some() {
144 self.progress.inc(delta:1);
145 } else if !self.progress.is_finished() {
146 self.progress.finish_using_style();
147 }
148
149 item
150 }
151}
152
153impl<T: FusedIterator> FusedIterator for ProgressBarIter<T> {}
154
155impl<R: io::Read> io::Read for ProgressBarIter<R> {
156 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
157 let inc = self.it.read(buf)?;
158 self.progress.inc(inc as u64);
159 Ok(inc)
160 }
161
162 fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
163 let inc = self.it.read_vectored(bufs)?;
164 self.progress.inc(inc as u64);
165 Ok(inc)
166 }
167
168 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
169 let inc = self.it.read_to_string(buf)?;
170 self.progress.inc(inc as u64);
171 Ok(inc)
172 }
173
174 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
175 self.it.read_exact(buf)?;
176 self.progress.inc(buf.len() as u64);
177 Ok(())
178 }
179}
180
181impl<R: io::BufRead> io::BufRead for ProgressBarIter<R> {
182 fn fill_buf(&mut self) -> io::Result<&[u8]> {
183 self.it.fill_buf()
184 }
185
186 fn consume(&mut self, amt: usize) {
187 self.it.consume(amt);
188 self.progress.inc(delta:amt as u64);
189 }
190}
191
192impl<S: io::Seek> io::Seek for ProgressBarIter<S> {
193 fn seek(&mut self, f: io::SeekFrom) -> io::Result<u64> {
194 self.it.seek(f).map(|pos: u64| {
195 self.progress.set_position(pos);
196 pos
197 })
198 }
199 // Pass this through to preserve optimizations that the inner I/O object may use here
200 // Also avoid sending a set_position update when the position hasn't changed
201 fn stream_position(&mut self) -> io::Result<u64> {
202 self.it.stream_position()
203 }
204}
205
206#[cfg(feature = "tokio")]
207#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
208impl<W: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for ProgressBarIter<W> {
209 fn poll_write(
210 mut self: Pin<&mut Self>,
211 cx: &mut Context<'_>,
212 buf: &[u8],
213 ) -> Poll<io::Result<usize>> {
214 Pin::new(&mut self.it).poll_write(cx, buf).map(|poll| {
215 poll.map(|inc| {
216 self.progress.inc(inc as u64);
217 inc
218 })
219 })
220 }
221
222 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
223 Pin::new(&mut self.it).poll_flush(cx)
224 }
225
226 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
227 Pin::new(&mut self.it).poll_shutdown(cx)
228 }
229}
230
231#[cfg(feature = "tokio")]
232#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
233impl<W: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for ProgressBarIter<W> {
234 fn poll_read(
235 mut self: Pin<&mut Self>,
236 cx: &mut Context<'_>,
237 buf: &mut ReadBuf<'_>,
238 ) -> Poll<io::Result<()>> {
239 let prev_len = buf.filled().len() as u64;
240 if let Poll::Ready(e) = Pin::new(&mut self.it).poll_read(cx, buf) {
241 self.progress.inc(buf.filled().len() as u64 - prev_len);
242 Poll::Ready(e)
243 } else {
244 Poll::Pending
245 }
246 }
247}
248
249#[cfg(feature = "tokio")]
250#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
251impl<W: tokio::io::AsyncSeek + Unpin> tokio::io::AsyncSeek for ProgressBarIter<W> {
252 fn start_seek(mut self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
253 Pin::new(&mut self.it).start_seek(position)
254 }
255
256 fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
257 Pin::new(&mut self.it).poll_complete(cx)
258 }
259}
260
261#[cfg(feature = "tokio")]
262#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
263impl<W: tokio::io::AsyncBufRead + Unpin + tokio::io::AsyncRead> tokio::io::AsyncBufRead
264 for ProgressBarIter<W>
265{
266 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
267 let this = self.get_mut();
268 let result = Pin::new(&mut this.it).poll_fill_buf(cx);
269 if let Poll::Ready(Ok(buf)) = &result {
270 this.progress.inc(buf.len() as u64);
271 }
272 result
273 }
274
275 fn consume(mut self: Pin<&mut Self>, amt: usize) {
276 Pin::new(&mut self.it).consume(amt);
277 }
278}
279
280#[cfg(feature = "futures")]
281#[cfg_attr(docsrs, doc(cfg(feature = "futures")))]
282impl<S: futures_core::Stream + Unpin> futures_core::Stream for ProgressBarIter<S> {
283 type Item = S::Item;
284
285 fn poll_next(
286 self: std::pin::Pin<&mut Self>,
287 cx: &mut std::task::Context<'_>,
288 ) -> std::task::Poll<Option<Self::Item>> {
289 let this = self.get_mut();
290 let item = std::pin::Pin::new(&mut this.it).poll_next(cx);
291 match &item {
292 std::task::Poll::Ready(Some(_)) => this.progress.inc(1),
293 std::task::Poll::Ready(None) => this.progress.finish_using_style(),
294 std::task::Poll::Pending => {}
295 }
296 item
297 }
298}
299
300impl<W: io::Write> io::Write for ProgressBarIter<W> {
301 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
302 self.it.write(buf).map(|inc: usize| {
303 self.progress.inc(delta:inc as u64);
304 inc
305 })
306 }
307
308 fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> io::Result<usize> {
309 self.it.write_vectored(bufs).map(|inc: usize| {
310 self.progress.inc(delta:inc as u64);
311 inc
312 })
313 }
314
315 fn flush(&mut self) -> io::Result<()> {
316 self.it.flush()
317 }
318
319 // write_fmt can not be captured with reasonable effort.
320 // as it uses write_all internally by default that should not be a problem.
321 // fn write_fmt(&mut self, fmt: fmt::Arguments) -> io::Result<()>;
322}
323
324impl<S, T: Iterator<Item = S>> ProgressIterator for T {
325 fn progress_with(self, progress: ProgressBar) -> ProgressBarIter<Self> {
326 ProgressBarIter { it: self, progress }
327 }
328}
329
330#[cfg(test)]
331mod test {
332 use crate::iter::{ProgressBarIter, ProgressIterator};
333 use crate::progress_bar::ProgressBar;
334 use crate::ProgressStyle;
335
336 #[test]
337 fn it_can_wrap_an_iterator() {
338 let v = vec![1, 2, 3];
339 let wrap = |it: ProgressBarIter<_>| {
340 assert_eq!(it.map(|x| x * 2).collect::<Vec<_>>(), vec![2, 4, 6]);
341 };
342
343 wrap(v.iter().progress());
344 wrap(v.iter().progress_count(3));
345 wrap({
346 let pb = ProgressBar::new(v.len() as u64);
347 v.iter().progress_with(pb)
348 });
349 wrap({
350 let style = ProgressStyle::default_bar()
351 .template("{wide_bar:.red} {percent}/100%")
352 .unwrap();
353 v.iter().progress_with_style(style)
354 });
355 }
356}
357