1mod util;
2
3use std::ops::{Add, AddAssign};
4
5use anyhow::anyhow;
6use arrayvec::ArrayVec;
7use v_frame::{frame::Frame, math::clamp, plane::Plane};
8
9use self::util::{extract_ar_row, get_block_mean, get_noise_var, linsolve, multiply_mat};
10use super::{NoiseStatus, BLOCK_SIZE, BLOCK_SIZE_SQUARED};
11use crate::{
12 diff::solver::util::normalized_cross_correlation, GrainTableSegment, DEFAULT_GRAIN_SEED,
13 NUM_UV_COEFFS, NUM_UV_POINTS, NUM_Y_COEFFS, NUM_Y_POINTS,
14};
15
16const LOW_POLY_NUM_PARAMS: usize = 3;
17const NOISE_MODEL_LAG: usize = 3;
18const BLOCK_NORMALIZATION: f64 = 255.0f64;
19
20#[derive(Debug, Clone, Copy)]
21pub(super) struct FlatBlockFinder {
22 a: [f64; LOW_POLY_NUM_PARAMS * BLOCK_SIZE_SQUARED],
23 a_t_a_inv: [f64; LOW_POLY_NUM_PARAMS * LOW_POLY_NUM_PARAMS],
24}
25
26impl FlatBlockFinder {
27 #[must_use]
28 pub fn new() -> Self {
29 let mut eqns = EquationSystem::new(LOW_POLY_NUM_PARAMS);
30 let mut a_t_a_inv = [0.0f64; LOW_POLY_NUM_PARAMS * LOW_POLY_NUM_PARAMS];
31 let mut a = [0.0f64; LOW_POLY_NUM_PARAMS * BLOCK_SIZE_SQUARED];
32
33 let bs_half = (BLOCK_SIZE / 2) as f64;
34 (0..BLOCK_SIZE).for_each(|y| {
35 let yd = (y as f64 - bs_half) / bs_half;
36 (0..BLOCK_SIZE).for_each(|x| {
37 let xd = (x as f64 - bs_half) / bs_half;
38 let coords = [yd, xd, 1.0f64];
39 let row = y * BLOCK_SIZE + x;
40 a[LOW_POLY_NUM_PARAMS * row] = yd;
41 a[LOW_POLY_NUM_PARAMS * row + 1] = xd;
42 a[LOW_POLY_NUM_PARAMS * row + 2] = 1.0f64;
43
44 (0..LOW_POLY_NUM_PARAMS).for_each(|i| {
45 (0..LOW_POLY_NUM_PARAMS).for_each(|j| {
46 eqns.a[LOW_POLY_NUM_PARAMS * i + j] += coords[i] * coords[j];
47 });
48 });
49 });
50 });
51
52 // Lazy inverse using existing equation solver.
53 (0..LOW_POLY_NUM_PARAMS).for_each(|i| {
54 eqns.b.fill(0.0f64);
55 eqns.b[i] = 1.0f64;
56 eqns.solve();
57
58 (0..LOW_POLY_NUM_PARAMS).for_each(|j| {
59 a_t_a_inv[j * LOW_POLY_NUM_PARAMS + i] = eqns.x[j];
60 });
61 });
62
63 FlatBlockFinder { a, a_t_a_inv }
64 }
65
66 // The gradient-based features used in this code are based on:
67 // A. Kokaram, D. Kelly, H. Denman and A. Crawford, "Measuring noise
68 // correlation for improved video denoising," 2012 19th, ICIP.
69 // The thresholds are more lenient to allow for correct grain modeling
70 // in extreme cases.
71 #[must_use]
72 #[allow(clippy::too_many_lines)]
73 pub fn run(&self, plane: &Plane<u8>) -> (Vec<u8>, usize) {
74 const TRACE_THRESHOLD: f64 = 0.15f64 / BLOCK_SIZE_SQUARED as f64;
75 const RATIO_THRESHOLD: f64 = 1.25f64;
76 const NORM_THRESHOLD: f64 = 0.08f64 / BLOCK_SIZE_SQUARED as f64;
77 const VAR_THRESHOLD: f64 = 0.005f64 / BLOCK_SIZE_SQUARED as f64;
78
79 // The following weights are used to combine the above features to give
80 // a sigmoid score for flatness. If the input was normalized to [0,100]
81 // the magnitude of these values would be close to 1 (e.g., weights
82 // corresponding to variance would be a factor of 10000x smaller).
83 const VAR_WEIGHT: f64 = -6682f64;
84 const RATIO_WEIGHT: f64 = -0.2056f64;
85 const TRACE_WEIGHT: f64 = 13087f64;
86 const NORM_WEIGHT: f64 = -12434f64;
87 const OFFSET: f64 = 2.5694f64;
88
89 let num_blocks_w = (plane.cfg.width + BLOCK_SIZE - 1) / BLOCK_SIZE;
90 let num_blocks_h = (plane.cfg.height + BLOCK_SIZE - 1) / BLOCK_SIZE;
91 let num_blocks = num_blocks_w * num_blocks_h;
92 let mut flat_blocks = vec![0u8; num_blocks];
93 let mut num_flat = 0;
94 let mut plane_result = [0.0f64; BLOCK_SIZE_SQUARED];
95 let mut block_result = [0.0f64; BLOCK_SIZE_SQUARED];
96 let mut scores = vec![IndexAndScore::default(); num_blocks];
97
98 for by in 0..num_blocks_h {
99 for bx in 0..num_blocks_w {
100 // Compute gradient covariance matrix.
101 let mut gxx = 0f64;
102 let mut gxy = 0f64;
103 let mut gyy = 0f64;
104 let mut var = 0f64;
105 let mut mean = 0f64;
106
107 self.extract_block(
108 plane,
109 bx * BLOCK_SIZE,
110 by * BLOCK_SIZE,
111 &mut plane_result,
112 &mut block_result,
113 );
114 for yi in 1..(BLOCK_SIZE - 1) {
115 for xi in 1..(BLOCK_SIZE - 1) {
116 // SAFETY: We know the size of `block_result` and that we cannot exceed the bounds of it
117 unsafe {
118 let result_ptr = block_result.as_ptr().add(yi * BLOCK_SIZE + xi);
119
120 let gx = (*result_ptr.add(1) - *result_ptr.sub(1)) / 2f64;
121 let gy =
122 (*result_ptr.add(BLOCK_SIZE) - *result_ptr.sub(BLOCK_SIZE)) / 2f64;
123 gxx += gx * gx;
124 gxy += gx * gy;
125 gyy += gy * gy;
126
127 let block_val = *result_ptr;
128 mean += block_val;
129 var += block_val * block_val;
130 }
131 }
132 }
133 let block_size_norm_factor = (BLOCK_SIZE - 2).pow(2) as f64;
134 mean /= block_size_norm_factor;
135
136 // Normalize gradients by block_size.
137 gxx /= block_size_norm_factor;
138 gxy /= block_size_norm_factor;
139 gyy /= block_size_norm_factor;
140 var = mean.mul_add(-mean, var / block_size_norm_factor);
141
142 let trace = gxx + gyy;
143 let det = gxx.mul_add(gyy, -gxy.powi(2));
144 let e_sub = (trace.mul_add(trace, -4f64 * det)).max(0.).sqrt();
145 let e1 = (trace + e_sub) / 2.0f64;
146 let e2 = (trace - e_sub) / 2.0f64;
147 // Spectral norm
148 let norm = e1;
149 let ratio = e1 / e2.max(1.0e-6_f64);
150 let is_flat = trace < TRACE_THRESHOLD
151 && ratio < RATIO_THRESHOLD
152 && norm < NORM_THRESHOLD
153 && var > VAR_THRESHOLD;
154
155 let sum_weights = NORM_WEIGHT.mul_add(
156 norm,
157 TRACE_WEIGHT.mul_add(
158 trace,
159 VAR_WEIGHT.mul_add(var, RATIO_WEIGHT.mul_add(ratio, OFFSET)),
160 ),
161 );
162 // clamp the value to [-25.0, 100.0] to prevent overflow
163 let sum_weights = clamp(sum_weights, -25.0f64, 100.0f64);
164 let score = (1.0f64 / (1.0f64 + (-sum_weights).exp())) as f32;
165 // SAFETY: We know the size of `flat_blocks` and `scores` and that we cannot exceed the bounds of it
166 unsafe {
167 let index = by * num_blocks_w + bx;
168 *flat_blocks.get_unchecked_mut(index) = if is_flat { 255 } else { 0 };
169 *scores.get_unchecked_mut(index) = IndexAndScore {
170 score: if var > VAR_THRESHOLD { score } else { 0f32 },
171 index,
172 };
173 }
174 if is_flat {
175 num_flat += 1;
176 }
177 }
178 }
179
180 scores.sort_unstable_by(|a, b| a.score.partial_cmp(&b.score).expect("Shouldn't be NaN"));
181 // SAFETY: We know the size of `flat_blocks` and `scores` and that we cannot exceed the bounds of it
182 unsafe {
183 let top_nth_percentile = num_blocks * 90 / 100;
184 let score_threshold = scores.get_unchecked(top_nth_percentile).score;
185 for score in &scores {
186 if score.score >= score_threshold {
187 let block_ref = flat_blocks.get_unchecked_mut(score.index);
188 if *block_ref == 0 {
189 num_flat += 1;
190 }
191 *block_ref |= 1;
192 }
193 }
194 }
195
196 (flat_blocks, num_flat)
197 }
198
199 fn extract_block(
200 &self,
201 plane: &Plane<u8>,
202 offset_x: usize,
203 offset_y: usize,
204 plane_result: &mut [f64; BLOCK_SIZE_SQUARED],
205 block_result: &mut [f64; BLOCK_SIZE_SQUARED],
206 ) {
207 let mut plane_coords = [0f64; LOW_POLY_NUM_PARAMS];
208 let mut a_t_a_inv_b = [0f64; LOW_POLY_NUM_PARAMS];
209 let plane_origin = plane.data_origin();
210
211 for yi in 0..BLOCK_SIZE {
212 let y = clamp(offset_y + yi, 0, plane.cfg.height - 1);
213 for xi in 0..BLOCK_SIZE {
214 let x = clamp(offset_x + xi, 0, plane.cfg.width - 1);
215 // SAFETY: We know the bounds of the plane data and `block_result`
216 // and do not exceed them.
217 unsafe {
218 *block_result.get_unchecked_mut(yi * BLOCK_SIZE + xi) =
219 f64::from(*plane_origin.get_unchecked(y * plane.cfg.stride + x))
220 / BLOCK_NORMALIZATION;
221 }
222 }
223 }
224
225 multiply_mat(
226 block_result,
227 &self.a,
228 &mut a_t_a_inv_b,
229 1,
230 BLOCK_SIZE_SQUARED,
231 LOW_POLY_NUM_PARAMS,
232 );
233 multiply_mat(
234 &self.a_t_a_inv,
235 &a_t_a_inv_b,
236 &mut plane_coords,
237 LOW_POLY_NUM_PARAMS,
238 LOW_POLY_NUM_PARAMS,
239 1,
240 );
241 multiply_mat(
242 &self.a,
243 &plane_coords,
244 plane_result,
245 BLOCK_SIZE_SQUARED,
246 LOW_POLY_NUM_PARAMS,
247 1,
248 );
249
250 for (block_res, plane_res) in block_result.iter_mut().zip(plane_result.iter()) {
251 *block_res -= *plane_res;
252 }
253 }
254}
255
256#[derive(Debug, Clone, Copy, Default)]
257struct IndexAndScore {
258 pub index: usize,
259 pub score: f32,
260}
261
262/// Wrapper of data required to represent linear system of eqns and soln.
263#[derive(Debug, Clone)]
264struct EquationSystem {
265 a: Vec<f64>,
266 b: Vec<f64>,
267 x: Vec<f64>,
268 n: usize,
269}
270
271impl EquationSystem {
272 #[must_use]
273 pub fn new(n: usize) -> Self {
274 Self {
275 a: vec![0.0f64; n * n],
276 b: vec![0.0f64; n],
277 x: vec![0.0f64; n],
278 n,
279 }
280 }
281
282 pub fn solve(&mut self) -> bool {
283 let n = self.n;
284 let mut a = self.a.clone();
285 let mut b = self.b.clone();
286
287 linsolve(n, &mut a, self.n, &mut b, &mut self.x)
288 }
289
290 pub fn set_chroma_coefficient_fallback_solution(&mut self) {
291 const TOLERANCE: f64 = 1.0e-6f64;
292 let last = self.n - 1;
293 // Set all of the AR coefficients to zero, but try to solve for correlation
294 // with the luma channel
295 self.x.fill(0f64);
296 if self.a[last * self.n + last].abs() > TOLERANCE {
297 self.x[last] = self.b[last] / self.a[last * self.n + last];
298 }
299 }
300
301 pub fn copy_from(&mut self, other: &Self) {
302 assert_eq!(self.n, other.n);
303 self.a.copy_from_slice(&other.a);
304 self.x.copy_from_slice(&other.x);
305 self.b.copy_from_slice(&other.b);
306 }
307
308 pub fn clear(&mut self) {
309 self.a.fill(0f64);
310 self.b.fill(0f64);
311 self.x.fill(0f64);
312 }
313}
314
315impl Add<&EquationSystem> for EquationSystem {
316 type Output = EquationSystem;
317
318 #[must_use]
319 fn add(self, addend: &EquationSystem) -> Self::Output {
320 let mut dest: EquationSystem = self.clone();
321 let n: usize = self.n;
322 for i: usize in 0..n {
323 for j: usize in 0..n {
324 dest.a[i * n + j] += addend.a[i * n + j];
325 }
326 dest.b[i] += addend.b[i];
327 }
328 dest
329 }
330}
331
332impl AddAssign<&EquationSystem> for EquationSystem {
333 fn add_assign(&mut self, rhs: &EquationSystem) {
334 *self = self.clone() + rhs;
335 }
336}
337
338/// Representation of a piecewise linear curve
339///
340/// Holds n points as (x, y) pairs, that store the curve.
341struct NoiseStrengthLut {
342 points: Vec<[f64; 2]>,
343}
344
345impl NoiseStrengthLut {
346 #[must_use]
347 pub fn new(num_bins: usize) -> Self {
348 assert!(num_bins > 0);
349 Self {
350 points: vec![[0f64; 2]; num_bins],
351 }
352 }
353}
354
355#[derive(Debug, Clone)]
356pub(super) struct NoiseModel {
357 combined_state: [NoiseModelState; 3],
358 latest_state: [NoiseModelState; 3],
359 n: usize,
360 coords: Vec<[isize; 2]>,
361}
362
363impl NoiseModel {
364 #[must_use]
365 pub fn new() -> Self {
366 let n = Self::num_coeffs();
367 let combined_state = [
368 NoiseModelState::new(n),
369 NoiseModelState::new(n + 1),
370 NoiseModelState::new(n + 1),
371 ];
372 let latest_state = [
373 NoiseModelState::new(n),
374 NoiseModelState::new(n + 1),
375 NoiseModelState::new(n + 1),
376 ];
377 let mut coords = Vec::new();
378
379 let neg_lag = -(NOISE_MODEL_LAG as isize);
380 for y in neg_lag..=0 {
381 let max_x = if y == 0 {
382 -1isize
383 } else {
384 NOISE_MODEL_LAG as isize
385 };
386 for x in neg_lag..=max_x {
387 coords.push([x, y]);
388 }
389 }
390 assert!(n == coords.len());
391
392 Self {
393 combined_state,
394 latest_state,
395 n,
396 coords,
397 }
398 }
399
400 pub fn update(
401 &mut self,
402 source: &Frame<u8>,
403 denoised: &Frame<u8>,
404 flat_blocks: &[u8],
405 ) -> NoiseStatus {
406 let num_blocks_w = (source.planes[0].cfg.width + BLOCK_SIZE - 1) / BLOCK_SIZE;
407 let num_blocks_h = (source.planes[0].cfg.height + BLOCK_SIZE - 1) / BLOCK_SIZE;
408 let mut y_model_different = false;
409
410 // Clear the latest equation system
411 for i in 0..3 {
412 self.latest_state[i].eqns.clear();
413 self.latest_state[i].num_observations = 0;
414 self.latest_state[i].strength_solver.clear();
415 }
416
417 // Check that we have enough flat blocks
418 let num_blocks = flat_blocks.iter().filter(|b| **b > 0).count();
419 if num_blocks <= 1 {
420 return NoiseStatus::Error(anyhow!("Not enough flat blocks to update noise estimate"));
421 }
422
423 let frame_dims = (source.planes[0].cfg.width, source.planes[0].cfg.height);
424 for channel in 0..3 {
425 if source.planes[channel].data.is_empty() {
426 // Monochrome source
427 break;
428 }
429 let is_chroma = channel > 0;
430 let alt_source = (channel > 0).then(|| &source.planes[0]);
431 let alt_denoised = (channel > 0).then(|| &denoised.planes[0]);
432 self.add_block_observations(
433 channel,
434 &source.planes[channel],
435 &denoised.planes[channel],
436 alt_source,
437 alt_denoised,
438 frame_dims,
439 flat_blocks,
440 num_blocks_w,
441 num_blocks_h,
442 );
443 if !self.latest_state[channel].ar_equation_system_solve(is_chroma) {
444 if is_chroma {
445 self.latest_state[channel]
446 .eqns
447 .set_chroma_coefficient_fallback_solution();
448 } else {
449 return NoiseStatus::Error(anyhow!(
450 "Solving latest noise equation system failed on plane {}",
451 channel
452 ));
453 }
454 }
455 self.add_noise_std_observations(
456 channel,
457 &source.planes[channel],
458 &denoised.planes[channel],
459 alt_source,
460 frame_dims,
461 flat_blocks,
462 num_blocks_w,
463 num_blocks_h,
464 );
465 if !self.latest_state[channel].strength_solver.solve() {
466 return NoiseStatus::Error(anyhow!(
467 "Failed to solve strength solver for latest state"
468 ));
469 }
470
471 // Check noise characteristics and return if error
472 if channel == 0
473 && self.combined_state[channel].strength_solver.num_equations > 0
474 && self.is_different()
475 {
476 y_model_different = true;
477 }
478
479 if y_model_different {
480 continue;
481 }
482
483 self.combined_state[channel].num_observations +=
484 self.latest_state[channel].num_observations;
485 self.combined_state[channel].eqns += &self.latest_state[channel].eqns;
486 if !self.combined_state[channel].ar_equation_system_solve(is_chroma) {
487 if is_chroma {
488 self.combined_state[channel]
489 .eqns
490 .set_chroma_coefficient_fallback_solution();
491 } else {
492 return NoiseStatus::Error(anyhow!(
493 "Solving combined noise equation system failed on plane {}",
494 channel
495 ));
496 }
497 }
498
499 self.combined_state[channel].strength_solver +=
500 &self.latest_state[channel].strength_solver;
501
502 if !self.combined_state[channel].strength_solver.solve() {
503 return NoiseStatus::Error(anyhow!(
504 "Failed to solve strength solver for combined state"
505 ));
506 };
507 }
508
509 if y_model_different {
510 return NoiseStatus::DifferentType;
511 }
512
513 NoiseStatus::Ok
514 }
515
516 #[allow(clippy::too_many_lines)]
517 #[must_use]
518 pub fn get_grain_parameters(&self, start_ts: u64, end_ts: u64) -> GrainTableSegment {
519 // Both the domain and the range of the scaling functions in the film_grain
520 // are normalized to 8-bit (e.g., they are implicitly scaled during grain
521 // synthesis).
522 let scaling_points_y = self.combined_state[0]
523 .strength_solver
524 .fit_piecewise(NUM_Y_POINTS)
525 .points;
526 let scaling_points_cb = self.combined_state[1]
527 .strength_solver
528 .fit_piecewise(NUM_UV_POINTS)
529 .points;
530 let scaling_points_cr = self.combined_state[2]
531 .strength_solver
532 .fit_piecewise(NUM_UV_POINTS)
533 .points;
534
535 let mut max_scaling_value: f64 = 1.0e-4f64;
536 for p in scaling_points_y
537 .iter()
538 .chain(scaling_points_cb.iter())
539 .chain(scaling_points_cr.iter())
540 .map(|p| p[1])
541 {
542 if p > max_scaling_value {
543 max_scaling_value = p;
544 }
545 }
546
547 // Scaling_shift values are in the range [8,11]
548 let max_scaling_value_log2 =
549 clamp((max_scaling_value.log2() + 1f64).floor() as u8, 2u8, 5u8);
550 let scale_factor = f64::from(1u32 << (8u8 - max_scaling_value_log2));
551 let map_scaling_point = |p: [f64; 2]| {
552 [
553 (p[0] + 0.5f64) as u8,
554 clamp(scale_factor.mul_add(p[1], 0.5f64) as i32, 0i32, 255i32) as u8,
555 ]
556 };
557
558 let scaling_points_y: ArrayVec<_, NUM_Y_POINTS> = scaling_points_y
559 .into_iter()
560 .map(map_scaling_point)
561 .collect();
562 let scaling_points_cb: ArrayVec<_, NUM_UV_POINTS> = scaling_points_cb
563 .into_iter()
564 .map(map_scaling_point)
565 .collect();
566 let scaling_points_cr: ArrayVec<_, NUM_UV_POINTS> = scaling_points_cr
567 .into_iter()
568 .map(map_scaling_point)
569 .collect();
570
571 // Convert the ar_coeffs into 8-bit values
572 let n_coeff = self.combined_state[0].eqns.n;
573 let mut max_coeff = 1.0e-4f64;
574 let mut min_coeff = 1.0e-4f64;
575 let mut y_corr = [0f64; 2];
576 let mut avg_luma_strength = 0f64;
577 for c in 0..3 {
578 let eqns = &self.combined_state[c].eqns;
579 for i in 0..n_coeff {
580 if eqns.x[i] > max_coeff {
581 max_coeff = eqns.x[i];
582 }
583 if eqns.x[i] < min_coeff {
584 min_coeff = eqns.x[i];
585 }
586 }
587
588 // Since the correlation between luma/chroma was computed in an already
589 // scaled space, we adjust it in the un-scaled space.
590 let solver = &self.combined_state[c].strength_solver;
591 // Compute a weighted average of the strength for the channel.
592 let mut average_strength = 0f64;
593 let mut total_weight = 0f64;
594 for i in 0..solver.eqns.n {
595 let mut w = 0f64;
596 for j in 0..solver.eqns.n {
597 w += solver.eqns.a[i * solver.eqns.n + j];
598 }
599 w = w.sqrt();
600 average_strength += solver.eqns.x[i] * w;
601 total_weight += w;
602 }
603 if total_weight.abs() < f64::EPSILON {
604 average_strength = 1f64;
605 } else {
606 average_strength /= total_weight;
607 }
608 if c == 0 {
609 avg_luma_strength = average_strength;
610 } else {
611 y_corr[c - 1] = avg_luma_strength * eqns.x[n_coeff] / average_strength;
612 max_coeff = max_coeff.max(y_corr[c - 1]);
613 min_coeff = min_coeff.min(y_corr[c - 1]);
614 }
615 }
616
617 // Shift value: AR coeffs range (values 6-9)
618 // 6: [-2, 2), 7: [-1, 1), 8: [-0.5, 0.5), 9: [-0.25, 0.25)
619 let ar_coeff_shift = clamp(
620 7i32 - (1.0f64 + max_coeff.log2().floor()).max((-min_coeff).log2().ceil()) as i32,
621 6i32,
622 9i32,
623 ) as u8;
624 let scale_ar_coeff = f64::from(1u16 << ar_coeff_shift);
625 let ar_coeffs_y = self.get_ar_coeffs_y(n_coeff, scale_ar_coeff);
626 let ar_coeffs_cb = self.get_ar_coeffs_uv(1, n_coeff, scale_ar_coeff, y_corr);
627 let ar_coeffs_cr = self.get_ar_coeffs_uv(2, n_coeff, scale_ar_coeff, y_corr);
628
629 GrainTableSegment {
630 random_seed: if start_ts == 0 { DEFAULT_GRAIN_SEED } else { 0 },
631 start_time: start_ts,
632 end_time: end_ts,
633 ar_coeff_lag: NOISE_MODEL_LAG as u8,
634 scaling_points_y,
635 scaling_points_cb,
636 scaling_points_cr,
637 scaling_shift: 5 + (8 - max_scaling_value_log2),
638 ar_coeff_shift,
639 ar_coeffs_y,
640 ar_coeffs_cb,
641 ar_coeffs_cr,
642 // At the moment, the noise modeling code assumes that the chroma scaling
643 // functions are a function of luma.
644 cb_mult: 128,
645 cb_luma_mult: 192,
646 cb_offset: 256,
647 cr_mult: 128,
648 cr_luma_mult: 192,
649 cr_offset: 256,
650 chroma_scaling_from_luma: false,
651 grain_scale_shift: 0,
652 overlap_flag: true,
653 }
654 }
655
656 pub fn save_latest(&mut self) {
657 for c in 0..3 {
658 let latest_state = &self.latest_state[c];
659 let combined_state = &mut self.combined_state[c];
660 combined_state.eqns.copy_from(&latest_state.eqns);
661 combined_state
662 .strength_solver
663 .eqns
664 .copy_from(&latest_state.strength_solver.eqns);
665 combined_state.strength_solver.num_equations =
666 latest_state.strength_solver.num_equations;
667 combined_state.num_observations = latest_state.num_observations;
668 combined_state.ar_gain = latest_state.ar_gain;
669 }
670 }
671
672 #[must_use]
673 const fn num_coeffs() -> usize {
674 let n = 2 * NOISE_MODEL_LAG + 1;
675 (n * n) / 2
676 }
677
678 #[must_use]
679 fn get_ar_coeffs_y(&self, n_coeff: usize, scale_ar_coeff: f64) -> ArrayVec<i8, NUM_Y_COEFFS> {
680 assert!(n_coeff <= NUM_Y_COEFFS);
681 let mut coeffs = ArrayVec::new();
682 let eqns = &self.combined_state[0].eqns;
683 for i in 0..n_coeff {
684 coeffs.push(clamp((scale_ar_coeff * eqns.x[i]).round() as i32, -128i32, 127i32) as i8);
685 }
686 coeffs
687 }
688
689 #[must_use]
690 fn get_ar_coeffs_uv(
691 &self,
692 channel: usize,
693 n_coeff: usize,
694 scale_ar_coeff: f64,
695 y_corr: [f64; 2],
696 ) -> ArrayVec<i8, NUM_UV_COEFFS> {
697 assert!(n_coeff <= NUM_Y_COEFFS);
698 let mut coeffs = ArrayVec::new();
699 let eqns = &self.combined_state[channel].eqns;
700 for i in 0..n_coeff {
701 coeffs.push(clamp((scale_ar_coeff * eqns.x[i]).round() as i32, -128i32, 127i32) as i8);
702 }
703 coeffs.push(clamp(
704 (scale_ar_coeff * y_corr[channel - 1]).round() as i32,
705 -128i32,
706 127i32,
707 ) as i8);
708 coeffs
709 }
710
711 // Return true if the noise estimate appears to be different from the combined
712 // (multi-frame) estimate. The difference is measured by checking whether the
713 // AR coefficients have diverged (using a threshold on normalized cross
714 // correlation), or whether the noise strength has changed.
715 #[must_use]
716 fn is_different(&self) -> bool {
717 const COEFF_THRESHOLD: f64 = 0.9f64;
718 const STRENGTH_THRESHOLD: f64 = 0.005f64;
719
720 let latest = &self.latest_state[0];
721 let combined = &self.combined_state[0];
722 let corr = normalized_cross_correlation(&latest.eqns.x, &combined.eqns.x, combined.eqns.n);
723 if corr < COEFF_THRESHOLD {
724 return true;
725 }
726
727 let dx = 1.0f64 / latest.strength_solver.num_bins as f64;
728 let latest_eqns = &latest.strength_solver.eqns;
729 let combined_eqns = &combined.strength_solver.eqns;
730 let mut diff = 0.0f64;
731 let mut total_weight = 0.0f64;
732 for j in 0..latest_eqns.n {
733 let mut weight = 0.0f64;
734 for i in 0..latest_eqns.n {
735 weight += latest_eqns.a[i * latest_eqns.n + j];
736 }
737 weight = weight.sqrt();
738 diff += weight * (latest_eqns.x[j] - combined_eqns.x[j]).abs();
739 total_weight += weight;
740 }
741
742 diff * dx / total_weight > STRENGTH_THRESHOLD
743 }
744
745 #[allow(clippy::too_many_arguments)]
746 fn add_block_observations(
747 &mut self,
748 channel: usize,
749 source: &Plane<u8>,
750 denoised: &Plane<u8>,
751 alt_source: Option<&Plane<u8>>,
752 alt_denoised: Option<&Plane<u8>>,
753 frame_dims: (usize, usize),
754 flat_blocks: &[u8],
755 num_blocks_w: usize,
756 num_blocks_h: usize,
757 ) {
758 let num_coords = self.n;
759 let state = &mut self.latest_state[channel];
760 let a = &mut state.eqns.a;
761 let b = &mut state.eqns.b;
762 let mut buffer = vec![0f64; num_coords + 1].into_boxed_slice();
763 let n = state.eqns.n;
764 let block_w = BLOCK_SIZE >> source.cfg.xdec;
765 let block_h = BLOCK_SIZE >> source.cfg.ydec;
766
767 let dec = (source.cfg.xdec, source.cfg.ydec);
768 let stride = source.cfg.stride;
769 let source_origin = source.data_origin();
770 let denoised_origin = denoised.data_origin();
771 let alt_stride = alt_source.map_or(0, |s| s.cfg.stride);
772 let alt_source_origin = alt_source.map(|s| s.data_origin());
773 let alt_denoised_origin = alt_denoised.map(|s| s.data_origin());
774
775 for by in 0..num_blocks_h {
776 let y_o = by * block_h;
777 for bx in 0..num_blocks_w {
778 // SAFETY: We know the indexes we provide do not overflow the data bounds
779 unsafe {
780 let flat_block_ptr = flat_blocks.as_ptr().add(by * num_blocks_w + bx);
781 let x_o = bx * block_w;
782 if *flat_block_ptr == 0 {
783 continue;
784 }
785 let y_start = if by > 0 && *flat_block_ptr.sub(num_blocks_w) > 0 {
786 0
787 } else {
788 NOISE_MODEL_LAG
789 };
790 let x_start = if bx > 0 && *flat_block_ptr.sub(1) > 0 {
791 0
792 } else {
793 NOISE_MODEL_LAG
794 };
795 let y_end = ((frame_dims.1 >> dec.1) - by * block_h).min(block_h);
796 let x_end = ((frame_dims.0 >> dec.0) - bx * block_w - NOISE_MODEL_LAG).min(
797 if bx + 1 < num_blocks_w && *flat_block_ptr.add(1) > 0 {
798 block_w
799 } else {
800 block_w - NOISE_MODEL_LAG
801 },
802 );
803 for y in y_start..y_end {
804 for x in x_start..x_end {
805 let val = extract_ar_row(
806 &self.coords,
807 num_coords,
808 source_origin,
809 denoised_origin,
810 stride,
811 dec,
812 alt_source_origin,
813 alt_denoised_origin,
814 alt_stride,
815 x + x_o,
816 y + y_o,
817 &mut buffer,
818 );
819 for i in 0..n {
820 for j in 0..n {
821 *a.get_unchecked_mut(i * n + j) += (*buffer.get_unchecked(i)
822 * *buffer.get_unchecked(j))
823 / BLOCK_NORMALIZATION.powi(2);
824 }
825 *b.get_unchecked_mut(i) +=
826 (*buffer.get_unchecked(i) * val) / BLOCK_NORMALIZATION.powi(2);
827 }
828 state.num_observations += 1;
829 }
830 }
831 }
832 }
833 }
834 }
835
836 #[allow(clippy::too_many_arguments)]
837 fn add_noise_std_observations(
838 &mut self,
839 channel: usize,
840 source: &Plane<u8>,
841 denoised: &Plane<u8>,
842 alt_source: Option<&Plane<u8>>,
843 frame_dims: (usize, usize),
844 flat_blocks: &[u8],
845 num_blocks_w: usize,
846 num_blocks_h: usize,
847 ) {
848 let coeffs = &self.latest_state[channel].eqns.x;
849 let num_coords = self.n;
850 let luma_gain = self.latest_state[0].ar_gain;
851 let noise_gain = self.latest_state[channel].ar_gain;
852 let block_w = BLOCK_SIZE >> source.cfg.xdec;
853 let block_h = BLOCK_SIZE >> source.cfg.ydec;
854
855 for by in 0..num_blocks_h {
856 let y_o = by * block_h;
857 for bx in 0..num_blocks_w {
858 let x_o = bx * block_w;
859 if flat_blocks[by * num_blocks_w + bx] == 0 {
860 continue;
861 }
862 let num_samples_h = ((frame_dims.1 >> source.cfg.ydec) - by * block_h).min(block_h);
863 let num_samples_w = ((frame_dims.0 >> source.cfg.xdec) - bx * block_w).min(block_w);
864 // Make sure that we have a reasonable amount of samples to consider the
865 // block
866 if num_samples_w * num_samples_h > BLOCK_SIZE {
867 let block_mean = get_block_mean(
868 alt_source.unwrap_or(source),
869 frame_dims,
870 x_o << source.cfg.xdec,
871 y_o << source.cfg.ydec,
872 );
873 let noise_var = get_noise_var(
874 source,
875 denoised,
876 (
877 frame_dims.0 >> source.cfg.xdec,
878 frame_dims.1 >> source.cfg.ydec,
879 ),
880 x_o,
881 y_o,
882 block_w,
883 block_h,
884 );
885 // We want to remove the part of the noise that came from being
886 // correlated with luma. Note that the noise solver for luma must
887 // have already been run.
888 let luma_strength = if channel > 0 {
889 luma_gain * self.latest_state[0].strength_solver.get_value(block_mean)
890 } else {
891 0f64
892 };
893 let corr = if channel > 0 {
894 coeffs[num_coords]
895 } else {
896 0f64
897 };
898 // Chroma noise:
899 // N(0, noise_var) = N(0, uncorr_var) + corr * N(0, luma_strength^2)
900 // The uncorrelated component:
901 // uncorr_var = noise_var - (corr * luma_strength)^2
902 // But don't allow fully correlated noise (hence the max), since the
903 // synthesis cannot model it.
904 let uncorr_std = (noise_var / 16f64)
905 .max((corr * luma_strength).mul_add(-(corr * luma_strength), noise_var))
906 .sqrt();
907 let adjusted_strength = uncorr_std / noise_gain;
908 self.latest_state[channel]
909 .strength_solver
910 .add_measurement(block_mean, adjusted_strength);
911 }
912 }
913 }
914 }
915}
916
917#[derive(Debug, Clone)]
918struct NoiseModelState {
919 eqns: EquationSystem,
920 ar_gain: f64,
921 num_observations: usize,
922 strength_solver: StrengthSolver,
923}
924
925impl NoiseModelState {
926 #[must_use]
927 pub fn new(n: usize) -> Self {
928 const NUM_BINS: usize = 20;
929
930 Self {
931 eqns: EquationSystem::new(n),
932 ar_gain: 1.0f64,
933 num_observations: 0usize,
934 strength_solver: StrengthSolver::new(NUM_BINS),
935 }
936 }
937
938 pub fn ar_equation_system_solve(&mut self, is_chroma: bool) -> bool {
939 let ret = self.eqns.solve();
940 self.ar_gain = 1.0f64;
941 if !ret {
942 return ret;
943 }
944
945 // Update the AR gain from the equation system as it will be used to fit
946 // the noise strength as a function of intensity. In the Yule-Walker
947 // equations, the diagonal should be the variance of the correlated noise.
948 // In the case of the least squares estimate, there will be some variability
949 // in the diagonal. So use the mean of the diagonal as the estimate of
950 // overall variance (this works for least squares or Yule-Walker formulation).
951 let mut var = 0f64;
952 let n_adjusted = self.eqns.n - usize::from(is_chroma);
953 for i in 0..n_adjusted {
954 var += self.eqns.a[i * self.eqns.n + i] / self.num_observations as f64;
955 }
956 var /= n_adjusted as f64;
957
958 // Keep track of E(Y^2) = <b, x> + E(X^2)
959 // In the case that we are using chroma and have an estimate of correlation
960 // with luma we adjust that estimate slightly to remove the correlated bits by
961 // subtracting out the last column of a scaled by our correlation estimate
962 // from b. E(y^2) = <b - A(:, end)*x(end), x>
963 let mut sum_covar = 0f64;
964 for i in 0..n_adjusted {
965 let mut bi = self.eqns.b[i];
966 if is_chroma {
967 bi -= self.eqns.a[i * self.eqns.n + n_adjusted] * self.eqns.x[n_adjusted];
968 }
969 sum_covar += (bi * self.eqns.x[i]) / self.num_observations as f64;
970 }
971
972 // Now, get an estimate of the variance of uncorrelated noise signal and use
973 // it to determine the gain of the AR filter.
974 let noise_var = (var - sum_covar).max(1e-6f64);
975 self.ar_gain = 1f64.max((var / noise_var).max(1e-6f64).sqrt());
976 ret
977 }
978}
979
980#[derive(Debug, Clone)]
981struct StrengthSolver {
982 eqns: EquationSystem,
983 num_bins: usize,
984 num_equations: usize,
985 total: f64,
986}
987
988impl StrengthSolver {
989 #[must_use]
990 pub fn new(num_bins: usize) -> Self {
991 Self {
992 eqns: EquationSystem::new(num_bins),
993 num_bins,
994 num_equations: 0usize,
995 total: 0f64,
996 }
997 }
998
999 pub fn add_measurement(&mut self, block_mean: f64, noise_std: f64) {
1000 let bin = self.get_bin_index(block_mean);
1001 let bin_i0 = bin.floor() as usize;
1002 let bin_i1 = (self.num_bins - 1).min(bin_i0 + 1);
1003 let a = bin - bin_i0 as f64;
1004 let n = self.num_bins;
1005 let eqns = &mut self.eqns;
1006 eqns.a[bin_i0 * n + bin_i0] += (1f64 - a).powi(2);
1007 eqns.a[bin_i1 * n + bin_i0] += a * (1f64 - a);
1008 eqns.a[bin_i1 * n + bin_i1] += a.powi(2);
1009 eqns.a[bin_i0 * n + bin_i1] += (1f64 - a) * a;
1010 eqns.b[bin_i0] += (1f64 - a) * noise_std;
1011 eqns.b[bin_i1] += a * noise_std;
1012 self.total += noise_std;
1013 self.num_equations += 1;
1014 }
1015
1016 pub fn solve(&mut self) -> bool {
1017 // Add regularization proportional to the number of constraints
1018 let n = self.num_bins;
1019 let alpha = 2f64 * self.num_equations as f64 / n as f64;
1020
1021 // Do this in a non-destructive manner so it is not confusing to the caller
1022 let old_a = self.eqns.a.clone();
1023 for i in 0..n {
1024 let i_lo = i.saturating_sub(1);
1025 let i_hi = (n - 1).min(i + 1);
1026 self.eqns.a[i * n + i_lo] -= alpha;
1027 self.eqns.a[i * n + i] += 2f64 * alpha;
1028 self.eqns.a[i * n + i_hi] -= alpha;
1029 }
1030
1031 // Small regularization to give average noise strength
1032 let mean = self.total / self.num_equations as f64;
1033 for i in 0..n {
1034 self.eqns.a[i * n + i] += 1f64 / 8192f64;
1035 self.eqns.b[i] += mean / 8192f64;
1036 }
1037 let result = self.eqns.solve();
1038 self.eqns.a = old_a;
1039 result
1040 }
1041
1042 #[must_use]
1043 pub fn fit_piecewise(&self, max_output_points: usize) -> NoiseStrengthLut {
1044 const TOLERANCE: f64 = 0.00625f64;
1045
1046 let mut lut = NoiseStrengthLut::new(self.num_bins);
1047 for i in 0..self.num_bins {
1048 lut.points[i][0] = self.get_center(i);
1049 lut.points[i][1] = self.eqns.x[i];
1050 }
1051
1052 let mut residual = vec![0.0f64; self.num_bins];
1053 self.update_piecewise_linear_residual(&lut, &mut residual, 0, self.num_bins);
1054
1055 // Greedily remove points if there are too many or if it doesn't hurt local
1056 // approximation (never remove the end points)
1057 while lut.points.len() > 2 {
1058 let mut min_index = 1usize;
1059 for j in 1..(lut.points.len() - 1) {
1060 if residual[j] < residual[min_index] {
1061 min_index = j;
1062 }
1063 }
1064 let dx = lut.points[min_index + 1][0] - lut.points[min_index - 1][0];
1065 let avg_residual = residual[min_index] / dx;
1066 if lut.points.len() <= max_output_points && avg_residual > TOLERANCE {
1067 break;
1068 }
1069
1070 lut.points.remove(min_index);
1071 self.update_piecewise_linear_residual(
1072 &lut,
1073 &mut residual,
1074 min_index - 1,
1075 min_index + 1,
1076 );
1077 }
1078
1079 lut
1080 }
1081
1082 #[must_use]
1083 pub fn get_value(&self, x: f64) -> f64 {
1084 let bin = self.get_bin_index(x);
1085 let bin_i0 = bin.floor() as usize;
1086 let bin_i1 = (self.num_bins - 1).min(bin_i0 + 1);
1087 let a = bin - bin_i0 as f64;
1088 (1f64 - a).mul_add(self.eqns.x[bin_i0], a * self.eqns.x[bin_i1])
1089 }
1090
1091 pub fn clear(&mut self) {
1092 self.eqns.clear();
1093 self.num_equations = 0;
1094 self.total = 0f64;
1095 }
1096
1097 #[must_use]
1098 fn get_bin_index(&self, value: f64) -> f64 {
1099 let max = 255f64;
1100 let val = clamp(value, 0f64, max);
1101 (self.num_bins - 1) as f64 * val / max
1102 }
1103
1104 fn update_piecewise_linear_residual(
1105 &self,
1106 lut: &NoiseStrengthLut,
1107 residual: &mut [f64],
1108 start: usize,
1109 end: usize,
1110 ) {
1111 let dx = 255f64 / self.num_bins as f64;
1112 #[allow(clippy::needless_range_loop)]
1113 for i in start.max(1)..end.min(lut.points.len() - 1) {
1114 let lower = 0usize.max(self.get_bin_index(lut.points[i - 1][0]).floor() as usize);
1115 let upper =
1116 (self.num_bins - 1).min(self.get_bin_index(lut.points[i + 1][0]).ceil() as usize);
1117 let mut r = 0f64;
1118 for j in lower..=upper {
1119 let x = self.get_center(j);
1120 if x < lut.points[i - 1][0] || x >= lut.points[i + 1][0] {
1121 continue;
1122 }
1123
1124 let y = self.eqns.x[j];
1125 let a = (x - lut.points[i - 1][0]) / (lut.points[i + 1][0] - lut.points[i - 1][0]);
1126 let estimate_y = lut.points[i - 1][1].mul_add(1f64 - a, lut.points[i + 1][1] * a);
1127 r += (y - estimate_y).abs();
1128 }
1129 residual[i] = r * dx;
1130 }
1131 }
1132
1133 #[must_use]
1134 fn get_center(&self, i: usize) -> f64 {
1135 let range = 255f64;
1136 let n = self.num_bins;
1137 i as f64 / (n - 1) as f64 * range
1138 }
1139}
1140
1141impl Add<&StrengthSolver> for StrengthSolver {
1142 type Output = StrengthSolver;
1143
1144 #[must_use]
1145 fn add(self, addend: &StrengthSolver) -> Self::Output {
1146 let mut dest: StrengthSolver = self;
1147 dest.eqns += &addend.eqns;
1148 dest.num_equations += addend.num_equations;
1149 dest.total += addend.total;
1150 dest
1151 }
1152}
1153
1154impl AddAssign<&StrengthSolver> for StrengthSolver {
1155 fn add_assign(&mut self, rhs: &StrengthSolver) {
1156 *self = self.clone() + rhs;
1157 }
1158}
1159