1use std::fmt::{self, Debug};
2
3use super::chunks::ChunkProducer;
4use super::plumbing::*;
5use super::*;
6use crate::math::div_round_up;
7
8/// `FoldChunks` is an iterator that groups elements of an underlying iterator and applies a
9/// function over them, producing a single value for each group.
10///
11/// This struct is created by the [`fold_chunks()`] method on [`IndexedParallelIterator`]
12///
13/// [`fold_chunks()`]: trait.IndexedParallelIterator.html#method.fold_chunks
14/// [`IndexedParallelIterator`]: trait.IndexedParallelIterator.html
15#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
16#[derive(Clone)]
17pub struct FoldChunks<I, ID, F>
18where
19 I: IndexedParallelIterator,
20{
21 base: I,
22 chunk_size: usize,
23 fold_op: F,
24 identity: ID,
25}
26
27impl<I: IndexedParallelIterator + Debug, ID, F> Debug for FoldChunks<I, ID, F> {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 f.debug_struct("Fold")
30 .field("base", &self.base)
31 .field("chunk_size", &self.chunk_size)
32 .finish()
33 }
34}
35
36impl<I, ID, U, F> FoldChunks<I, ID, F>
37where
38 I: IndexedParallelIterator,
39 ID: Fn() -> U + Send + Sync,
40 F: Fn(U, I::Item) -> U + Send + Sync,
41 U: Send,
42{
43 /// Creates a new `FoldChunks` iterator
44 pub(super) fn new(base: I, chunk_size: usize, identity: ID, fold_op: F) -> Self {
45 FoldChunks {
46 base,
47 chunk_size,
48 identity,
49 fold_op,
50 }
51 }
52}
53
54impl<I, ID, U, F> ParallelIterator for FoldChunks<I, ID, F>
55where
56 I: IndexedParallelIterator,
57 ID: Fn() -> U + Send + Sync,
58 F: Fn(U, I::Item) -> U + Send + Sync,
59 U: Send,
60{
61 type Item = U;
62
63 fn drive_unindexed<C>(self, consumer: C) -> C::Result
64 where
65 C: Consumer<U>,
66 {
67 bridge(self, consumer)
68 }
69
70 fn opt_len(&self) -> Option<usize> {
71 Some(self.len())
72 }
73}
74
75impl<I, ID, U, F> IndexedParallelIterator for FoldChunks<I, ID, F>
76where
77 I: IndexedParallelIterator,
78 ID: Fn() -> U + Send + Sync,
79 F: Fn(U, I::Item) -> U + Send + Sync,
80 U: Send,
81{
82 fn len(&self) -> usize {
83 div_round_up(self.base.len(), self.chunk_size)
84 }
85
86 fn drive<C>(self, consumer: C) -> C::Result
87 where
88 C: Consumer<Self::Item>,
89 {
90 bridge(self, consumer)
91 }
92
93 fn with_producer<CB>(self, callback: CB) -> CB::Output
94 where
95 CB: ProducerCallback<Self::Item>,
96 {
97 let len = self.base.len();
98 return self.base.with_producer(Callback {
99 chunk_size: self.chunk_size,
100 len,
101 identity: self.identity,
102 fold_op: self.fold_op,
103 callback,
104 });
105
106 struct Callback<CB, ID, F> {
107 chunk_size: usize,
108 len: usize,
109 identity: ID,
110 fold_op: F,
111 callback: CB,
112 }
113
114 impl<T, CB, ID, U, F> ProducerCallback<T> for Callback<CB, ID, F>
115 where
116 CB: ProducerCallback<U>,
117 ID: Fn() -> U + Send + Sync,
118 F: Fn(U, T) -> U + Send + Sync,
119 {
120 type Output = CB::Output;
121
122 fn callback<P>(self, base: P) -> CB::Output
123 where
124 P: Producer<Item = T>,
125 {
126 let identity = &self.identity;
127 let fold_op = &self.fold_op;
128 let fold_iter = move |iter: P::IntoIter| iter.fold(identity(), fold_op);
129 let producer = ChunkProducer::new(self.chunk_size, self.len, base, fold_iter);
130 self.callback.callback(producer)
131 }
132 }
133 }
134}
135
136#[cfg(test)]
137mod test {
138 use super::*;
139 use std::ops::Add;
140
141 #[test]
142 fn check_fold_chunks() {
143 let words = "bishbashbosh!"
144 .chars()
145 .collect::<Vec<_>>()
146 .into_par_iter()
147 .fold_chunks(4, String::new, |mut s, c| {
148 s.push(c);
149 s
150 })
151 .collect::<Vec<_>>();
152
153 assert_eq!(words, vec!["bish", "bash", "bosh", "!"]);
154 }
155
156 // 'closure' values for tests below
157 fn id() -> i32 {
158 0
159 }
160 fn sum<T, U>(x: T, y: U) -> T
161 where
162 T: Add<U, Output = T>,
163 {
164 x + y
165 }
166
167 #[test]
168 #[should_panic(expected = "chunk_size must not be zero")]
169 fn check_fold_chunks_zero_size() {
170 let _: Vec<i32> = vec![1, 2, 3]
171 .into_par_iter()
172 .fold_chunks(0, id, sum)
173 .collect();
174 }
175
176 #[test]
177 fn check_fold_chunks_even_size() {
178 assert_eq!(
179 vec![1 + 2 + 3, 4 + 5 + 6, 7 + 8 + 9],
180 (1..10)
181 .into_par_iter()
182 .fold_chunks(3, id, sum)
183 .collect::<Vec<i32>>()
184 );
185 }
186
187 #[test]
188 fn check_fold_chunks_empty() {
189 let v: Vec<i32> = vec![];
190 let expected: Vec<i32> = vec![];
191 assert_eq!(
192 expected,
193 v.into_par_iter()
194 .fold_chunks(2, id, sum)
195 .collect::<Vec<i32>>()
196 );
197 }
198
199 #[test]
200 fn check_fold_chunks_len() {
201 assert_eq!(4, (0..8).into_par_iter().fold_chunks(2, id, sum).len());
202 assert_eq!(3, (0..9).into_par_iter().fold_chunks(3, id, sum).len());
203 assert_eq!(3, (0..8).into_par_iter().fold_chunks(3, id, sum).len());
204 assert_eq!(1, (&[1]).par_iter().fold_chunks(3, id, sum).len());
205 assert_eq!(0, (0..0).into_par_iter().fold_chunks(3, id, sum).len());
206 }
207
208 #[test]
209 fn check_fold_chunks_uneven() {
210 let cases: Vec<(Vec<u32>, usize, Vec<u32>)> = vec![
211 ((0..5).collect(), 3, vec![0 + 1 + 2, 3 + 4]),
212 (vec![1], 5, vec![1]),
213 ((0..4).collect(), 3, vec![0 + 1 + 2, 3]),
214 ];
215
216 for (i, (v, n, expected)) in cases.into_iter().enumerate() {
217 let mut res: Vec<u32> = vec![];
218 v.par_iter()
219 .fold_chunks(n, || 0, sum)
220 .collect_into_vec(&mut res);
221 assert_eq!(expected, res, "Case {} failed", i);
222
223 res.truncate(0);
224 v.into_par_iter()
225 .fold_chunks(n, || 0, sum)
226 .rev()
227 .collect_into_vec(&mut res);
228 assert_eq!(
229 expected.into_iter().rev().collect::<Vec<u32>>(),
230 res,
231 "Case {} reversed failed",
232 i
233 );
234 }
235 }
236}
237