1 | mod util; |
2 | |
3 | use std::ops::{Add, AddAssign}; |
4 | |
5 | use anyhow::anyhow; |
6 | use arrayvec::ArrayVec; |
7 | use v_frame::{frame::Frame, math::clamp, plane::Plane}; |
8 | |
9 | use self::util::{extract_ar_row, get_block_mean, get_noise_var, linsolve, multiply_mat}; |
10 | use super::{NoiseStatus, BLOCK_SIZE, BLOCK_SIZE_SQUARED}; |
11 | use crate::{ |
12 | diff::solver::util::normalized_cross_correlation, GrainTableSegment, DEFAULT_GRAIN_SEED, |
13 | NUM_UV_COEFFS, NUM_UV_POINTS, NUM_Y_COEFFS, NUM_Y_POINTS, |
14 | }; |
15 | |
16 | const LOW_POLY_NUM_PARAMS: usize = 3; |
17 | const NOISE_MODEL_LAG: usize = 3; |
18 | const BLOCK_NORMALIZATION: f64 = 255.0f64; |
19 | |
20 | #[derive (Debug, Clone, Copy)] |
21 | pub(super) struct FlatBlockFinder { |
22 | a: [f64; LOW_POLY_NUM_PARAMS * BLOCK_SIZE_SQUARED], |
23 | a_t_a_inv: [f64; LOW_POLY_NUM_PARAMS * LOW_POLY_NUM_PARAMS], |
24 | } |
25 | |
26 | impl FlatBlockFinder { |
27 | #[must_use ] |
28 | pub fn new() -> Self { |
29 | let mut eqns = EquationSystem::new(LOW_POLY_NUM_PARAMS); |
30 | let mut a_t_a_inv = [0.0f64; LOW_POLY_NUM_PARAMS * LOW_POLY_NUM_PARAMS]; |
31 | let mut a = [0.0f64; LOW_POLY_NUM_PARAMS * BLOCK_SIZE_SQUARED]; |
32 | |
33 | let bs_half = (BLOCK_SIZE / 2) as f64; |
34 | (0..BLOCK_SIZE).for_each(|y| { |
35 | let yd = (y as f64 - bs_half) / bs_half; |
36 | (0..BLOCK_SIZE).for_each(|x| { |
37 | let xd = (x as f64 - bs_half) / bs_half; |
38 | let coords = [yd, xd, 1.0f64]; |
39 | let row = y * BLOCK_SIZE + x; |
40 | a[LOW_POLY_NUM_PARAMS * row] = yd; |
41 | a[LOW_POLY_NUM_PARAMS * row + 1] = xd; |
42 | a[LOW_POLY_NUM_PARAMS * row + 2] = 1.0f64; |
43 | |
44 | (0..LOW_POLY_NUM_PARAMS).for_each(|i| { |
45 | (0..LOW_POLY_NUM_PARAMS).for_each(|j| { |
46 | eqns.a[LOW_POLY_NUM_PARAMS * i + j] += coords[i] * coords[j]; |
47 | }); |
48 | }); |
49 | }); |
50 | }); |
51 | |
52 | // Lazy inverse using existing equation solver. |
53 | (0..LOW_POLY_NUM_PARAMS).for_each(|i| { |
54 | eqns.b.fill(0.0f64); |
55 | eqns.b[i] = 1.0f64; |
56 | eqns.solve(); |
57 | |
58 | (0..LOW_POLY_NUM_PARAMS).for_each(|j| { |
59 | a_t_a_inv[j * LOW_POLY_NUM_PARAMS + i] = eqns.x[j]; |
60 | }); |
61 | }); |
62 | |
63 | FlatBlockFinder { a, a_t_a_inv } |
64 | } |
65 | |
66 | // The gradient-based features used in this code are based on: |
67 | // A. Kokaram, D. Kelly, H. Denman and A. Crawford, "Measuring noise |
68 | // correlation for improved video denoising," 2012 19th, ICIP. |
69 | // The thresholds are more lenient to allow for correct grain modeling |
70 | // in extreme cases. |
71 | #[must_use ] |
72 | #[allow (clippy::too_many_lines)] |
73 | pub fn run(&self, plane: &Plane<u8>) -> (Vec<u8>, usize) { |
74 | const TRACE_THRESHOLD: f64 = 0.15f64 / BLOCK_SIZE_SQUARED as f64; |
75 | const RATIO_THRESHOLD: f64 = 1.25f64; |
76 | const NORM_THRESHOLD: f64 = 0.08f64 / BLOCK_SIZE_SQUARED as f64; |
77 | const VAR_THRESHOLD: f64 = 0.005f64 / BLOCK_SIZE_SQUARED as f64; |
78 | |
79 | // The following weights are used to combine the above features to give |
80 | // a sigmoid score for flatness. If the input was normalized to [0,100] |
81 | // the magnitude of these values would be close to 1 (e.g., weights |
82 | // corresponding to variance would be a factor of 10000x smaller). |
83 | const VAR_WEIGHT: f64 = -6682f64; |
84 | const RATIO_WEIGHT: f64 = -0.2056f64; |
85 | const TRACE_WEIGHT: f64 = 13087f64; |
86 | const NORM_WEIGHT: f64 = -12434f64; |
87 | const OFFSET: f64 = 2.5694f64; |
88 | |
89 | let num_blocks_w = (plane.cfg.width + BLOCK_SIZE - 1) / BLOCK_SIZE; |
90 | let num_blocks_h = (plane.cfg.height + BLOCK_SIZE - 1) / BLOCK_SIZE; |
91 | let num_blocks = num_blocks_w * num_blocks_h; |
92 | let mut flat_blocks = vec![0u8; num_blocks]; |
93 | let mut num_flat = 0; |
94 | let mut plane_result = [0.0f64; BLOCK_SIZE_SQUARED]; |
95 | let mut block_result = [0.0f64; BLOCK_SIZE_SQUARED]; |
96 | let mut scores = vec![IndexAndScore::default(); num_blocks]; |
97 | |
98 | for by in 0..num_blocks_h { |
99 | for bx in 0..num_blocks_w { |
100 | // Compute gradient covariance matrix. |
101 | let mut gxx = 0f64; |
102 | let mut gxy = 0f64; |
103 | let mut gyy = 0f64; |
104 | let mut var = 0f64; |
105 | let mut mean = 0f64; |
106 | |
107 | self.extract_block( |
108 | plane, |
109 | bx * BLOCK_SIZE, |
110 | by * BLOCK_SIZE, |
111 | &mut plane_result, |
112 | &mut block_result, |
113 | ); |
114 | for yi in 1..(BLOCK_SIZE - 1) { |
115 | for xi in 1..(BLOCK_SIZE - 1) { |
116 | // SAFETY: We know the size of `block_result` and that we cannot exceed the bounds of it |
117 | unsafe { |
118 | let result_ptr = block_result.as_ptr().add(yi * BLOCK_SIZE + xi); |
119 | |
120 | let gx = (*result_ptr.add(1) - *result_ptr.sub(1)) / 2f64; |
121 | let gy = |
122 | (*result_ptr.add(BLOCK_SIZE) - *result_ptr.sub(BLOCK_SIZE)) / 2f64; |
123 | gxx += gx * gx; |
124 | gxy += gx * gy; |
125 | gyy += gy * gy; |
126 | |
127 | let block_val = *result_ptr; |
128 | mean += block_val; |
129 | var += block_val * block_val; |
130 | } |
131 | } |
132 | } |
133 | let block_size_norm_factor = (BLOCK_SIZE - 2).pow(2) as f64; |
134 | mean /= block_size_norm_factor; |
135 | |
136 | // Normalize gradients by block_size. |
137 | gxx /= block_size_norm_factor; |
138 | gxy /= block_size_norm_factor; |
139 | gyy /= block_size_norm_factor; |
140 | var = mean.mul_add(-mean, var / block_size_norm_factor); |
141 | |
142 | let trace = gxx + gyy; |
143 | let det = gxx.mul_add(gyy, -gxy.powi(2)); |
144 | let e_sub = (trace.mul_add(trace, -4f64 * det)).max(0.).sqrt(); |
145 | let e1 = (trace + e_sub) / 2.0f64; |
146 | let e2 = (trace - e_sub) / 2.0f64; |
147 | // Spectral norm |
148 | let norm = e1; |
149 | let ratio = e1 / e2.max(1.0e-6_f64); |
150 | let is_flat = trace < TRACE_THRESHOLD |
151 | && ratio < RATIO_THRESHOLD |
152 | && norm < NORM_THRESHOLD |
153 | && var > VAR_THRESHOLD; |
154 | |
155 | let sum_weights = NORM_WEIGHT.mul_add( |
156 | norm, |
157 | TRACE_WEIGHT.mul_add( |
158 | trace, |
159 | VAR_WEIGHT.mul_add(var, RATIO_WEIGHT.mul_add(ratio, OFFSET)), |
160 | ), |
161 | ); |
162 | // clamp the value to [-25.0, 100.0] to prevent overflow |
163 | let sum_weights = clamp(sum_weights, -25.0f64, 100.0f64); |
164 | let score = (1.0f64 / (1.0f64 + (-sum_weights).exp())) as f32; |
165 | // SAFETY: We know the size of `flat_blocks` and `scores` and that we cannot exceed the bounds of it |
166 | unsafe { |
167 | let index = by * num_blocks_w + bx; |
168 | *flat_blocks.get_unchecked_mut(index) = if is_flat { 255 } else { 0 }; |
169 | *scores.get_unchecked_mut(index) = IndexAndScore { |
170 | score: if var > VAR_THRESHOLD { score } else { 0f32 }, |
171 | index, |
172 | }; |
173 | } |
174 | if is_flat { |
175 | num_flat += 1; |
176 | } |
177 | } |
178 | } |
179 | |
180 | scores.sort_unstable_by(|a, b| a.score.partial_cmp(&b.score).expect("Shouldn't be NaN" )); |
181 | // SAFETY: We know the size of `flat_blocks` and `scores` and that we cannot exceed the bounds of it |
182 | unsafe { |
183 | let top_nth_percentile = num_blocks * 90 / 100; |
184 | let score_threshold = scores.get_unchecked(top_nth_percentile).score; |
185 | for score in &scores { |
186 | if score.score >= score_threshold { |
187 | let block_ref = flat_blocks.get_unchecked_mut(score.index); |
188 | if *block_ref == 0 { |
189 | num_flat += 1; |
190 | } |
191 | *block_ref |= 1; |
192 | } |
193 | } |
194 | } |
195 | |
196 | (flat_blocks, num_flat) |
197 | } |
198 | |
199 | fn extract_block( |
200 | &self, |
201 | plane: &Plane<u8>, |
202 | offset_x: usize, |
203 | offset_y: usize, |
204 | plane_result: &mut [f64; BLOCK_SIZE_SQUARED], |
205 | block_result: &mut [f64; BLOCK_SIZE_SQUARED], |
206 | ) { |
207 | let mut plane_coords = [0f64; LOW_POLY_NUM_PARAMS]; |
208 | let mut a_t_a_inv_b = [0f64; LOW_POLY_NUM_PARAMS]; |
209 | let plane_origin = plane.data_origin(); |
210 | |
211 | for yi in 0..BLOCK_SIZE { |
212 | let y = clamp(offset_y + yi, 0, plane.cfg.height - 1); |
213 | for xi in 0..BLOCK_SIZE { |
214 | let x = clamp(offset_x + xi, 0, plane.cfg.width - 1); |
215 | // SAFETY: We know the bounds of the plane data and `block_result` |
216 | // and do not exceed them. |
217 | unsafe { |
218 | *block_result.get_unchecked_mut(yi * BLOCK_SIZE + xi) = |
219 | f64::from(*plane_origin.get_unchecked(y * plane.cfg.stride + x)) |
220 | / BLOCK_NORMALIZATION; |
221 | } |
222 | } |
223 | } |
224 | |
225 | multiply_mat( |
226 | block_result, |
227 | &self.a, |
228 | &mut a_t_a_inv_b, |
229 | 1, |
230 | BLOCK_SIZE_SQUARED, |
231 | LOW_POLY_NUM_PARAMS, |
232 | ); |
233 | multiply_mat( |
234 | &self.a_t_a_inv, |
235 | &a_t_a_inv_b, |
236 | &mut plane_coords, |
237 | LOW_POLY_NUM_PARAMS, |
238 | LOW_POLY_NUM_PARAMS, |
239 | 1, |
240 | ); |
241 | multiply_mat( |
242 | &self.a, |
243 | &plane_coords, |
244 | plane_result, |
245 | BLOCK_SIZE_SQUARED, |
246 | LOW_POLY_NUM_PARAMS, |
247 | 1, |
248 | ); |
249 | |
250 | for (block_res, plane_res) in block_result.iter_mut().zip(plane_result.iter()) { |
251 | *block_res -= *plane_res; |
252 | } |
253 | } |
254 | } |
255 | |
256 | #[derive (Debug, Clone, Copy, Default)] |
257 | struct IndexAndScore { |
258 | pub index: usize, |
259 | pub score: f32, |
260 | } |
261 | |
262 | /// Wrapper of data required to represent linear system of eqns and soln. |
263 | #[derive (Debug, Clone)] |
264 | struct EquationSystem { |
265 | a: Vec<f64>, |
266 | b: Vec<f64>, |
267 | x: Vec<f64>, |
268 | n: usize, |
269 | } |
270 | |
271 | impl EquationSystem { |
272 | #[must_use ] |
273 | pub fn new(n: usize) -> Self { |
274 | Self { |
275 | a: vec![0.0f64; n * n], |
276 | b: vec![0.0f64; n], |
277 | x: vec![0.0f64; n], |
278 | n, |
279 | } |
280 | } |
281 | |
282 | pub fn solve(&mut self) -> bool { |
283 | let n = self.n; |
284 | let mut a = self.a.clone(); |
285 | let mut b = self.b.clone(); |
286 | |
287 | linsolve(n, &mut a, self.n, &mut b, &mut self.x) |
288 | } |
289 | |
290 | pub fn set_chroma_coefficient_fallback_solution(&mut self) { |
291 | const TOLERANCE: f64 = 1.0e-6f64; |
292 | let last = self.n - 1; |
293 | // Set all of the AR coefficients to zero, but try to solve for correlation |
294 | // with the luma channel |
295 | self.x.fill(0f64); |
296 | if self.a[last * self.n + last].abs() > TOLERANCE { |
297 | self.x[last] = self.b[last] / self.a[last * self.n + last]; |
298 | } |
299 | } |
300 | |
301 | pub fn copy_from(&mut self, other: &Self) { |
302 | assert_eq!(self.n, other.n); |
303 | self.a.copy_from_slice(&other.a); |
304 | self.x.copy_from_slice(&other.x); |
305 | self.b.copy_from_slice(&other.b); |
306 | } |
307 | |
308 | pub fn clear(&mut self) { |
309 | self.a.fill(0f64); |
310 | self.b.fill(0f64); |
311 | self.x.fill(0f64); |
312 | } |
313 | } |
314 | |
315 | impl Add<&EquationSystem> for EquationSystem { |
316 | type Output = EquationSystem; |
317 | |
318 | #[must_use ] |
319 | fn add(self, addend: &EquationSystem) -> Self::Output { |
320 | let mut dest: EquationSystem = self.clone(); |
321 | let n: usize = self.n; |
322 | for i: usize in 0..n { |
323 | for j: usize in 0..n { |
324 | dest.a[i * n + j] += addend.a[i * n + j]; |
325 | } |
326 | dest.b[i] += addend.b[i]; |
327 | } |
328 | dest |
329 | } |
330 | } |
331 | |
332 | impl AddAssign<&EquationSystem> for EquationSystem { |
333 | fn add_assign(&mut self, rhs: &EquationSystem) { |
334 | *self = self.clone() + rhs; |
335 | } |
336 | } |
337 | |
338 | /// Representation of a piecewise linear curve |
339 | /// |
340 | /// Holds n points as (x, y) pairs, that store the curve. |
341 | struct NoiseStrengthLut { |
342 | points: Vec<[f64; 2]>, |
343 | } |
344 | |
345 | impl NoiseStrengthLut { |
346 | #[must_use ] |
347 | pub fn new(num_bins: usize) -> Self { |
348 | assert!(num_bins > 0); |
349 | Self { |
350 | points: vec![[0f64; 2]; num_bins], |
351 | } |
352 | } |
353 | } |
354 | |
355 | #[derive (Debug, Clone)] |
356 | pub(super) struct NoiseModel { |
357 | combined_state: [NoiseModelState; 3], |
358 | latest_state: [NoiseModelState; 3], |
359 | n: usize, |
360 | coords: Vec<[isize; 2]>, |
361 | } |
362 | |
363 | impl NoiseModel { |
364 | #[must_use ] |
365 | pub fn new() -> Self { |
366 | let n = Self::num_coeffs(); |
367 | let combined_state = [ |
368 | NoiseModelState::new(n), |
369 | NoiseModelState::new(n + 1), |
370 | NoiseModelState::new(n + 1), |
371 | ]; |
372 | let latest_state = [ |
373 | NoiseModelState::new(n), |
374 | NoiseModelState::new(n + 1), |
375 | NoiseModelState::new(n + 1), |
376 | ]; |
377 | let mut coords = Vec::new(); |
378 | |
379 | let neg_lag = -(NOISE_MODEL_LAG as isize); |
380 | for y in neg_lag..=0 { |
381 | let max_x = if y == 0 { |
382 | -1isize |
383 | } else { |
384 | NOISE_MODEL_LAG as isize |
385 | }; |
386 | for x in neg_lag..=max_x { |
387 | coords.push([x, y]); |
388 | } |
389 | } |
390 | assert!(n == coords.len()); |
391 | |
392 | Self { |
393 | combined_state, |
394 | latest_state, |
395 | n, |
396 | coords, |
397 | } |
398 | } |
399 | |
400 | pub fn update( |
401 | &mut self, |
402 | source: &Frame<u8>, |
403 | denoised: &Frame<u8>, |
404 | flat_blocks: &[u8], |
405 | ) -> NoiseStatus { |
406 | let num_blocks_w = (source.planes[0].cfg.width + BLOCK_SIZE - 1) / BLOCK_SIZE; |
407 | let num_blocks_h = (source.planes[0].cfg.height + BLOCK_SIZE - 1) / BLOCK_SIZE; |
408 | let mut y_model_different = false; |
409 | |
410 | // Clear the latest equation system |
411 | for i in 0..3 { |
412 | self.latest_state[i].eqns.clear(); |
413 | self.latest_state[i].num_observations = 0; |
414 | self.latest_state[i].strength_solver.clear(); |
415 | } |
416 | |
417 | // Check that we have enough flat blocks |
418 | let num_blocks = flat_blocks.iter().filter(|b| **b > 0).count(); |
419 | if num_blocks <= 1 { |
420 | return NoiseStatus::Error(anyhow!("Not enough flat blocks to update noise estimate" )); |
421 | } |
422 | |
423 | let frame_dims = (source.planes[0].cfg.width, source.planes[0].cfg.height); |
424 | for channel in 0..3 { |
425 | if source.planes[channel].data.is_empty() { |
426 | // Monochrome source |
427 | break; |
428 | } |
429 | let is_chroma = channel > 0; |
430 | let alt_source = (channel > 0).then(|| &source.planes[0]); |
431 | let alt_denoised = (channel > 0).then(|| &denoised.planes[0]); |
432 | self.add_block_observations( |
433 | channel, |
434 | &source.planes[channel], |
435 | &denoised.planes[channel], |
436 | alt_source, |
437 | alt_denoised, |
438 | frame_dims, |
439 | flat_blocks, |
440 | num_blocks_w, |
441 | num_blocks_h, |
442 | ); |
443 | if !self.latest_state[channel].ar_equation_system_solve(is_chroma) { |
444 | if is_chroma { |
445 | self.latest_state[channel] |
446 | .eqns |
447 | .set_chroma_coefficient_fallback_solution(); |
448 | } else { |
449 | return NoiseStatus::Error(anyhow!( |
450 | "Solving latest noise equation system failed on plane {}" , |
451 | channel |
452 | )); |
453 | } |
454 | } |
455 | self.add_noise_std_observations( |
456 | channel, |
457 | &source.planes[channel], |
458 | &denoised.planes[channel], |
459 | alt_source, |
460 | frame_dims, |
461 | flat_blocks, |
462 | num_blocks_w, |
463 | num_blocks_h, |
464 | ); |
465 | if !self.latest_state[channel].strength_solver.solve() { |
466 | return NoiseStatus::Error(anyhow!( |
467 | "Failed to solve strength solver for latest state" |
468 | )); |
469 | } |
470 | |
471 | // Check noise characteristics and return if error |
472 | if channel == 0 |
473 | && self.combined_state[channel].strength_solver.num_equations > 0 |
474 | && self.is_different() |
475 | { |
476 | y_model_different = true; |
477 | } |
478 | |
479 | if y_model_different { |
480 | continue; |
481 | } |
482 | |
483 | self.combined_state[channel].num_observations += |
484 | self.latest_state[channel].num_observations; |
485 | self.combined_state[channel].eqns += &self.latest_state[channel].eqns; |
486 | if !self.combined_state[channel].ar_equation_system_solve(is_chroma) { |
487 | if is_chroma { |
488 | self.combined_state[channel] |
489 | .eqns |
490 | .set_chroma_coefficient_fallback_solution(); |
491 | } else { |
492 | return NoiseStatus::Error(anyhow!( |
493 | "Solving combined noise equation system failed on plane {}" , |
494 | channel |
495 | )); |
496 | } |
497 | } |
498 | |
499 | self.combined_state[channel].strength_solver += |
500 | &self.latest_state[channel].strength_solver; |
501 | |
502 | if !self.combined_state[channel].strength_solver.solve() { |
503 | return NoiseStatus::Error(anyhow!( |
504 | "Failed to solve strength solver for combined state" |
505 | )); |
506 | }; |
507 | } |
508 | |
509 | if y_model_different { |
510 | return NoiseStatus::DifferentType; |
511 | } |
512 | |
513 | NoiseStatus::Ok |
514 | } |
515 | |
516 | #[allow (clippy::too_many_lines)] |
517 | #[must_use ] |
518 | pub fn get_grain_parameters(&self, start_ts: u64, end_ts: u64) -> GrainTableSegment { |
519 | // Both the domain and the range of the scaling functions in the film_grain |
520 | // are normalized to 8-bit (e.g., they are implicitly scaled during grain |
521 | // synthesis). |
522 | let scaling_points_y = self.combined_state[0] |
523 | .strength_solver |
524 | .fit_piecewise(NUM_Y_POINTS) |
525 | .points; |
526 | let scaling_points_cb = self.combined_state[1] |
527 | .strength_solver |
528 | .fit_piecewise(NUM_UV_POINTS) |
529 | .points; |
530 | let scaling_points_cr = self.combined_state[2] |
531 | .strength_solver |
532 | .fit_piecewise(NUM_UV_POINTS) |
533 | .points; |
534 | |
535 | let mut max_scaling_value: f64 = 1.0e-4f64; |
536 | for p in scaling_points_y |
537 | .iter() |
538 | .chain(scaling_points_cb.iter()) |
539 | .chain(scaling_points_cr.iter()) |
540 | .map(|p| p[1]) |
541 | { |
542 | if p > max_scaling_value { |
543 | max_scaling_value = p; |
544 | } |
545 | } |
546 | |
547 | // Scaling_shift values are in the range [8,11] |
548 | let max_scaling_value_log2 = |
549 | clamp((max_scaling_value.log2() + 1f64).floor() as u8, 2u8, 5u8); |
550 | let scale_factor = f64::from(1u32 << (8u8 - max_scaling_value_log2)); |
551 | let map_scaling_point = |p: [f64; 2]| { |
552 | [ |
553 | (p[0] + 0.5f64) as u8, |
554 | clamp(scale_factor.mul_add(p[1], 0.5f64) as i32, 0i32, 255i32) as u8, |
555 | ] |
556 | }; |
557 | |
558 | let scaling_points_y: ArrayVec<_, NUM_Y_POINTS> = scaling_points_y |
559 | .into_iter() |
560 | .map(map_scaling_point) |
561 | .collect(); |
562 | let scaling_points_cb: ArrayVec<_, NUM_UV_POINTS> = scaling_points_cb |
563 | .into_iter() |
564 | .map(map_scaling_point) |
565 | .collect(); |
566 | let scaling_points_cr: ArrayVec<_, NUM_UV_POINTS> = scaling_points_cr |
567 | .into_iter() |
568 | .map(map_scaling_point) |
569 | .collect(); |
570 | |
571 | // Convert the ar_coeffs into 8-bit values |
572 | let n_coeff = self.combined_state[0].eqns.n; |
573 | let mut max_coeff = 1.0e-4f64; |
574 | let mut min_coeff = 1.0e-4f64; |
575 | let mut y_corr = [0f64; 2]; |
576 | let mut avg_luma_strength = 0f64; |
577 | for c in 0..3 { |
578 | let eqns = &self.combined_state[c].eqns; |
579 | for i in 0..n_coeff { |
580 | if eqns.x[i] > max_coeff { |
581 | max_coeff = eqns.x[i]; |
582 | } |
583 | if eqns.x[i] < min_coeff { |
584 | min_coeff = eqns.x[i]; |
585 | } |
586 | } |
587 | |
588 | // Since the correlation between luma/chroma was computed in an already |
589 | // scaled space, we adjust it in the un-scaled space. |
590 | let solver = &self.combined_state[c].strength_solver; |
591 | // Compute a weighted average of the strength for the channel. |
592 | let mut average_strength = 0f64; |
593 | let mut total_weight = 0f64; |
594 | for i in 0..solver.eqns.n { |
595 | let mut w = 0f64; |
596 | for j in 0..solver.eqns.n { |
597 | w += solver.eqns.a[i * solver.eqns.n + j]; |
598 | } |
599 | w = w.sqrt(); |
600 | average_strength += solver.eqns.x[i] * w; |
601 | total_weight += w; |
602 | } |
603 | if total_weight.abs() < f64::EPSILON { |
604 | average_strength = 1f64; |
605 | } else { |
606 | average_strength /= total_weight; |
607 | } |
608 | if c == 0 { |
609 | avg_luma_strength = average_strength; |
610 | } else { |
611 | y_corr[c - 1] = avg_luma_strength * eqns.x[n_coeff] / average_strength; |
612 | max_coeff = max_coeff.max(y_corr[c - 1]); |
613 | min_coeff = min_coeff.min(y_corr[c - 1]); |
614 | } |
615 | } |
616 | |
617 | // Shift value: AR coeffs range (values 6-9) |
618 | // 6: [-2, 2), 7: [-1, 1), 8: [-0.5, 0.5), 9: [-0.25, 0.25) |
619 | let ar_coeff_shift = clamp( |
620 | 7i32 - (1.0f64 + max_coeff.log2().floor()).max((-min_coeff).log2().ceil()) as i32, |
621 | 6i32, |
622 | 9i32, |
623 | ) as u8; |
624 | let scale_ar_coeff = f64::from(1u16 << ar_coeff_shift); |
625 | let ar_coeffs_y = self.get_ar_coeffs_y(n_coeff, scale_ar_coeff); |
626 | let ar_coeffs_cb = self.get_ar_coeffs_uv(1, n_coeff, scale_ar_coeff, y_corr); |
627 | let ar_coeffs_cr = self.get_ar_coeffs_uv(2, n_coeff, scale_ar_coeff, y_corr); |
628 | |
629 | GrainTableSegment { |
630 | random_seed: if start_ts == 0 { DEFAULT_GRAIN_SEED } else { 0 }, |
631 | start_time: start_ts, |
632 | end_time: end_ts, |
633 | ar_coeff_lag: NOISE_MODEL_LAG as u8, |
634 | scaling_points_y, |
635 | scaling_points_cb, |
636 | scaling_points_cr, |
637 | scaling_shift: 5 + (8 - max_scaling_value_log2), |
638 | ar_coeff_shift, |
639 | ar_coeffs_y, |
640 | ar_coeffs_cb, |
641 | ar_coeffs_cr, |
642 | // At the moment, the noise modeling code assumes that the chroma scaling |
643 | // functions are a function of luma. |
644 | cb_mult: 128, |
645 | cb_luma_mult: 192, |
646 | cb_offset: 256, |
647 | cr_mult: 128, |
648 | cr_luma_mult: 192, |
649 | cr_offset: 256, |
650 | chroma_scaling_from_luma: false, |
651 | grain_scale_shift: 0, |
652 | overlap_flag: true, |
653 | } |
654 | } |
655 | |
656 | pub fn save_latest(&mut self) { |
657 | for c in 0..3 { |
658 | let latest_state = &self.latest_state[c]; |
659 | let combined_state = &mut self.combined_state[c]; |
660 | combined_state.eqns.copy_from(&latest_state.eqns); |
661 | combined_state |
662 | .strength_solver |
663 | .eqns |
664 | .copy_from(&latest_state.strength_solver.eqns); |
665 | combined_state.strength_solver.num_equations = |
666 | latest_state.strength_solver.num_equations; |
667 | combined_state.num_observations = latest_state.num_observations; |
668 | combined_state.ar_gain = latest_state.ar_gain; |
669 | } |
670 | } |
671 | |
672 | #[must_use ] |
673 | const fn num_coeffs() -> usize { |
674 | let n = 2 * NOISE_MODEL_LAG + 1; |
675 | (n * n) / 2 |
676 | } |
677 | |
678 | #[must_use ] |
679 | fn get_ar_coeffs_y(&self, n_coeff: usize, scale_ar_coeff: f64) -> ArrayVec<i8, NUM_Y_COEFFS> { |
680 | assert!(n_coeff <= NUM_Y_COEFFS); |
681 | let mut coeffs = ArrayVec::new(); |
682 | let eqns = &self.combined_state[0].eqns; |
683 | for i in 0..n_coeff { |
684 | coeffs.push(clamp((scale_ar_coeff * eqns.x[i]).round() as i32, -128i32, 127i32) as i8); |
685 | } |
686 | coeffs |
687 | } |
688 | |
689 | #[must_use ] |
690 | fn get_ar_coeffs_uv( |
691 | &self, |
692 | channel: usize, |
693 | n_coeff: usize, |
694 | scale_ar_coeff: f64, |
695 | y_corr: [f64; 2], |
696 | ) -> ArrayVec<i8, NUM_UV_COEFFS> { |
697 | assert!(n_coeff <= NUM_Y_COEFFS); |
698 | let mut coeffs = ArrayVec::new(); |
699 | let eqns = &self.combined_state[channel].eqns; |
700 | for i in 0..n_coeff { |
701 | coeffs.push(clamp((scale_ar_coeff * eqns.x[i]).round() as i32, -128i32, 127i32) as i8); |
702 | } |
703 | coeffs.push(clamp( |
704 | (scale_ar_coeff * y_corr[channel - 1]).round() as i32, |
705 | -128i32, |
706 | 127i32, |
707 | ) as i8); |
708 | coeffs |
709 | } |
710 | |
711 | // Return true if the noise estimate appears to be different from the combined |
712 | // (multi-frame) estimate. The difference is measured by checking whether the |
713 | // AR coefficients have diverged (using a threshold on normalized cross |
714 | // correlation), or whether the noise strength has changed. |
715 | #[must_use ] |
716 | fn is_different(&self) -> bool { |
717 | const COEFF_THRESHOLD: f64 = 0.9f64; |
718 | const STRENGTH_THRESHOLD: f64 = 0.005f64; |
719 | |
720 | let latest = &self.latest_state[0]; |
721 | let combined = &self.combined_state[0]; |
722 | let corr = normalized_cross_correlation(&latest.eqns.x, &combined.eqns.x, combined.eqns.n); |
723 | if corr < COEFF_THRESHOLD { |
724 | return true; |
725 | } |
726 | |
727 | let dx = 1.0f64 / latest.strength_solver.num_bins as f64; |
728 | let latest_eqns = &latest.strength_solver.eqns; |
729 | let combined_eqns = &combined.strength_solver.eqns; |
730 | let mut diff = 0.0f64; |
731 | let mut total_weight = 0.0f64; |
732 | for j in 0..latest_eqns.n { |
733 | let mut weight = 0.0f64; |
734 | for i in 0..latest_eqns.n { |
735 | weight += latest_eqns.a[i * latest_eqns.n + j]; |
736 | } |
737 | weight = weight.sqrt(); |
738 | diff += weight * (latest_eqns.x[j] - combined_eqns.x[j]).abs(); |
739 | total_weight += weight; |
740 | } |
741 | |
742 | diff * dx / total_weight > STRENGTH_THRESHOLD |
743 | } |
744 | |
745 | #[allow (clippy::too_many_arguments)] |
746 | fn add_block_observations( |
747 | &mut self, |
748 | channel: usize, |
749 | source: &Plane<u8>, |
750 | denoised: &Plane<u8>, |
751 | alt_source: Option<&Plane<u8>>, |
752 | alt_denoised: Option<&Plane<u8>>, |
753 | frame_dims: (usize, usize), |
754 | flat_blocks: &[u8], |
755 | num_blocks_w: usize, |
756 | num_blocks_h: usize, |
757 | ) { |
758 | let num_coords = self.n; |
759 | let state = &mut self.latest_state[channel]; |
760 | let a = &mut state.eqns.a; |
761 | let b = &mut state.eqns.b; |
762 | let mut buffer = vec![0f64; num_coords + 1].into_boxed_slice(); |
763 | let n = state.eqns.n; |
764 | let block_w = BLOCK_SIZE >> source.cfg.xdec; |
765 | let block_h = BLOCK_SIZE >> source.cfg.ydec; |
766 | |
767 | let dec = (source.cfg.xdec, source.cfg.ydec); |
768 | let stride = source.cfg.stride; |
769 | let source_origin = source.data_origin(); |
770 | let denoised_origin = denoised.data_origin(); |
771 | let alt_stride = alt_source.map_or(0, |s| s.cfg.stride); |
772 | let alt_source_origin = alt_source.map(|s| s.data_origin()); |
773 | let alt_denoised_origin = alt_denoised.map(|s| s.data_origin()); |
774 | |
775 | for by in 0..num_blocks_h { |
776 | let y_o = by * block_h; |
777 | for bx in 0..num_blocks_w { |
778 | // SAFETY: We know the indexes we provide do not overflow the data bounds |
779 | unsafe { |
780 | let flat_block_ptr = flat_blocks.as_ptr().add(by * num_blocks_w + bx); |
781 | let x_o = bx * block_w; |
782 | if *flat_block_ptr == 0 { |
783 | continue; |
784 | } |
785 | let y_start = if by > 0 && *flat_block_ptr.sub(num_blocks_w) > 0 { |
786 | 0 |
787 | } else { |
788 | NOISE_MODEL_LAG |
789 | }; |
790 | let x_start = if bx > 0 && *flat_block_ptr.sub(1) > 0 { |
791 | 0 |
792 | } else { |
793 | NOISE_MODEL_LAG |
794 | }; |
795 | let y_end = ((frame_dims.1 >> dec.1) - by * block_h).min(block_h); |
796 | let x_end = ((frame_dims.0 >> dec.0) - bx * block_w - NOISE_MODEL_LAG).min( |
797 | if bx + 1 < num_blocks_w && *flat_block_ptr.add(1) > 0 { |
798 | block_w |
799 | } else { |
800 | block_w - NOISE_MODEL_LAG |
801 | }, |
802 | ); |
803 | for y in y_start..y_end { |
804 | for x in x_start..x_end { |
805 | let val = extract_ar_row( |
806 | &self.coords, |
807 | num_coords, |
808 | source_origin, |
809 | denoised_origin, |
810 | stride, |
811 | dec, |
812 | alt_source_origin, |
813 | alt_denoised_origin, |
814 | alt_stride, |
815 | x + x_o, |
816 | y + y_o, |
817 | &mut buffer, |
818 | ); |
819 | for i in 0..n { |
820 | for j in 0..n { |
821 | *a.get_unchecked_mut(i * n + j) += (*buffer.get_unchecked(i) |
822 | * *buffer.get_unchecked(j)) |
823 | / BLOCK_NORMALIZATION.powi(2); |
824 | } |
825 | *b.get_unchecked_mut(i) += |
826 | (*buffer.get_unchecked(i) * val) / BLOCK_NORMALIZATION.powi(2); |
827 | } |
828 | state.num_observations += 1; |
829 | } |
830 | } |
831 | } |
832 | } |
833 | } |
834 | } |
835 | |
836 | #[allow (clippy::too_many_arguments)] |
837 | fn add_noise_std_observations( |
838 | &mut self, |
839 | channel: usize, |
840 | source: &Plane<u8>, |
841 | denoised: &Plane<u8>, |
842 | alt_source: Option<&Plane<u8>>, |
843 | frame_dims: (usize, usize), |
844 | flat_blocks: &[u8], |
845 | num_blocks_w: usize, |
846 | num_blocks_h: usize, |
847 | ) { |
848 | let coeffs = &self.latest_state[channel].eqns.x; |
849 | let num_coords = self.n; |
850 | let luma_gain = self.latest_state[0].ar_gain; |
851 | let noise_gain = self.latest_state[channel].ar_gain; |
852 | let block_w = BLOCK_SIZE >> source.cfg.xdec; |
853 | let block_h = BLOCK_SIZE >> source.cfg.ydec; |
854 | |
855 | for by in 0..num_blocks_h { |
856 | let y_o = by * block_h; |
857 | for bx in 0..num_blocks_w { |
858 | let x_o = bx * block_w; |
859 | if flat_blocks[by * num_blocks_w + bx] == 0 { |
860 | continue; |
861 | } |
862 | let num_samples_h = ((frame_dims.1 >> source.cfg.ydec) - by * block_h).min(block_h); |
863 | let num_samples_w = ((frame_dims.0 >> source.cfg.xdec) - bx * block_w).min(block_w); |
864 | // Make sure that we have a reasonable amount of samples to consider the |
865 | // block |
866 | if num_samples_w * num_samples_h > BLOCK_SIZE { |
867 | let block_mean = get_block_mean( |
868 | alt_source.unwrap_or(source), |
869 | frame_dims, |
870 | x_o << source.cfg.xdec, |
871 | y_o << source.cfg.ydec, |
872 | ); |
873 | let noise_var = get_noise_var( |
874 | source, |
875 | denoised, |
876 | ( |
877 | frame_dims.0 >> source.cfg.xdec, |
878 | frame_dims.1 >> source.cfg.ydec, |
879 | ), |
880 | x_o, |
881 | y_o, |
882 | block_w, |
883 | block_h, |
884 | ); |
885 | // We want to remove the part of the noise that came from being |
886 | // correlated with luma. Note that the noise solver for luma must |
887 | // have already been run. |
888 | let luma_strength = if channel > 0 { |
889 | luma_gain * self.latest_state[0].strength_solver.get_value(block_mean) |
890 | } else { |
891 | 0f64 |
892 | }; |
893 | let corr = if channel > 0 { |
894 | coeffs[num_coords] |
895 | } else { |
896 | 0f64 |
897 | }; |
898 | // Chroma noise: |
899 | // N(0, noise_var) = N(0, uncorr_var) + corr * N(0, luma_strength^2) |
900 | // The uncorrelated component: |
901 | // uncorr_var = noise_var - (corr * luma_strength)^2 |
902 | // But don't allow fully correlated noise (hence the max), since the |
903 | // synthesis cannot model it. |
904 | let uncorr_std = (noise_var / 16f64) |
905 | .max((corr * luma_strength).mul_add(-(corr * luma_strength), noise_var)) |
906 | .sqrt(); |
907 | let adjusted_strength = uncorr_std / noise_gain; |
908 | self.latest_state[channel] |
909 | .strength_solver |
910 | .add_measurement(block_mean, adjusted_strength); |
911 | } |
912 | } |
913 | } |
914 | } |
915 | } |
916 | |
917 | #[derive (Debug, Clone)] |
918 | struct NoiseModelState { |
919 | eqns: EquationSystem, |
920 | ar_gain: f64, |
921 | num_observations: usize, |
922 | strength_solver: StrengthSolver, |
923 | } |
924 | |
925 | impl NoiseModelState { |
926 | #[must_use ] |
927 | pub fn new(n: usize) -> Self { |
928 | const NUM_BINS: usize = 20; |
929 | |
930 | Self { |
931 | eqns: EquationSystem::new(n), |
932 | ar_gain: 1.0f64, |
933 | num_observations: 0usize, |
934 | strength_solver: StrengthSolver::new(NUM_BINS), |
935 | } |
936 | } |
937 | |
938 | pub fn ar_equation_system_solve(&mut self, is_chroma: bool) -> bool { |
939 | let ret = self.eqns.solve(); |
940 | self.ar_gain = 1.0f64; |
941 | if !ret { |
942 | return ret; |
943 | } |
944 | |
945 | // Update the AR gain from the equation system as it will be used to fit |
946 | // the noise strength as a function of intensity. In the Yule-Walker |
947 | // equations, the diagonal should be the variance of the correlated noise. |
948 | // In the case of the least squares estimate, there will be some variability |
949 | // in the diagonal. So use the mean of the diagonal as the estimate of |
950 | // overall variance (this works for least squares or Yule-Walker formulation). |
951 | let mut var = 0f64; |
952 | let n_adjusted = self.eqns.n - usize::from(is_chroma); |
953 | for i in 0..n_adjusted { |
954 | var += self.eqns.a[i * self.eqns.n + i] / self.num_observations as f64; |
955 | } |
956 | var /= n_adjusted as f64; |
957 | |
958 | // Keep track of E(Y^2) = <b, x> + E(X^2) |
959 | // In the case that we are using chroma and have an estimate of correlation |
960 | // with luma we adjust that estimate slightly to remove the correlated bits by |
961 | // subtracting out the last column of a scaled by our correlation estimate |
962 | // from b. E(y^2) = <b - A(:, end)*x(end), x> |
963 | let mut sum_covar = 0f64; |
964 | for i in 0..n_adjusted { |
965 | let mut bi = self.eqns.b[i]; |
966 | if is_chroma { |
967 | bi -= self.eqns.a[i * self.eqns.n + n_adjusted] * self.eqns.x[n_adjusted]; |
968 | } |
969 | sum_covar += (bi * self.eqns.x[i]) / self.num_observations as f64; |
970 | } |
971 | |
972 | // Now, get an estimate of the variance of uncorrelated noise signal and use |
973 | // it to determine the gain of the AR filter. |
974 | let noise_var = (var - sum_covar).max(1e-6f64); |
975 | self.ar_gain = 1f64.max((var / noise_var).max(1e-6f64).sqrt()); |
976 | ret |
977 | } |
978 | } |
979 | |
980 | #[derive (Debug, Clone)] |
981 | struct StrengthSolver { |
982 | eqns: EquationSystem, |
983 | num_bins: usize, |
984 | num_equations: usize, |
985 | total: f64, |
986 | } |
987 | |
988 | impl StrengthSolver { |
989 | #[must_use ] |
990 | pub fn new(num_bins: usize) -> Self { |
991 | Self { |
992 | eqns: EquationSystem::new(num_bins), |
993 | num_bins, |
994 | num_equations: 0usize, |
995 | total: 0f64, |
996 | } |
997 | } |
998 | |
999 | pub fn add_measurement(&mut self, block_mean: f64, noise_std: f64) { |
1000 | let bin = self.get_bin_index(block_mean); |
1001 | let bin_i0 = bin.floor() as usize; |
1002 | let bin_i1 = (self.num_bins - 1).min(bin_i0 + 1); |
1003 | let a = bin - bin_i0 as f64; |
1004 | let n = self.num_bins; |
1005 | let eqns = &mut self.eqns; |
1006 | eqns.a[bin_i0 * n + bin_i0] += (1f64 - a).powi(2); |
1007 | eqns.a[bin_i1 * n + bin_i0] += a * (1f64 - a); |
1008 | eqns.a[bin_i1 * n + bin_i1] += a.powi(2); |
1009 | eqns.a[bin_i0 * n + bin_i1] += (1f64 - a) * a; |
1010 | eqns.b[bin_i0] += (1f64 - a) * noise_std; |
1011 | eqns.b[bin_i1] += a * noise_std; |
1012 | self.total += noise_std; |
1013 | self.num_equations += 1; |
1014 | } |
1015 | |
1016 | pub fn solve(&mut self) -> bool { |
1017 | // Add regularization proportional to the number of constraints |
1018 | let n = self.num_bins; |
1019 | let alpha = 2f64 * self.num_equations as f64 / n as f64; |
1020 | |
1021 | // Do this in a non-destructive manner so it is not confusing to the caller |
1022 | let old_a = self.eqns.a.clone(); |
1023 | for i in 0..n { |
1024 | let i_lo = i.saturating_sub(1); |
1025 | let i_hi = (n - 1).min(i + 1); |
1026 | self.eqns.a[i * n + i_lo] -= alpha; |
1027 | self.eqns.a[i * n + i] += 2f64 * alpha; |
1028 | self.eqns.a[i * n + i_hi] -= alpha; |
1029 | } |
1030 | |
1031 | // Small regularization to give average noise strength |
1032 | let mean = self.total / self.num_equations as f64; |
1033 | for i in 0..n { |
1034 | self.eqns.a[i * n + i] += 1f64 / 8192f64; |
1035 | self.eqns.b[i] += mean / 8192f64; |
1036 | } |
1037 | let result = self.eqns.solve(); |
1038 | self.eqns.a = old_a; |
1039 | result |
1040 | } |
1041 | |
1042 | #[must_use ] |
1043 | pub fn fit_piecewise(&self, max_output_points: usize) -> NoiseStrengthLut { |
1044 | const TOLERANCE: f64 = 0.00625f64; |
1045 | |
1046 | let mut lut = NoiseStrengthLut::new(self.num_bins); |
1047 | for i in 0..self.num_bins { |
1048 | lut.points[i][0] = self.get_center(i); |
1049 | lut.points[i][1] = self.eqns.x[i]; |
1050 | } |
1051 | |
1052 | let mut residual = vec![0.0f64; self.num_bins]; |
1053 | self.update_piecewise_linear_residual(&lut, &mut residual, 0, self.num_bins); |
1054 | |
1055 | // Greedily remove points if there are too many or if it doesn't hurt local |
1056 | // approximation (never remove the end points) |
1057 | while lut.points.len() > 2 { |
1058 | let mut min_index = 1usize; |
1059 | for j in 1..(lut.points.len() - 1) { |
1060 | if residual[j] < residual[min_index] { |
1061 | min_index = j; |
1062 | } |
1063 | } |
1064 | let dx = lut.points[min_index + 1][0] - lut.points[min_index - 1][0]; |
1065 | let avg_residual = residual[min_index] / dx; |
1066 | if lut.points.len() <= max_output_points && avg_residual > TOLERANCE { |
1067 | break; |
1068 | } |
1069 | |
1070 | lut.points.remove(min_index); |
1071 | self.update_piecewise_linear_residual( |
1072 | &lut, |
1073 | &mut residual, |
1074 | min_index - 1, |
1075 | min_index + 1, |
1076 | ); |
1077 | } |
1078 | |
1079 | lut |
1080 | } |
1081 | |
1082 | #[must_use ] |
1083 | pub fn get_value(&self, x: f64) -> f64 { |
1084 | let bin = self.get_bin_index(x); |
1085 | let bin_i0 = bin.floor() as usize; |
1086 | let bin_i1 = (self.num_bins - 1).min(bin_i0 + 1); |
1087 | let a = bin - bin_i0 as f64; |
1088 | (1f64 - a).mul_add(self.eqns.x[bin_i0], a * self.eqns.x[bin_i1]) |
1089 | } |
1090 | |
1091 | pub fn clear(&mut self) { |
1092 | self.eqns.clear(); |
1093 | self.num_equations = 0; |
1094 | self.total = 0f64; |
1095 | } |
1096 | |
1097 | #[must_use ] |
1098 | fn get_bin_index(&self, value: f64) -> f64 { |
1099 | let max = 255f64; |
1100 | let val = clamp(value, 0f64, max); |
1101 | (self.num_bins - 1) as f64 * val / max |
1102 | } |
1103 | |
1104 | fn update_piecewise_linear_residual( |
1105 | &self, |
1106 | lut: &NoiseStrengthLut, |
1107 | residual: &mut [f64], |
1108 | start: usize, |
1109 | end: usize, |
1110 | ) { |
1111 | let dx = 255f64 / self.num_bins as f64; |
1112 | #[allow (clippy::needless_range_loop)] |
1113 | for i in start.max(1)..end.min(lut.points.len() - 1) { |
1114 | let lower = 0usize.max(self.get_bin_index(lut.points[i - 1][0]).floor() as usize); |
1115 | let upper = |
1116 | (self.num_bins - 1).min(self.get_bin_index(lut.points[i + 1][0]).ceil() as usize); |
1117 | let mut r = 0f64; |
1118 | for j in lower..=upper { |
1119 | let x = self.get_center(j); |
1120 | if x < lut.points[i - 1][0] || x >= lut.points[i + 1][0] { |
1121 | continue; |
1122 | } |
1123 | |
1124 | let y = self.eqns.x[j]; |
1125 | let a = (x - lut.points[i - 1][0]) / (lut.points[i + 1][0] - lut.points[i - 1][0]); |
1126 | let estimate_y = lut.points[i - 1][1].mul_add(1f64 - a, lut.points[i + 1][1] * a); |
1127 | r += (y - estimate_y).abs(); |
1128 | } |
1129 | residual[i] = r * dx; |
1130 | } |
1131 | } |
1132 | |
1133 | #[must_use ] |
1134 | fn get_center(&self, i: usize) -> f64 { |
1135 | let range = 255f64; |
1136 | let n = self.num_bins; |
1137 | i as f64 / (n - 1) as f64 * range |
1138 | } |
1139 | } |
1140 | |
1141 | impl Add<&StrengthSolver> for StrengthSolver { |
1142 | type Output = StrengthSolver; |
1143 | |
1144 | #[must_use ] |
1145 | fn add(self, addend: &StrengthSolver) -> Self::Output { |
1146 | let mut dest: StrengthSolver = self; |
1147 | dest.eqns += &addend.eqns; |
1148 | dest.num_equations += addend.num_equations; |
1149 | dest.total += addend.total; |
1150 | dest |
1151 | } |
1152 | } |
1153 | |
1154 | impl AddAssign<&StrengthSolver> for StrengthSolver { |
1155 | fn add_assign(&mut self, rhs: &StrengthSolver) { |
1156 | *self = self.clone() + rhs; |
1157 | } |
1158 | } |
1159 | |