1 | use std::fmt::{self, Debug}; |
2 | |
3 | use super::chunks::ChunkProducer; |
4 | use super::plumbing::*; |
5 | use super::*; |
6 | use crate::math::div_round_up; |
7 | |
8 | /// `FoldChunks` is an iterator that groups elements of an underlying iterator and applies a |
9 | /// function over them, producing a single value for each group. |
10 | /// |
11 | /// This struct is created by the [`fold_chunks()`] method on [`IndexedParallelIterator`] |
12 | /// |
13 | /// [`fold_chunks()`]: trait.IndexedParallelIterator.html#method.fold_chunks |
14 | /// [`IndexedParallelIterator`]: trait.IndexedParallelIterator.html |
15 | #[must_use = "iterator adaptors are lazy and do nothing unless consumed" ] |
16 | #[derive(Clone)] |
17 | pub struct FoldChunks<I, ID, F> |
18 | where |
19 | I: IndexedParallelIterator, |
20 | { |
21 | base: I, |
22 | chunk_size: usize, |
23 | fold_op: F, |
24 | identity: ID, |
25 | } |
26 | |
27 | impl<I: IndexedParallelIterator + Debug, ID, F> Debug for FoldChunks<I, ID, F> { |
28 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
29 | f.debug_struct("Fold" ) |
30 | .field("base" , &self.base) |
31 | .field("chunk_size" , &self.chunk_size) |
32 | .finish() |
33 | } |
34 | } |
35 | |
36 | impl<I, ID, U, F> FoldChunks<I, ID, F> |
37 | where |
38 | I: IndexedParallelIterator, |
39 | ID: Fn() -> U + Send + Sync, |
40 | F: Fn(U, I::Item) -> U + Send + Sync, |
41 | U: Send, |
42 | { |
43 | /// Creates a new `FoldChunks` iterator |
44 | pub(super) fn new(base: I, chunk_size: usize, identity: ID, fold_op: F) -> Self { |
45 | FoldChunks { |
46 | base, |
47 | chunk_size, |
48 | identity, |
49 | fold_op, |
50 | } |
51 | } |
52 | } |
53 | |
54 | impl<I, ID, U, F> ParallelIterator for FoldChunks<I, ID, F> |
55 | where |
56 | I: IndexedParallelIterator, |
57 | ID: Fn() -> U + Send + Sync, |
58 | F: Fn(U, I::Item) -> U + Send + Sync, |
59 | U: Send, |
60 | { |
61 | type Item = U; |
62 | |
63 | fn drive_unindexed<C>(self, consumer: C) -> C::Result |
64 | where |
65 | C: Consumer<U>, |
66 | { |
67 | bridge(self, consumer) |
68 | } |
69 | |
70 | fn opt_len(&self) -> Option<usize> { |
71 | Some(self.len()) |
72 | } |
73 | } |
74 | |
75 | impl<I, ID, U, F> IndexedParallelIterator for FoldChunks<I, ID, F> |
76 | where |
77 | I: IndexedParallelIterator, |
78 | ID: Fn() -> U + Send + Sync, |
79 | F: Fn(U, I::Item) -> U + Send + Sync, |
80 | U: Send, |
81 | { |
82 | fn len(&self) -> usize { |
83 | div_round_up(self.base.len(), self.chunk_size) |
84 | } |
85 | |
86 | fn drive<C>(self, consumer: C) -> C::Result |
87 | where |
88 | C: Consumer<Self::Item>, |
89 | { |
90 | bridge(self, consumer) |
91 | } |
92 | |
93 | fn with_producer<CB>(self, callback: CB) -> CB::Output |
94 | where |
95 | CB: ProducerCallback<Self::Item>, |
96 | { |
97 | let len = self.base.len(); |
98 | return self.base.with_producer(Callback { |
99 | chunk_size: self.chunk_size, |
100 | len, |
101 | identity: self.identity, |
102 | fold_op: self.fold_op, |
103 | callback, |
104 | }); |
105 | |
106 | struct Callback<CB, ID, F> { |
107 | chunk_size: usize, |
108 | len: usize, |
109 | identity: ID, |
110 | fold_op: F, |
111 | callback: CB, |
112 | } |
113 | |
114 | impl<T, CB, ID, U, F> ProducerCallback<T> for Callback<CB, ID, F> |
115 | where |
116 | CB: ProducerCallback<U>, |
117 | ID: Fn() -> U + Send + Sync, |
118 | F: Fn(U, T) -> U + Send + Sync, |
119 | { |
120 | type Output = CB::Output; |
121 | |
122 | fn callback<P>(self, base: P) -> CB::Output |
123 | where |
124 | P: Producer<Item = T>, |
125 | { |
126 | let identity = &self.identity; |
127 | let fold_op = &self.fold_op; |
128 | let fold_iter = move |iter: P::IntoIter| iter.fold(identity(), fold_op); |
129 | let producer = ChunkProducer::new(self.chunk_size, self.len, base, fold_iter); |
130 | self.callback.callback(producer) |
131 | } |
132 | } |
133 | } |
134 | } |
135 | |
136 | #[cfg (test)] |
137 | mod test { |
138 | use super::*; |
139 | use std::ops::Add; |
140 | |
141 | #[test] |
142 | fn check_fold_chunks() { |
143 | let words = "bishbashbosh!" |
144 | .chars() |
145 | .collect::<Vec<_>>() |
146 | .into_par_iter() |
147 | .fold_chunks(4, String::new, |mut s, c| { |
148 | s.push(c); |
149 | s |
150 | }) |
151 | .collect::<Vec<_>>(); |
152 | |
153 | assert_eq!(words, vec!["bish" , "bash" , "bosh" , "!" ]); |
154 | } |
155 | |
156 | // 'closure' values for tests below |
157 | fn id() -> i32 { |
158 | 0 |
159 | } |
160 | fn sum<T, U>(x: T, y: U) -> T |
161 | where |
162 | T: Add<U, Output = T>, |
163 | { |
164 | x + y |
165 | } |
166 | |
167 | #[test] |
168 | #[should_panic (expected = "chunk_size must not be zero" )] |
169 | fn check_fold_chunks_zero_size() { |
170 | let _: Vec<i32> = vec![1, 2, 3] |
171 | .into_par_iter() |
172 | .fold_chunks(0, id, sum) |
173 | .collect(); |
174 | } |
175 | |
176 | #[test] |
177 | fn check_fold_chunks_even_size() { |
178 | assert_eq!( |
179 | vec![1 + 2 + 3, 4 + 5 + 6, 7 + 8 + 9], |
180 | (1..10) |
181 | .into_par_iter() |
182 | .fold_chunks(3, id, sum) |
183 | .collect::<Vec<i32>>() |
184 | ); |
185 | } |
186 | |
187 | #[test] |
188 | fn check_fold_chunks_empty() { |
189 | let v: Vec<i32> = vec![]; |
190 | let expected: Vec<i32> = vec![]; |
191 | assert_eq!( |
192 | expected, |
193 | v.into_par_iter() |
194 | .fold_chunks(2, id, sum) |
195 | .collect::<Vec<i32>>() |
196 | ); |
197 | } |
198 | |
199 | #[test] |
200 | fn check_fold_chunks_len() { |
201 | assert_eq!(4, (0..8).into_par_iter().fold_chunks(2, id, sum).len()); |
202 | assert_eq!(3, (0..9).into_par_iter().fold_chunks(3, id, sum).len()); |
203 | assert_eq!(3, (0..8).into_par_iter().fold_chunks(3, id, sum).len()); |
204 | assert_eq!(1, (&[1]).par_iter().fold_chunks(3, id, sum).len()); |
205 | assert_eq!(0, (0..0).into_par_iter().fold_chunks(3, id, sum).len()); |
206 | } |
207 | |
208 | #[test] |
209 | fn check_fold_chunks_uneven() { |
210 | let cases: Vec<(Vec<u32>, usize, Vec<u32>)> = vec![ |
211 | ((0..5).collect(), 3, vec![0 + 1 + 2, 3 + 4]), |
212 | (vec![1], 5, vec![1]), |
213 | ((0..4).collect(), 3, vec![0 + 1 + 2, 3]), |
214 | ]; |
215 | |
216 | for (i, (v, n, expected)) in cases.into_iter().enumerate() { |
217 | let mut res: Vec<u32> = vec![]; |
218 | v.par_iter() |
219 | .fold_chunks(n, || 0, sum) |
220 | .collect_into_vec(&mut res); |
221 | assert_eq!(expected, res, "Case {} failed" , i); |
222 | |
223 | res.truncate(0); |
224 | v.into_par_iter() |
225 | .fold_chunks(n, || 0, sum) |
226 | .rev() |
227 | .collect_into_vec(&mut res); |
228 | assert_eq!( |
229 | expected.into_iter().rev().collect::<Vec<u32>>(), |
230 | res, |
231 | "Case {} reversed failed" , |
232 | i |
233 | ); |
234 | } |
235 | } |
236 | } |
237 | |