| 1 | mod util; |
| 2 | |
| 3 | use std::ops::{Add, AddAssign}; |
| 4 | |
| 5 | use anyhow::anyhow; |
| 6 | use arrayvec::ArrayVec; |
| 7 | use v_frame::{frame::Frame, math::clamp, plane::Plane}; |
| 8 | |
| 9 | use self::util::{extract_ar_row, get_block_mean, get_noise_var, linsolve, multiply_mat}; |
| 10 | use super::{NoiseStatus, BLOCK_SIZE, BLOCK_SIZE_SQUARED}; |
| 11 | use crate::{ |
| 12 | diff::solver::util::normalized_cross_correlation, GrainTableSegment, DEFAULT_GRAIN_SEED, |
| 13 | NUM_UV_COEFFS, NUM_UV_POINTS, NUM_Y_COEFFS, NUM_Y_POINTS, |
| 14 | }; |
| 15 | |
| 16 | const LOW_POLY_NUM_PARAMS: usize = 3; |
| 17 | const NOISE_MODEL_LAG: usize = 3; |
| 18 | const BLOCK_NORMALIZATION: f64 = 255.0f64; |
| 19 | |
| 20 | #[derive (Debug, Clone, Copy)] |
| 21 | pub(super) struct FlatBlockFinder { |
| 22 | a: [f64; LOW_POLY_NUM_PARAMS * BLOCK_SIZE_SQUARED], |
| 23 | a_t_a_inv: [f64; LOW_POLY_NUM_PARAMS * LOW_POLY_NUM_PARAMS], |
| 24 | } |
| 25 | |
| 26 | impl FlatBlockFinder { |
| 27 | #[must_use ] |
| 28 | pub fn new() -> Self { |
| 29 | let mut eqns = EquationSystem::new(LOW_POLY_NUM_PARAMS); |
| 30 | let mut a_t_a_inv = [0.0f64; LOW_POLY_NUM_PARAMS * LOW_POLY_NUM_PARAMS]; |
| 31 | let mut a = [0.0f64; LOW_POLY_NUM_PARAMS * BLOCK_SIZE_SQUARED]; |
| 32 | |
| 33 | let bs_half = (BLOCK_SIZE / 2) as f64; |
| 34 | (0..BLOCK_SIZE).for_each(|y| { |
| 35 | let yd = (y as f64 - bs_half) / bs_half; |
| 36 | (0..BLOCK_SIZE).for_each(|x| { |
| 37 | let xd = (x as f64 - bs_half) / bs_half; |
| 38 | let coords = [yd, xd, 1.0f64]; |
| 39 | let row = y * BLOCK_SIZE + x; |
| 40 | a[LOW_POLY_NUM_PARAMS * row] = yd; |
| 41 | a[LOW_POLY_NUM_PARAMS * row + 1] = xd; |
| 42 | a[LOW_POLY_NUM_PARAMS * row + 2] = 1.0f64; |
| 43 | |
| 44 | (0..LOW_POLY_NUM_PARAMS).for_each(|i| { |
| 45 | (0..LOW_POLY_NUM_PARAMS).for_each(|j| { |
| 46 | eqns.a[LOW_POLY_NUM_PARAMS * i + j] += coords[i] * coords[j]; |
| 47 | }); |
| 48 | }); |
| 49 | }); |
| 50 | }); |
| 51 | |
| 52 | // Lazy inverse using existing equation solver. |
| 53 | (0..LOW_POLY_NUM_PARAMS).for_each(|i| { |
| 54 | eqns.b.fill(0.0f64); |
| 55 | eqns.b[i] = 1.0f64; |
| 56 | eqns.solve(); |
| 57 | |
| 58 | (0..LOW_POLY_NUM_PARAMS).for_each(|j| { |
| 59 | a_t_a_inv[j * LOW_POLY_NUM_PARAMS + i] = eqns.x[j]; |
| 60 | }); |
| 61 | }); |
| 62 | |
| 63 | FlatBlockFinder { a, a_t_a_inv } |
| 64 | } |
| 65 | |
| 66 | // The gradient-based features used in this code are based on: |
| 67 | // A. Kokaram, D. Kelly, H. Denman and A. Crawford, "Measuring noise |
| 68 | // correlation for improved video denoising," 2012 19th, ICIP. |
| 69 | // The thresholds are more lenient to allow for correct grain modeling |
| 70 | // in extreme cases. |
| 71 | #[must_use ] |
| 72 | #[allow (clippy::too_many_lines)] |
| 73 | pub fn run(&self, plane: &Plane<u8>) -> (Vec<u8>, usize) { |
| 74 | const TRACE_THRESHOLD: f64 = 0.15f64 / BLOCK_SIZE_SQUARED as f64; |
| 75 | const RATIO_THRESHOLD: f64 = 1.25f64; |
| 76 | const NORM_THRESHOLD: f64 = 0.08f64 / BLOCK_SIZE_SQUARED as f64; |
| 77 | const VAR_THRESHOLD: f64 = 0.005f64 / BLOCK_SIZE_SQUARED as f64; |
| 78 | |
| 79 | // The following weights are used to combine the above features to give |
| 80 | // a sigmoid score for flatness. If the input was normalized to [0,100] |
| 81 | // the magnitude of these values would be close to 1 (e.g., weights |
| 82 | // corresponding to variance would be a factor of 10000x smaller). |
| 83 | const VAR_WEIGHT: f64 = -6682f64; |
| 84 | const RATIO_WEIGHT: f64 = -0.2056f64; |
| 85 | const TRACE_WEIGHT: f64 = 13087f64; |
| 86 | const NORM_WEIGHT: f64 = -12434f64; |
| 87 | const OFFSET: f64 = 2.5694f64; |
| 88 | |
| 89 | let num_blocks_w = (plane.cfg.width + BLOCK_SIZE - 1) / BLOCK_SIZE; |
| 90 | let num_blocks_h = (plane.cfg.height + BLOCK_SIZE - 1) / BLOCK_SIZE; |
| 91 | let num_blocks = num_blocks_w * num_blocks_h; |
| 92 | let mut flat_blocks = vec![0u8; num_blocks]; |
| 93 | let mut num_flat = 0; |
| 94 | let mut plane_result = [0.0f64; BLOCK_SIZE_SQUARED]; |
| 95 | let mut block_result = [0.0f64; BLOCK_SIZE_SQUARED]; |
| 96 | let mut scores = vec![IndexAndScore::default(); num_blocks]; |
| 97 | |
| 98 | for by in 0..num_blocks_h { |
| 99 | for bx in 0..num_blocks_w { |
| 100 | // Compute gradient covariance matrix. |
| 101 | let mut gxx = 0f64; |
| 102 | let mut gxy = 0f64; |
| 103 | let mut gyy = 0f64; |
| 104 | let mut var = 0f64; |
| 105 | let mut mean = 0f64; |
| 106 | |
| 107 | self.extract_block( |
| 108 | plane, |
| 109 | bx * BLOCK_SIZE, |
| 110 | by * BLOCK_SIZE, |
| 111 | &mut plane_result, |
| 112 | &mut block_result, |
| 113 | ); |
| 114 | for yi in 1..(BLOCK_SIZE - 1) { |
| 115 | for xi in 1..(BLOCK_SIZE - 1) { |
| 116 | // SAFETY: We know the size of `block_result` and that we cannot exceed the bounds of it |
| 117 | unsafe { |
| 118 | let result_ptr = block_result.as_ptr().add(yi * BLOCK_SIZE + xi); |
| 119 | |
| 120 | let gx = (*result_ptr.add(1) - *result_ptr.sub(1)) / 2f64; |
| 121 | let gy = |
| 122 | (*result_ptr.add(BLOCK_SIZE) - *result_ptr.sub(BLOCK_SIZE)) / 2f64; |
| 123 | gxx += gx * gx; |
| 124 | gxy += gx * gy; |
| 125 | gyy += gy * gy; |
| 126 | |
| 127 | let block_val = *result_ptr; |
| 128 | mean += block_val; |
| 129 | var += block_val * block_val; |
| 130 | } |
| 131 | } |
| 132 | } |
| 133 | let block_size_norm_factor = (BLOCK_SIZE - 2).pow(2) as f64; |
| 134 | mean /= block_size_norm_factor; |
| 135 | |
| 136 | // Normalize gradients by block_size. |
| 137 | gxx /= block_size_norm_factor; |
| 138 | gxy /= block_size_norm_factor; |
| 139 | gyy /= block_size_norm_factor; |
| 140 | var = mean.mul_add(-mean, var / block_size_norm_factor); |
| 141 | |
| 142 | let trace = gxx + gyy; |
| 143 | let det = gxx.mul_add(gyy, -gxy.powi(2)); |
| 144 | let e_sub = (trace.mul_add(trace, -4f64 * det)).max(0.).sqrt(); |
| 145 | let e1 = (trace + e_sub) / 2.0f64; |
| 146 | let e2 = (trace - e_sub) / 2.0f64; |
| 147 | // Spectral norm |
| 148 | let norm = e1; |
| 149 | let ratio = e1 / e2.max(1.0e-6_f64); |
| 150 | let is_flat = trace < TRACE_THRESHOLD |
| 151 | && ratio < RATIO_THRESHOLD |
| 152 | && norm < NORM_THRESHOLD |
| 153 | && var > VAR_THRESHOLD; |
| 154 | |
| 155 | let sum_weights = NORM_WEIGHT.mul_add( |
| 156 | norm, |
| 157 | TRACE_WEIGHT.mul_add( |
| 158 | trace, |
| 159 | VAR_WEIGHT.mul_add(var, RATIO_WEIGHT.mul_add(ratio, OFFSET)), |
| 160 | ), |
| 161 | ); |
| 162 | // clamp the value to [-25.0, 100.0] to prevent overflow |
| 163 | let sum_weights = clamp(sum_weights, -25.0f64, 100.0f64); |
| 164 | let score = (1.0f64 / (1.0f64 + (-sum_weights).exp())) as f32; |
| 165 | // SAFETY: We know the size of `flat_blocks` and `scores` and that we cannot exceed the bounds of it |
| 166 | unsafe { |
| 167 | let index = by * num_blocks_w + bx; |
| 168 | *flat_blocks.get_unchecked_mut(index) = if is_flat { 255 } else { 0 }; |
| 169 | *scores.get_unchecked_mut(index) = IndexAndScore { |
| 170 | score: if var > VAR_THRESHOLD { score } else { 0f32 }, |
| 171 | index, |
| 172 | }; |
| 173 | } |
| 174 | if is_flat { |
| 175 | num_flat += 1; |
| 176 | } |
| 177 | } |
| 178 | } |
| 179 | |
| 180 | scores.sort_unstable_by(|a, b| a.score.partial_cmp(&b.score).expect("Shouldn't be NaN" )); |
| 181 | // SAFETY: We know the size of `flat_blocks` and `scores` and that we cannot exceed the bounds of it |
| 182 | unsafe { |
| 183 | let top_nth_percentile = num_blocks * 90 / 100; |
| 184 | let score_threshold = scores.get_unchecked(top_nth_percentile).score; |
| 185 | for score in &scores { |
| 186 | if score.score >= score_threshold { |
| 187 | let block_ref = flat_blocks.get_unchecked_mut(score.index); |
| 188 | if *block_ref == 0 { |
| 189 | num_flat += 1; |
| 190 | } |
| 191 | *block_ref |= 1; |
| 192 | } |
| 193 | } |
| 194 | } |
| 195 | |
| 196 | (flat_blocks, num_flat) |
| 197 | } |
| 198 | |
| 199 | fn extract_block( |
| 200 | &self, |
| 201 | plane: &Plane<u8>, |
| 202 | offset_x: usize, |
| 203 | offset_y: usize, |
| 204 | plane_result: &mut [f64; BLOCK_SIZE_SQUARED], |
| 205 | block_result: &mut [f64; BLOCK_SIZE_SQUARED], |
| 206 | ) { |
| 207 | let mut plane_coords = [0f64; LOW_POLY_NUM_PARAMS]; |
| 208 | let mut a_t_a_inv_b = [0f64; LOW_POLY_NUM_PARAMS]; |
| 209 | let plane_origin = plane.data_origin(); |
| 210 | |
| 211 | for yi in 0..BLOCK_SIZE { |
| 212 | let y = clamp(offset_y + yi, 0, plane.cfg.height - 1); |
| 213 | for xi in 0..BLOCK_SIZE { |
| 214 | let x = clamp(offset_x + xi, 0, plane.cfg.width - 1); |
| 215 | // SAFETY: We know the bounds of the plane data and `block_result` |
| 216 | // and do not exceed them. |
| 217 | unsafe { |
| 218 | *block_result.get_unchecked_mut(yi * BLOCK_SIZE + xi) = |
| 219 | f64::from(*plane_origin.get_unchecked(y * plane.cfg.stride + x)) |
| 220 | / BLOCK_NORMALIZATION; |
| 221 | } |
| 222 | } |
| 223 | } |
| 224 | |
| 225 | multiply_mat( |
| 226 | block_result, |
| 227 | &self.a, |
| 228 | &mut a_t_a_inv_b, |
| 229 | 1, |
| 230 | BLOCK_SIZE_SQUARED, |
| 231 | LOW_POLY_NUM_PARAMS, |
| 232 | ); |
| 233 | multiply_mat( |
| 234 | &self.a_t_a_inv, |
| 235 | &a_t_a_inv_b, |
| 236 | &mut plane_coords, |
| 237 | LOW_POLY_NUM_PARAMS, |
| 238 | LOW_POLY_NUM_PARAMS, |
| 239 | 1, |
| 240 | ); |
| 241 | multiply_mat( |
| 242 | &self.a, |
| 243 | &plane_coords, |
| 244 | plane_result, |
| 245 | BLOCK_SIZE_SQUARED, |
| 246 | LOW_POLY_NUM_PARAMS, |
| 247 | 1, |
| 248 | ); |
| 249 | |
| 250 | for (block_res, plane_res) in block_result.iter_mut().zip(plane_result.iter()) { |
| 251 | *block_res -= *plane_res; |
| 252 | } |
| 253 | } |
| 254 | } |
| 255 | |
| 256 | #[derive (Debug, Clone, Copy, Default)] |
| 257 | struct IndexAndScore { |
| 258 | pub index: usize, |
| 259 | pub score: f32, |
| 260 | } |
| 261 | |
| 262 | /// Wrapper of data required to represent linear system of eqns and soln. |
| 263 | #[derive (Debug, Clone)] |
| 264 | struct EquationSystem { |
| 265 | a: Vec<f64>, |
| 266 | b: Vec<f64>, |
| 267 | x: Vec<f64>, |
| 268 | n: usize, |
| 269 | } |
| 270 | |
| 271 | impl EquationSystem { |
| 272 | #[must_use ] |
| 273 | pub fn new(n: usize) -> Self { |
| 274 | Self { |
| 275 | a: vec![0.0f64; n * n], |
| 276 | b: vec![0.0f64; n], |
| 277 | x: vec![0.0f64; n], |
| 278 | n, |
| 279 | } |
| 280 | } |
| 281 | |
| 282 | pub fn solve(&mut self) -> bool { |
| 283 | let n = self.n; |
| 284 | let mut a = self.a.clone(); |
| 285 | let mut b = self.b.clone(); |
| 286 | |
| 287 | linsolve(n, &mut a, self.n, &mut b, &mut self.x) |
| 288 | } |
| 289 | |
| 290 | pub fn set_chroma_coefficient_fallback_solution(&mut self) { |
| 291 | const TOLERANCE: f64 = 1.0e-6f64; |
| 292 | let last = self.n - 1; |
| 293 | // Set all of the AR coefficients to zero, but try to solve for correlation |
| 294 | // with the luma channel |
| 295 | self.x.fill(0f64); |
| 296 | if self.a[last * self.n + last].abs() > TOLERANCE { |
| 297 | self.x[last] = self.b[last] / self.a[last * self.n + last]; |
| 298 | } |
| 299 | } |
| 300 | |
| 301 | pub fn copy_from(&mut self, other: &Self) { |
| 302 | assert_eq!(self.n, other.n); |
| 303 | self.a.copy_from_slice(&other.a); |
| 304 | self.x.copy_from_slice(&other.x); |
| 305 | self.b.copy_from_slice(&other.b); |
| 306 | } |
| 307 | |
| 308 | pub fn clear(&mut self) { |
| 309 | self.a.fill(0f64); |
| 310 | self.b.fill(0f64); |
| 311 | self.x.fill(0f64); |
| 312 | } |
| 313 | } |
| 314 | |
| 315 | impl Add<&EquationSystem> for EquationSystem { |
| 316 | type Output = EquationSystem; |
| 317 | |
| 318 | #[must_use ] |
| 319 | fn add(self, addend: &EquationSystem) -> Self::Output { |
| 320 | let mut dest: EquationSystem = self.clone(); |
| 321 | let n: usize = self.n; |
| 322 | for i: usize in 0..n { |
| 323 | for j: usize in 0..n { |
| 324 | dest.a[i * n + j] += addend.a[i * n + j]; |
| 325 | } |
| 326 | dest.b[i] += addend.b[i]; |
| 327 | } |
| 328 | dest |
| 329 | } |
| 330 | } |
| 331 | |
| 332 | impl AddAssign<&EquationSystem> for EquationSystem { |
| 333 | fn add_assign(&mut self, rhs: &EquationSystem) { |
| 334 | *self = self.clone() + rhs; |
| 335 | } |
| 336 | } |
| 337 | |
| 338 | /// Representation of a piecewise linear curve |
| 339 | /// |
| 340 | /// Holds n points as (x, y) pairs, that store the curve. |
| 341 | struct NoiseStrengthLut { |
| 342 | points: Vec<[f64; 2]>, |
| 343 | } |
| 344 | |
| 345 | impl NoiseStrengthLut { |
| 346 | #[must_use ] |
| 347 | pub fn new(num_bins: usize) -> Self { |
| 348 | assert!(num_bins > 0); |
| 349 | Self { |
| 350 | points: vec![[0f64; 2]; num_bins], |
| 351 | } |
| 352 | } |
| 353 | } |
| 354 | |
| 355 | #[derive (Debug, Clone)] |
| 356 | pub(super) struct NoiseModel { |
| 357 | combined_state: [NoiseModelState; 3], |
| 358 | latest_state: [NoiseModelState; 3], |
| 359 | n: usize, |
| 360 | coords: Vec<[isize; 2]>, |
| 361 | } |
| 362 | |
| 363 | impl NoiseModel { |
| 364 | #[must_use ] |
| 365 | pub fn new() -> Self { |
| 366 | let n = Self::num_coeffs(); |
| 367 | let combined_state = [ |
| 368 | NoiseModelState::new(n), |
| 369 | NoiseModelState::new(n + 1), |
| 370 | NoiseModelState::new(n + 1), |
| 371 | ]; |
| 372 | let latest_state = [ |
| 373 | NoiseModelState::new(n), |
| 374 | NoiseModelState::new(n + 1), |
| 375 | NoiseModelState::new(n + 1), |
| 376 | ]; |
| 377 | let mut coords = Vec::new(); |
| 378 | |
| 379 | let neg_lag = -(NOISE_MODEL_LAG as isize); |
| 380 | for y in neg_lag..=0 { |
| 381 | let max_x = if y == 0 { |
| 382 | -1isize |
| 383 | } else { |
| 384 | NOISE_MODEL_LAG as isize |
| 385 | }; |
| 386 | for x in neg_lag..=max_x { |
| 387 | coords.push([x, y]); |
| 388 | } |
| 389 | } |
| 390 | assert!(n == coords.len()); |
| 391 | |
| 392 | Self { |
| 393 | combined_state, |
| 394 | latest_state, |
| 395 | n, |
| 396 | coords, |
| 397 | } |
| 398 | } |
| 399 | |
| 400 | pub fn update( |
| 401 | &mut self, |
| 402 | source: &Frame<u8>, |
| 403 | denoised: &Frame<u8>, |
| 404 | flat_blocks: &[u8], |
| 405 | ) -> NoiseStatus { |
| 406 | let num_blocks_w = (source.planes[0].cfg.width + BLOCK_SIZE - 1) / BLOCK_SIZE; |
| 407 | let num_blocks_h = (source.planes[0].cfg.height + BLOCK_SIZE - 1) / BLOCK_SIZE; |
| 408 | let mut y_model_different = false; |
| 409 | |
| 410 | // Clear the latest equation system |
| 411 | for i in 0..3 { |
| 412 | self.latest_state[i].eqns.clear(); |
| 413 | self.latest_state[i].num_observations = 0; |
| 414 | self.latest_state[i].strength_solver.clear(); |
| 415 | } |
| 416 | |
| 417 | // Check that we have enough flat blocks |
| 418 | let num_blocks = flat_blocks.iter().filter(|b| **b > 0).count(); |
| 419 | if num_blocks <= 1 { |
| 420 | return NoiseStatus::Error(anyhow!("Not enough flat blocks to update noise estimate" )); |
| 421 | } |
| 422 | |
| 423 | let frame_dims = (source.planes[0].cfg.width, source.planes[0].cfg.height); |
| 424 | for channel in 0..3 { |
| 425 | if source.planes[channel].data.is_empty() { |
| 426 | // Monochrome source |
| 427 | break; |
| 428 | } |
| 429 | let is_chroma = channel > 0; |
| 430 | let alt_source = (channel > 0).then(|| &source.planes[0]); |
| 431 | let alt_denoised = (channel > 0).then(|| &denoised.planes[0]); |
| 432 | self.add_block_observations( |
| 433 | channel, |
| 434 | &source.planes[channel], |
| 435 | &denoised.planes[channel], |
| 436 | alt_source, |
| 437 | alt_denoised, |
| 438 | frame_dims, |
| 439 | flat_blocks, |
| 440 | num_blocks_w, |
| 441 | num_blocks_h, |
| 442 | ); |
| 443 | if !self.latest_state[channel].ar_equation_system_solve(is_chroma) { |
| 444 | if is_chroma { |
| 445 | self.latest_state[channel] |
| 446 | .eqns |
| 447 | .set_chroma_coefficient_fallback_solution(); |
| 448 | } else { |
| 449 | return NoiseStatus::Error(anyhow!( |
| 450 | "Solving latest noise equation system failed on plane {}" , |
| 451 | channel |
| 452 | )); |
| 453 | } |
| 454 | } |
| 455 | self.add_noise_std_observations( |
| 456 | channel, |
| 457 | &source.planes[channel], |
| 458 | &denoised.planes[channel], |
| 459 | alt_source, |
| 460 | frame_dims, |
| 461 | flat_blocks, |
| 462 | num_blocks_w, |
| 463 | num_blocks_h, |
| 464 | ); |
| 465 | if !self.latest_state[channel].strength_solver.solve() { |
| 466 | return NoiseStatus::Error(anyhow!( |
| 467 | "Failed to solve strength solver for latest state" |
| 468 | )); |
| 469 | } |
| 470 | |
| 471 | // Check noise characteristics and return if error |
| 472 | if channel == 0 |
| 473 | && self.combined_state[channel].strength_solver.num_equations > 0 |
| 474 | && self.is_different() |
| 475 | { |
| 476 | y_model_different = true; |
| 477 | } |
| 478 | |
| 479 | if y_model_different { |
| 480 | continue; |
| 481 | } |
| 482 | |
| 483 | self.combined_state[channel].num_observations += |
| 484 | self.latest_state[channel].num_observations; |
| 485 | self.combined_state[channel].eqns += &self.latest_state[channel].eqns; |
| 486 | if !self.combined_state[channel].ar_equation_system_solve(is_chroma) { |
| 487 | if is_chroma { |
| 488 | self.combined_state[channel] |
| 489 | .eqns |
| 490 | .set_chroma_coefficient_fallback_solution(); |
| 491 | } else { |
| 492 | return NoiseStatus::Error(anyhow!( |
| 493 | "Solving combined noise equation system failed on plane {}" , |
| 494 | channel |
| 495 | )); |
| 496 | } |
| 497 | } |
| 498 | |
| 499 | self.combined_state[channel].strength_solver += |
| 500 | &self.latest_state[channel].strength_solver; |
| 501 | |
| 502 | if !self.combined_state[channel].strength_solver.solve() { |
| 503 | return NoiseStatus::Error(anyhow!( |
| 504 | "Failed to solve strength solver for combined state" |
| 505 | )); |
| 506 | }; |
| 507 | } |
| 508 | |
| 509 | if y_model_different { |
| 510 | return NoiseStatus::DifferentType; |
| 511 | } |
| 512 | |
| 513 | NoiseStatus::Ok |
| 514 | } |
| 515 | |
| 516 | #[allow (clippy::too_many_lines)] |
| 517 | #[must_use ] |
| 518 | pub fn get_grain_parameters(&self, start_ts: u64, end_ts: u64) -> GrainTableSegment { |
| 519 | // Both the domain and the range of the scaling functions in the film_grain |
| 520 | // are normalized to 8-bit (e.g., they are implicitly scaled during grain |
| 521 | // synthesis). |
| 522 | let scaling_points_y = self.combined_state[0] |
| 523 | .strength_solver |
| 524 | .fit_piecewise(NUM_Y_POINTS) |
| 525 | .points; |
| 526 | let scaling_points_cb = self.combined_state[1] |
| 527 | .strength_solver |
| 528 | .fit_piecewise(NUM_UV_POINTS) |
| 529 | .points; |
| 530 | let scaling_points_cr = self.combined_state[2] |
| 531 | .strength_solver |
| 532 | .fit_piecewise(NUM_UV_POINTS) |
| 533 | .points; |
| 534 | |
| 535 | let mut max_scaling_value: f64 = 1.0e-4f64; |
| 536 | for p in scaling_points_y |
| 537 | .iter() |
| 538 | .chain(scaling_points_cb.iter()) |
| 539 | .chain(scaling_points_cr.iter()) |
| 540 | .map(|p| p[1]) |
| 541 | { |
| 542 | if p > max_scaling_value { |
| 543 | max_scaling_value = p; |
| 544 | } |
| 545 | } |
| 546 | |
| 547 | // Scaling_shift values are in the range [8,11] |
| 548 | let max_scaling_value_log2 = |
| 549 | clamp((max_scaling_value.log2() + 1f64).floor() as u8, 2u8, 5u8); |
| 550 | let scale_factor = f64::from(1u32 << (8u8 - max_scaling_value_log2)); |
| 551 | let map_scaling_point = |p: [f64; 2]| { |
| 552 | [ |
| 553 | (p[0] + 0.5f64) as u8, |
| 554 | clamp(scale_factor.mul_add(p[1], 0.5f64) as i32, 0i32, 255i32) as u8, |
| 555 | ] |
| 556 | }; |
| 557 | |
| 558 | let scaling_points_y: ArrayVec<_, NUM_Y_POINTS> = scaling_points_y |
| 559 | .into_iter() |
| 560 | .map(map_scaling_point) |
| 561 | .collect(); |
| 562 | let scaling_points_cb: ArrayVec<_, NUM_UV_POINTS> = scaling_points_cb |
| 563 | .into_iter() |
| 564 | .map(map_scaling_point) |
| 565 | .collect(); |
| 566 | let scaling_points_cr: ArrayVec<_, NUM_UV_POINTS> = scaling_points_cr |
| 567 | .into_iter() |
| 568 | .map(map_scaling_point) |
| 569 | .collect(); |
| 570 | |
| 571 | // Convert the ar_coeffs into 8-bit values |
| 572 | let n_coeff = self.combined_state[0].eqns.n; |
| 573 | let mut max_coeff = 1.0e-4f64; |
| 574 | let mut min_coeff = 1.0e-4f64; |
| 575 | let mut y_corr = [0f64; 2]; |
| 576 | let mut avg_luma_strength = 0f64; |
| 577 | for c in 0..3 { |
| 578 | let eqns = &self.combined_state[c].eqns; |
| 579 | for i in 0..n_coeff { |
| 580 | if eqns.x[i] > max_coeff { |
| 581 | max_coeff = eqns.x[i]; |
| 582 | } |
| 583 | if eqns.x[i] < min_coeff { |
| 584 | min_coeff = eqns.x[i]; |
| 585 | } |
| 586 | } |
| 587 | |
| 588 | // Since the correlation between luma/chroma was computed in an already |
| 589 | // scaled space, we adjust it in the un-scaled space. |
| 590 | let solver = &self.combined_state[c].strength_solver; |
| 591 | // Compute a weighted average of the strength for the channel. |
| 592 | let mut average_strength = 0f64; |
| 593 | let mut total_weight = 0f64; |
| 594 | for i in 0..solver.eqns.n { |
| 595 | let mut w = 0f64; |
| 596 | for j in 0..solver.eqns.n { |
| 597 | w += solver.eqns.a[i * solver.eqns.n + j]; |
| 598 | } |
| 599 | w = w.sqrt(); |
| 600 | average_strength += solver.eqns.x[i] * w; |
| 601 | total_weight += w; |
| 602 | } |
| 603 | if total_weight.abs() < f64::EPSILON { |
| 604 | average_strength = 1f64; |
| 605 | } else { |
| 606 | average_strength /= total_weight; |
| 607 | } |
| 608 | if c == 0 { |
| 609 | avg_luma_strength = average_strength; |
| 610 | } else { |
| 611 | y_corr[c - 1] = avg_luma_strength * eqns.x[n_coeff] / average_strength; |
| 612 | max_coeff = max_coeff.max(y_corr[c - 1]); |
| 613 | min_coeff = min_coeff.min(y_corr[c - 1]); |
| 614 | } |
| 615 | } |
| 616 | |
| 617 | // Shift value: AR coeffs range (values 6-9) |
| 618 | // 6: [-2, 2), 7: [-1, 1), 8: [-0.5, 0.5), 9: [-0.25, 0.25) |
| 619 | let ar_coeff_shift = clamp( |
| 620 | 7i32 - (1.0f64 + max_coeff.log2().floor()).max((-min_coeff).log2().ceil()) as i32, |
| 621 | 6i32, |
| 622 | 9i32, |
| 623 | ) as u8; |
| 624 | let scale_ar_coeff = f64::from(1u16 << ar_coeff_shift); |
| 625 | let ar_coeffs_y = self.get_ar_coeffs_y(n_coeff, scale_ar_coeff); |
| 626 | let ar_coeffs_cb = self.get_ar_coeffs_uv(1, n_coeff, scale_ar_coeff, y_corr); |
| 627 | let ar_coeffs_cr = self.get_ar_coeffs_uv(2, n_coeff, scale_ar_coeff, y_corr); |
| 628 | |
| 629 | GrainTableSegment { |
| 630 | random_seed: if start_ts == 0 { DEFAULT_GRAIN_SEED } else { 0 }, |
| 631 | start_time: start_ts, |
| 632 | end_time: end_ts, |
| 633 | ar_coeff_lag: NOISE_MODEL_LAG as u8, |
| 634 | scaling_points_y, |
| 635 | scaling_points_cb, |
| 636 | scaling_points_cr, |
| 637 | scaling_shift: 5 + (8 - max_scaling_value_log2), |
| 638 | ar_coeff_shift, |
| 639 | ar_coeffs_y, |
| 640 | ar_coeffs_cb, |
| 641 | ar_coeffs_cr, |
| 642 | // At the moment, the noise modeling code assumes that the chroma scaling |
| 643 | // functions are a function of luma. |
| 644 | cb_mult: 128, |
| 645 | cb_luma_mult: 192, |
| 646 | cb_offset: 256, |
| 647 | cr_mult: 128, |
| 648 | cr_luma_mult: 192, |
| 649 | cr_offset: 256, |
| 650 | chroma_scaling_from_luma: false, |
| 651 | grain_scale_shift: 0, |
| 652 | overlap_flag: true, |
| 653 | } |
| 654 | } |
| 655 | |
| 656 | pub fn save_latest(&mut self) { |
| 657 | for c in 0..3 { |
| 658 | let latest_state = &self.latest_state[c]; |
| 659 | let combined_state = &mut self.combined_state[c]; |
| 660 | combined_state.eqns.copy_from(&latest_state.eqns); |
| 661 | combined_state |
| 662 | .strength_solver |
| 663 | .eqns |
| 664 | .copy_from(&latest_state.strength_solver.eqns); |
| 665 | combined_state.strength_solver.num_equations = |
| 666 | latest_state.strength_solver.num_equations; |
| 667 | combined_state.num_observations = latest_state.num_observations; |
| 668 | combined_state.ar_gain = latest_state.ar_gain; |
| 669 | } |
| 670 | } |
| 671 | |
| 672 | #[must_use ] |
| 673 | const fn num_coeffs() -> usize { |
| 674 | let n = 2 * NOISE_MODEL_LAG + 1; |
| 675 | (n * n) / 2 |
| 676 | } |
| 677 | |
| 678 | #[must_use ] |
| 679 | fn get_ar_coeffs_y(&self, n_coeff: usize, scale_ar_coeff: f64) -> ArrayVec<i8, NUM_Y_COEFFS> { |
| 680 | assert!(n_coeff <= NUM_Y_COEFFS); |
| 681 | let mut coeffs = ArrayVec::new(); |
| 682 | let eqns = &self.combined_state[0].eqns; |
| 683 | for i in 0..n_coeff { |
| 684 | coeffs.push(clamp((scale_ar_coeff * eqns.x[i]).round() as i32, -128i32, 127i32) as i8); |
| 685 | } |
| 686 | coeffs |
| 687 | } |
| 688 | |
| 689 | #[must_use ] |
| 690 | fn get_ar_coeffs_uv( |
| 691 | &self, |
| 692 | channel: usize, |
| 693 | n_coeff: usize, |
| 694 | scale_ar_coeff: f64, |
| 695 | y_corr: [f64; 2], |
| 696 | ) -> ArrayVec<i8, NUM_UV_COEFFS> { |
| 697 | assert!(n_coeff <= NUM_Y_COEFFS); |
| 698 | let mut coeffs = ArrayVec::new(); |
| 699 | let eqns = &self.combined_state[channel].eqns; |
| 700 | for i in 0..n_coeff { |
| 701 | coeffs.push(clamp((scale_ar_coeff * eqns.x[i]).round() as i32, -128i32, 127i32) as i8); |
| 702 | } |
| 703 | coeffs.push(clamp( |
| 704 | (scale_ar_coeff * y_corr[channel - 1]).round() as i32, |
| 705 | -128i32, |
| 706 | 127i32, |
| 707 | ) as i8); |
| 708 | coeffs |
| 709 | } |
| 710 | |
| 711 | // Return true if the noise estimate appears to be different from the combined |
| 712 | // (multi-frame) estimate. The difference is measured by checking whether the |
| 713 | // AR coefficients have diverged (using a threshold on normalized cross |
| 714 | // correlation), or whether the noise strength has changed. |
| 715 | #[must_use ] |
| 716 | fn is_different(&self) -> bool { |
| 717 | const COEFF_THRESHOLD: f64 = 0.9f64; |
| 718 | const STRENGTH_THRESHOLD: f64 = 0.005f64; |
| 719 | |
| 720 | let latest = &self.latest_state[0]; |
| 721 | let combined = &self.combined_state[0]; |
| 722 | let corr = normalized_cross_correlation(&latest.eqns.x, &combined.eqns.x, combined.eqns.n); |
| 723 | if corr < COEFF_THRESHOLD { |
| 724 | return true; |
| 725 | } |
| 726 | |
| 727 | let dx = 1.0f64 / latest.strength_solver.num_bins as f64; |
| 728 | let latest_eqns = &latest.strength_solver.eqns; |
| 729 | let combined_eqns = &combined.strength_solver.eqns; |
| 730 | let mut diff = 0.0f64; |
| 731 | let mut total_weight = 0.0f64; |
| 732 | for j in 0..latest_eqns.n { |
| 733 | let mut weight = 0.0f64; |
| 734 | for i in 0..latest_eqns.n { |
| 735 | weight += latest_eqns.a[i * latest_eqns.n + j]; |
| 736 | } |
| 737 | weight = weight.sqrt(); |
| 738 | diff += weight * (latest_eqns.x[j] - combined_eqns.x[j]).abs(); |
| 739 | total_weight += weight; |
| 740 | } |
| 741 | |
| 742 | diff * dx / total_weight > STRENGTH_THRESHOLD |
| 743 | } |
| 744 | |
| 745 | #[allow (clippy::too_many_arguments)] |
| 746 | fn add_block_observations( |
| 747 | &mut self, |
| 748 | channel: usize, |
| 749 | source: &Plane<u8>, |
| 750 | denoised: &Plane<u8>, |
| 751 | alt_source: Option<&Plane<u8>>, |
| 752 | alt_denoised: Option<&Plane<u8>>, |
| 753 | frame_dims: (usize, usize), |
| 754 | flat_blocks: &[u8], |
| 755 | num_blocks_w: usize, |
| 756 | num_blocks_h: usize, |
| 757 | ) { |
| 758 | let num_coords = self.n; |
| 759 | let state = &mut self.latest_state[channel]; |
| 760 | let a = &mut state.eqns.a; |
| 761 | let b = &mut state.eqns.b; |
| 762 | let mut buffer = vec![0f64; num_coords + 1].into_boxed_slice(); |
| 763 | let n = state.eqns.n; |
| 764 | let block_w = BLOCK_SIZE >> source.cfg.xdec; |
| 765 | let block_h = BLOCK_SIZE >> source.cfg.ydec; |
| 766 | |
| 767 | let dec = (source.cfg.xdec, source.cfg.ydec); |
| 768 | let stride = source.cfg.stride; |
| 769 | let source_origin = source.data_origin(); |
| 770 | let denoised_origin = denoised.data_origin(); |
| 771 | let alt_stride = alt_source.map_or(0, |s| s.cfg.stride); |
| 772 | let alt_source_origin = alt_source.map(|s| s.data_origin()); |
| 773 | let alt_denoised_origin = alt_denoised.map(|s| s.data_origin()); |
| 774 | |
| 775 | for by in 0..num_blocks_h { |
| 776 | let y_o = by * block_h; |
| 777 | for bx in 0..num_blocks_w { |
| 778 | // SAFETY: We know the indexes we provide do not overflow the data bounds |
| 779 | unsafe { |
| 780 | let flat_block_ptr = flat_blocks.as_ptr().add(by * num_blocks_w + bx); |
| 781 | let x_o = bx * block_w; |
| 782 | if *flat_block_ptr == 0 { |
| 783 | continue; |
| 784 | } |
| 785 | let y_start = if by > 0 && *flat_block_ptr.sub(num_blocks_w) > 0 { |
| 786 | 0 |
| 787 | } else { |
| 788 | NOISE_MODEL_LAG |
| 789 | }; |
| 790 | let x_start = if bx > 0 && *flat_block_ptr.sub(1) > 0 { |
| 791 | 0 |
| 792 | } else { |
| 793 | NOISE_MODEL_LAG |
| 794 | }; |
| 795 | let y_end = ((frame_dims.1 >> dec.1) - by * block_h).min(block_h); |
| 796 | let x_end = ((frame_dims.0 >> dec.0) - bx * block_w - NOISE_MODEL_LAG).min( |
| 797 | if bx + 1 < num_blocks_w && *flat_block_ptr.add(1) > 0 { |
| 798 | block_w |
| 799 | } else { |
| 800 | block_w - NOISE_MODEL_LAG |
| 801 | }, |
| 802 | ); |
| 803 | for y in y_start..y_end { |
| 804 | for x in x_start..x_end { |
| 805 | let val = extract_ar_row( |
| 806 | &self.coords, |
| 807 | num_coords, |
| 808 | source_origin, |
| 809 | denoised_origin, |
| 810 | stride, |
| 811 | dec, |
| 812 | alt_source_origin, |
| 813 | alt_denoised_origin, |
| 814 | alt_stride, |
| 815 | x + x_o, |
| 816 | y + y_o, |
| 817 | &mut buffer, |
| 818 | ); |
| 819 | for i in 0..n { |
| 820 | for j in 0..n { |
| 821 | *a.get_unchecked_mut(i * n + j) += (*buffer.get_unchecked(i) |
| 822 | * *buffer.get_unchecked(j)) |
| 823 | / BLOCK_NORMALIZATION.powi(2); |
| 824 | } |
| 825 | *b.get_unchecked_mut(i) += |
| 826 | (*buffer.get_unchecked(i) * val) / BLOCK_NORMALIZATION.powi(2); |
| 827 | } |
| 828 | state.num_observations += 1; |
| 829 | } |
| 830 | } |
| 831 | } |
| 832 | } |
| 833 | } |
| 834 | } |
| 835 | |
| 836 | #[allow (clippy::too_many_arguments)] |
| 837 | fn add_noise_std_observations( |
| 838 | &mut self, |
| 839 | channel: usize, |
| 840 | source: &Plane<u8>, |
| 841 | denoised: &Plane<u8>, |
| 842 | alt_source: Option<&Plane<u8>>, |
| 843 | frame_dims: (usize, usize), |
| 844 | flat_blocks: &[u8], |
| 845 | num_blocks_w: usize, |
| 846 | num_blocks_h: usize, |
| 847 | ) { |
| 848 | let coeffs = &self.latest_state[channel].eqns.x; |
| 849 | let num_coords = self.n; |
| 850 | let luma_gain = self.latest_state[0].ar_gain; |
| 851 | let noise_gain = self.latest_state[channel].ar_gain; |
| 852 | let block_w = BLOCK_SIZE >> source.cfg.xdec; |
| 853 | let block_h = BLOCK_SIZE >> source.cfg.ydec; |
| 854 | |
| 855 | for by in 0..num_blocks_h { |
| 856 | let y_o = by * block_h; |
| 857 | for bx in 0..num_blocks_w { |
| 858 | let x_o = bx * block_w; |
| 859 | if flat_blocks[by * num_blocks_w + bx] == 0 { |
| 860 | continue; |
| 861 | } |
| 862 | let num_samples_h = ((frame_dims.1 >> source.cfg.ydec) - by * block_h).min(block_h); |
| 863 | let num_samples_w = ((frame_dims.0 >> source.cfg.xdec) - bx * block_w).min(block_w); |
| 864 | // Make sure that we have a reasonable amount of samples to consider the |
| 865 | // block |
| 866 | if num_samples_w * num_samples_h > BLOCK_SIZE { |
| 867 | let block_mean = get_block_mean( |
| 868 | alt_source.unwrap_or(source), |
| 869 | frame_dims, |
| 870 | x_o << source.cfg.xdec, |
| 871 | y_o << source.cfg.ydec, |
| 872 | ); |
| 873 | let noise_var = get_noise_var( |
| 874 | source, |
| 875 | denoised, |
| 876 | ( |
| 877 | frame_dims.0 >> source.cfg.xdec, |
| 878 | frame_dims.1 >> source.cfg.ydec, |
| 879 | ), |
| 880 | x_o, |
| 881 | y_o, |
| 882 | block_w, |
| 883 | block_h, |
| 884 | ); |
| 885 | // We want to remove the part of the noise that came from being |
| 886 | // correlated with luma. Note that the noise solver for luma must |
| 887 | // have already been run. |
| 888 | let luma_strength = if channel > 0 { |
| 889 | luma_gain * self.latest_state[0].strength_solver.get_value(block_mean) |
| 890 | } else { |
| 891 | 0f64 |
| 892 | }; |
| 893 | let corr = if channel > 0 { |
| 894 | coeffs[num_coords] |
| 895 | } else { |
| 896 | 0f64 |
| 897 | }; |
| 898 | // Chroma noise: |
| 899 | // N(0, noise_var) = N(0, uncorr_var) + corr * N(0, luma_strength^2) |
| 900 | // The uncorrelated component: |
| 901 | // uncorr_var = noise_var - (corr * luma_strength)^2 |
| 902 | // But don't allow fully correlated noise (hence the max), since the |
| 903 | // synthesis cannot model it. |
| 904 | let uncorr_std = (noise_var / 16f64) |
| 905 | .max((corr * luma_strength).mul_add(-(corr * luma_strength), noise_var)) |
| 906 | .sqrt(); |
| 907 | let adjusted_strength = uncorr_std / noise_gain; |
| 908 | self.latest_state[channel] |
| 909 | .strength_solver |
| 910 | .add_measurement(block_mean, adjusted_strength); |
| 911 | } |
| 912 | } |
| 913 | } |
| 914 | } |
| 915 | } |
| 916 | |
| 917 | #[derive (Debug, Clone)] |
| 918 | struct NoiseModelState { |
| 919 | eqns: EquationSystem, |
| 920 | ar_gain: f64, |
| 921 | num_observations: usize, |
| 922 | strength_solver: StrengthSolver, |
| 923 | } |
| 924 | |
| 925 | impl NoiseModelState { |
| 926 | #[must_use ] |
| 927 | pub fn new(n: usize) -> Self { |
| 928 | const NUM_BINS: usize = 20; |
| 929 | |
| 930 | Self { |
| 931 | eqns: EquationSystem::new(n), |
| 932 | ar_gain: 1.0f64, |
| 933 | num_observations: 0usize, |
| 934 | strength_solver: StrengthSolver::new(NUM_BINS), |
| 935 | } |
| 936 | } |
| 937 | |
| 938 | pub fn ar_equation_system_solve(&mut self, is_chroma: bool) -> bool { |
| 939 | let ret = self.eqns.solve(); |
| 940 | self.ar_gain = 1.0f64; |
| 941 | if !ret { |
| 942 | return ret; |
| 943 | } |
| 944 | |
| 945 | // Update the AR gain from the equation system as it will be used to fit |
| 946 | // the noise strength as a function of intensity. In the Yule-Walker |
| 947 | // equations, the diagonal should be the variance of the correlated noise. |
| 948 | // In the case of the least squares estimate, there will be some variability |
| 949 | // in the diagonal. So use the mean of the diagonal as the estimate of |
| 950 | // overall variance (this works for least squares or Yule-Walker formulation). |
| 951 | let mut var = 0f64; |
| 952 | let n_adjusted = self.eqns.n - usize::from(is_chroma); |
| 953 | for i in 0..n_adjusted { |
| 954 | var += self.eqns.a[i * self.eqns.n + i] / self.num_observations as f64; |
| 955 | } |
| 956 | var /= n_adjusted as f64; |
| 957 | |
| 958 | // Keep track of E(Y^2) = <b, x> + E(X^2) |
| 959 | // In the case that we are using chroma and have an estimate of correlation |
| 960 | // with luma we adjust that estimate slightly to remove the correlated bits by |
| 961 | // subtracting out the last column of a scaled by our correlation estimate |
| 962 | // from b. E(y^2) = <b - A(:, end)*x(end), x> |
| 963 | let mut sum_covar = 0f64; |
| 964 | for i in 0..n_adjusted { |
| 965 | let mut bi = self.eqns.b[i]; |
| 966 | if is_chroma { |
| 967 | bi -= self.eqns.a[i * self.eqns.n + n_adjusted] * self.eqns.x[n_adjusted]; |
| 968 | } |
| 969 | sum_covar += (bi * self.eqns.x[i]) / self.num_observations as f64; |
| 970 | } |
| 971 | |
| 972 | // Now, get an estimate of the variance of uncorrelated noise signal and use |
| 973 | // it to determine the gain of the AR filter. |
| 974 | let noise_var = (var - sum_covar).max(1e-6f64); |
| 975 | self.ar_gain = 1f64.max((var / noise_var).max(1e-6f64).sqrt()); |
| 976 | ret |
| 977 | } |
| 978 | } |
| 979 | |
| 980 | #[derive (Debug, Clone)] |
| 981 | struct StrengthSolver { |
| 982 | eqns: EquationSystem, |
| 983 | num_bins: usize, |
| 984 | num_equations: usize, |
| 985 | total: f64, |
| 986 | } |
| 987 | |
| 988 | impl StrengthSolver { |
| 989 | #[must_use ] |
| 990 | pub fn new(num_bins: usize) -> Self { |
| 991 | Self { |
| 992 | eqns: EquationSystem::new(num_bins), |
| 993 | num_bins, |
| 994 | num_equations: 0usize, |
| 995 | total: 0f64, |
| 996 | } |
| 997 | } |
| 998 | |
| 999 | pub fn add_measurement(&mut self, block_mean: f64, noise_std: f64) { |
| 1000 | let bin = self.get_bin_index(block_mean); |
| 1001 | let bin_i0 = bin.floor() as usize; |
| 1002 | let bin_i1 = (self.num_bins - 1).min(bin_i0 + 1); |
| 1003 | let a = bin - bin_i0 as f64; |
| 1004 | let n = self.num_bins; |
| 1005 | let eqns = &mut self.eqns; |
| 1006 | eqns.a[bin_i0 * n + bin_i0] += (1f64 - a).powi(2); |
| 1007 | eqns.a[bin_i1 * n + bin_i0] += a * (1f64 - a); |
| 1008 | eqns.a[bin_i1 * n + bin_i1] += a.powi(2); |
| 1009 | eqns.a[bin_i0 * n + bin_i1] += (1f64 - a) * a; |
| 1010 | eqns.b[bin_i0] += (1f64 - a) * noise_std; |
| 1011 | eqns.b[bin_i1] += a * noise_std; |
| 1012 | self.total += noise_std; |
| 1013 | self.num_equations += 1; |
| 1014 | } |
| 1015 | |
| 1016 | pub fn solve(&mut self) -> bool { |
| 1017 | // Add regularization proportional to the number of constraints |
| 1018 | let n = self.num_bins; |
| 1019 | let alpha = 2f64 * self.num_equations as f64 / n as f64; |
| 1020 | |
| 1021 | // Do this in a non-destructive manner so it is not confusing to the caller |
| 1022 | let old_a = self.eqns.a.clone(); |
| 1023 | for i in 0..n { |
| 1024 | let i_lo = i.saturating_sub(1); |
| 1025 | let i_hi = (n - 1).min(i + 1); |
| 1026 | self.eqns.a[i * n + i_lo] -= alpha; |
| 1027 | self.eqns.a[i * n + i] += 2f64 * alpha; |
| 1028 | self.eqns.a[i * n + i_hi] -= alpha; |
| 1029 | } |
| 1030 | |
| 1031 | // Small regularization to give average noise strength |
| 1032 | let mean = self.total / self.num_equations as f64; |
| 1033 | for i in 0..n { |
| 1034 | self.eqns.a[i * n + i] += 1f64 / 8192f64; |
| 1035 | self.eqns.b[i] += mean / 8192f64; |
| 1036 | } |
| 1037 | let result = self.eqns.solve(); |
| 1038 | self.eqns.a = old_a; |
| 1039 | result |
| 1040 | } |
| 1041 | |
| 1042 | #[must_use ] |
| 1043 | pub fn fit_piecewise(&self, max_output_points: usize) -> NoiseStrengthLut { |
| 1044 | const TOLERANCE: f64 = 0.00625f64; |
| 1045 | |
| 1046 | let mut lut = NoiseStrengthLut::new(self.num_bins); |
| 1047 | for i in 0..self.num_bins { |
| 1048 | lut.points[i][0] = self.get_center(i); |
| 1049 | lut.points[i][1] = self.eqns.x[i]; |
| 1050 | } |
| 1051 | |
| 1052 | let mut residual = vec![0.0f64; self.num_bins]; |
| 1053 | self.update_piecewise_linear_residual(&lut, &mut residual, 0, self.num_bins); |
| 1054 | |
| 1055 | // Greedily remove points if there are too many or if it doesn't hurt local |
| 1056 | // approximation (never remove the end points) |
| 1057 | while lut.points.len() > 2 { |
| 1058 | let mut min_index = 1usize; |
| 1059 | for j in 1..(lut.points.len() - 1) { |
| 1060 | if residual[j] < residual[min_index] { |
| 1061 | min_index = j; |
| 1062 | } |
| 1063 | } |
| 1064 | let dx = lut.points[min_index + 1][0] - lut.points[min_index - 1][0]; |
| 1065 | let avg_residual = residual[min_index] / dx; |
| 1066 | if lut.points.len() <= max_output_points && avg_residual > TOLERANCE { |
| 1067 | break; |
| 1068 | } |
| 1069 | |
| 1070 | lut.points.remove(min_index); |
| 1071 | self.update_piecewise_linear_residual( |
| 1072 | &lut, |
| 1073 | &mut residual, |
| 1074 | min_index - 1, |
| 1075 | min_index + 1, |
| 1076 | ); |
| 1077 | } |
| 1078 | |
| 1079 | lut |
| 1080 | } |
| 1081 | |
| 1082 | #[must_use ] |
| 1083 | pub fn get_value(&self, x: f64) -> f64 { |
| 1084 | let bin = self.get_bin_index(x); |
| 1085 | let bin_i0 = bin.floor() as usize; |
| 1086 | let bin_i1 = (self.num_bins - 1).min(bin_i0 + 1); |
| 1087 | let a = bin - bin_i0 as f64; |
| 1088 | (1f64 - a).mul_add(self.eqns.x[bin_i0], a * self.eqns.x[bin_i1]) |
| 1089 | } |
| 1090 | |
| 1091 | pub fn clear(&mut self) { |
| 1092 | self.eqns.clear(); |
| 1093 | self.num_equations = 0; |
| 1094 | self.total = 0f64; |
| 1095 | } |
| 1096 | |
| 1097 | #[must_use ] |
| 1098 | fn get_bin_index(&self, value: f64) -> f64 { |
| 1099 | let max = 255f64; |
| 1100 | let val = clamp(value, 0f64, max); |
| 1101 | (self.num_bins - 1) as f64 * val / max |
| 1102 | } |
| 1103 | |
| 1104 | fn update_piecewise_linear_residual( |
| 1105 | &self, |
| 1106 | lut: &NoiseStrengthLut, |
| 1107 | residual: &mut [f64], |
| 1108 | start: usize, |
| 1109 | end: usize, |
| 1110 | ) { |
| 1111 | let dx = 255f64 / self.num_bins as f64; |
| 1112 | #[allow (clippy::needless_range_loop)] |
| 1113 | for i in start.max(1)..end.min(lut.points.len() - 1) { |
| 1114 | let lower = 0usize.max(self.get_bin_index(lut.points[i - 1][0]).floor() as usize); |
| 1115 | let upper = |
| 1116 | (self.num_bins - 1).min(self.get_bin_index(lut.points[i + 1][0]).ceil() as usize); |
| 1117 | let mut r = 0f64; |
| 1118 | for j in lower..=upper { |
| 1119 | let x = self.get_center(j); |
| 1120 | if x < lut.points[i - 1][0] || x >= lut.points[i + 1][0] { |
| 1121 | continue; |
| 1122 | } |
| 1123 | |
| 1124 | let y = self.eqns.x[j]; |
| 1125 | let a = (x - lut.points[i - 1][0]) / (lut.points[i + 1][0] - lut.points[i - 1][0]); |
| 1126 | let estimate_y = lut.points[i - 1][1].mul_add(1f64 - a, lut.points[i + 1][1] * a); |
| 1127 | r += (y - estimate_y).abs(); |
| 1128 | } |
| 1129 | residual[i] = r * dx; |
| 1130 | } |
| 1131 | } |
| 1132 | |
| 1133 | #[must_use ] |
| 1134 | fn get_center(&self, i: usize) -> f64 { |
| 1135 | let range = 255f64; |
| 1136 | let n = self.num_bins; |
| 1137 | i as f64 / (n - 1) as f64 * range |
| 1138 | } |
| 1139 | } |
| 1140 | |
| 1141 | impl Add<&StrengthSolver> for StrengthSolver { |
| 1142 | type Output = StrengthSolver; |
| 1143 | |
| 1144 | #[must_use ] |
| 1145 | fn add(self, addend: &StrengthSolver) -> Self::Output { |
| 1146 | let mut dest: StrengthSolver = self; |
| 1147 | dest.eqns += &addend.eqns; |
| 1148 | dest.num_equations += addend.num_equations; |
| 1149 | dest.total += addend.total; |
| 1150 | dest |
| 1151 | } |
| 1152 | } |
| 1153 | |
| 1154 | impl AddAssign<&StrengthSolver> for StrengthSolver { |
| 1155 | fn add_assign(&mut self, rhs: &StrengthSolver) { |
| 1156 | *self = self.clone() + rhs; |
| 1157 | } |
| 1158 | } |
| 1159 | |