1use std::fmt::{self, Debug};
2
3use super::chunks::ChunkProducer;
4use super::plumbing::*;
5use super::*;
6use crate::math::div_round_up;
7
8/// `FoldChunksWith` is an iterator that groups elements of an underlying iterator and applies a
9/// function over them, producing a single value for each group.
10///
11/// This struct is created by the [`fold_chunks_with()`] method on [`IndexedParallelIterator`]
12///
13/// [`fold_chunks_with()`]: trait.IndexedParallelIterator.html#method.fold_chunks
14/// [`IndexedParallelIterator`]: trait.IndexedParallelIterator.html
15#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
16#[derive(Clone)]
17pub struct FoldChunksWith<I, U, F>
18where
19 I: IndexedParallelIterator,
20{
21 base: I,
22 chunk_size: usize,
23 item: U,
24 fold_op: F,
25}
26
27impl<I: IndexedParallelIterator + Debug, U: Debug, F> Debug for FoldChunksWith<I, U, F> {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 f.debug_struct("Fold")
30 .field("base", &self.base)
31 .field("chunk_size", &self.chunk_size)
32 .field("item", &self.item)
33 .finish()
34 }
35}
36
37impl<I, U, F> FoldChunksWith<I, U, F>
38where
39 I: IndexedParallelIterator,
40 U: Send + Clone,
41 F: Fn(U, I::Item) -> U + Send + Sync,
42{
43 /// Creates a new `FoldChunksWith` iterator
44 pub(super) fn new(base: I, chunk_size: usize, item: U, fold_op: F) -> Self {
45 FoldChunksWith {
46 base,
47 chunk_size,
48 item,
49 fold_op,
50 }
51 }
52}
53
54impl<I, U, F> ParallelIterator for FoldChunksWith<I, U, F>
55where
56 I: IndexedParallelIterator,
57 U: Send + Clone,
58 F: Fn(U, I::Item) -> U + Send + Sync,
59{
60 type Item = U;
61
62 fn drive_unindexed<C>(self, consumer: C) -> C::Result
63 where
64 C: Consumer<U>,
65 {
66 bridge(self, consumer)
67 }
68
69 fn opt_len(&self) -> Option<usize> {
70 Some(self.len())
71 }
72}
73
74impl<I, U, F> IndexedParallelIterator for FoldChunksWith<I, U, F>
75where
76 I: IndexedParallelIterator,
77 U: Send + Clone,
78 F: Fn(U, I::Item) -> U + Send + Sync,
79{
80 fn len(&self) -> usize {
81 div_round_up(self.base.len(), self.chunk_size)
82 }
83
84 fn drive<C>(self, consumer: C) -> C::Result
85 where
86 C: Consumer<Self::Item>,
87 {
88 bridge(self, consumer)
89 }
90
91 fn with_producer<CB>(self, callback: CB) -> CB::Output
92 where
93 CB: ProducerCallback<Self::Item>,
94 {
95 let len = self.base.len();
96 return self.base.with_producer(Callback {
97 chunk_size: self.chunk_size,
98 len,
99 item: self.item,
100 fold_op: self.fold_op,
101 callback,
102 });
103
104 struct Callback<CB, T, F> {
105 chunk_size: usize,
106 len: usize,
107 item: T,
108 fold_op: F,
109 callback: CB,
110 }
111
112 impl<T, U, F, CB> ProducerCallback<T> for Callback<CB, U, F>
113 where
114 CB: ProducerCallback<U>,
115 U: Send + Clone,
116 F: Fn(U, T) -> U + Send + Sync,
117 {
118 type Output = CB::Output;
119
120 fn callback<P>(self, base: P) -> CB::Output
121 where
122 P: Producer<Item = T>,
123 {
124 let item = self.item;
125 let fold_op = &self.fold_op;
126 let fold_iter = move |iter: P::IntoIter| iter.fold(item.clone(), fold_op);
127 let producer = ChunkProducer::new(self.chunk_size, self.len, base, fold_iter);
128 self.callback.callback(producer)
129 }
130 }
131 }
132}
133
134#[cfg(test)]
135mod test {
136 use super::*;
137 use std::ops::Add;
138
139 #[test]
140 fn check_fold_chunks_with() {
141 let words = "bishbashbosh!"
142 .chars()
143 .collect::<Vec<_>>()
144 .into_par_iter()
145 .fold_chunks_with(4, String::new(), |mut s, c| {
146 s.push(c);
147 s
148 })
149 .collect::<Vec<_>>();
150
151 assert_eq!(words, vec!["bish", "bash", "bosh", "!"]);
152 }
153
154 // 'closure' value for tests below
155 fn sum<T, U>(x: T, y: U) -> T
156 where
157 T: Add<U, Output = T>,
158 {
159 x + y
160 }
161
162 #[test]
163 #[should_panic(expected = "chunk_size must not be zero")]
164 fn check_fold_chunks_zero_size() {
165 let _: Vec<i32> = vec![1, 2, 3]
166 .into_par_iter()
167 .fold_chunks_with(0, 0, sum)
168 .collect();
169 }
170
171 #[test]
172 fn check_fold_chunks_even_size() {
173 assert_eq!(
174 vec![1 + 2 + 3, 4 + 5 + 6, 7 + 8 + 9],
175 (1..10)
176 .into_par_iter()
177 .fold_chunks_with(3, 0, sum)
178 .collect::<Vec<i32>>()
179 );
180 }
181
182 #[test]
183 fn check_fold_chunks_with_empty() {
184 let v: Vec<i32> = vec![];
185 let expected: Vec<i32> = vec![];
186 assert_eq!(
187 expected,
188 v.into_par_iter()
189 .fold_chunks_with(2, 0, sum)
190 .collect::<Vec<i32>>()
191 );
192 }
193
194 #[test]
195 fn check_fold_chunks_len() {
196 assert_eq!(4, (0..8).into_par_iter().fold_chunks_with(2, 0, sum).len());
197 assert_eq!(3, (0..9).into_par_iter().fold_chunks_with(3, 0, sum).len());
198 assert_eq!(3, (0..8).into_par_iter().fold_chunks_with(3, 0, sum).len());
199 assert_eq!(1, (&[1]).par_iter().fold_chunks_with(3, 0, sum).len());
200 assert_eq!(0, (0..0).into_par_iter().fold_chunks_with(3, 0, sum).len());
201 }
202
203 #[test]
204 fn check_fold_chunks_uneven() {
205 let cases: Vec<(Vec<u32>, usize, Vec<u32>)> = vec![
206 ((0..5).collect(), 3, vec![0 + 1 + 2, 3 + 4]),
207 (vec![1], 5, vec![1]),
208 ((0..4).collect(), 3, vec![0 + 1 + 2, 3]),
209 ];
210
211 for (i, (v, n, expected)) in cases.into_iter().enumerate() {
212 let mut res: Vec<u32> = vec![];
213 v.par_iter()
214 .fold_chunks_with(n, 0, sum)
215 .collect_into_vec(&mut res);
216 assert_eq!(expected, res, "Case {} failed", i);
217
218 res.truncate(0);
219 v.into_par_iter()
220 .fold_chunks_with(n, 0, sum)
221 .rev()
222 .collect_into_vec(&mut res);
223 assert_eq!(
224 expected.into_iter().rev().collect::<Vec<u32>>(),
225 res,
226 "Case {} reversed failed",
227 i
228 );
229 }
230 }
231}
232