| 1 | // Copyright (c) 2018-2023, The rav1e contributors. All rights reserved |
| 2 | // |
| 3 | // This source code is subject to the terms of the BSD 2 Clause License and |
| 4 | // the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License |
| 5 | // was not distributed with this source code in the LICENSE file, you can |
| 6 | // obtain it at www.aomedia.org/license/software. If the Alliance for Open |
| 7 | // Media Patent License 1.0 was not distributed with this source code in the |
| 8 | // PATENTS file, you can obtain it at www.aomedia.org/license/patent. |
| 9 | |
| 10 | use crate::context::*; |
| 11 | use crate::header::PRIMARY_REF_NONE; |
| 12 | use crate::partition::BlockSize; |
| 13 | use crate::rdo::spatiotemporal_scale; |
| 14 | use crate::rdo::DistortionScale; |
| 15 | use crate::tiling::TileStateMut; |
| 16 | use crate::util::Pixel; |
| 17 | use crate::FrameInvariants; |
| 18 | use crate::FrameState; |
| 19 | |
| 20 | pub const MAX_SEGMENTS: usize = 8; |
| 21 | |
| 22 | #[profiling::function ] |
| 23 | pub fn segmentation_optimize<T: Pixel>( |
| 24 | fi: &FrameInvariants<T>, fs: &mut FrameState<T>, |
| 25 | ) { |
| 26 | assert!(fi.enable_segmentation); |
| 27 | fs.segmentation.enabled = true; |
| 28 | |
| 29 | if fs.segmentation.enabled { |
| 30 | fs.segmentation.update_map = true; |
| 31 | |
| 32 | // We don't change the values between frames. |
| 33 | fs.segmentation.update_data = fi.primary_ref_frame == PRIMARY_REF_NONE; |
| 34 | |
| 35 | // Avoid going into lossless mode by never bringing qidx below 1. |
| 36 | // Because base_q_idx changes more frequently than the segmentation |
| 37 | // data, it is still possible for a segment to enter lossless, so |
| 38 | // enforcement elsewhere is needed. |
| 39 | let offset_lower_limit = 1 - fi.base_q_idx as i16; |
| 40 | |
| 41 | if !fs.segmentation.update_data { |
| 42 | let mut min_segment = MAX_SEGMENTS; |
| 43 | for i in 0..MAX_SEGMENTS { |
| 44 | if fs.segmentation.features[i][SegLvl::SEG_LVL_ALT_Q as usize] |
| 45 | && fs.segmentation.data[i][SegLvl::SEG_LVL_ALT_Q as usize] |
| 46 | >= offset_lower_limit |
| 47 | { |
| 48 | min_segment = i; |
| 49 | break; |
| 50 | } |
| 51 | } |
| 52 | assert_ne!(min_segment, MAX_SEGMENTS); |
| 53 | fs.segmentation.min_segment = min_segment as u8; |
| 54 | fs.segmentation.update_threshold(fi.base_q_idx, fi.config.bit_depth); |
| 55 | return; |
| 56 | } |
| 57 | |
| 58 | segmentation_optimize_inner(fi, fs, offset_lower_limit); |
| 59 | |
| 60 | /* Figure out parameters */ |
| 61 | fs.segmentation.preskip = false; |
| 62 | fs.segmentation.last_active_segid = 0; |
| 63 | for i in 0..MAX_SEGMENTS { |
| 64 | for j in 0..SegLvl::SEG_LVL_MAX as usize { |
| 65 | if fs.segmentation.features[i][j] { |
| 66 | fs.segmentation.last_active_segid = i as u8; |
| 67 | if j >= SegLvl::SEG_LVL_REF_FRAME as usize { |
| 68 | fs.segmentation.preskip = true; |
| 69 | } |
| 70 | } |
| 71 | } |
| 72 | } |
| 73 | } |
| 74 | } |
| 75 | |
| 76 | // Select target quantizers for each segment by fitting to log(scale). |
| 77 | fn segmentation_optimize_inner<T: Pixel>( |
| 78 | fi: &FrameInvariants<T>, fs: &mut FrameState<T>, offset_lower_limit: i16, |
| 79 | ) { |
| 80 | use crate::quantize::{ac_q, select_ac_qi}; |
| 81 | use crate::util::kmeans; |
| 82 | use arrayvec::ArrayVec; |
| 83 | |
| 84 | // Minimize the total distance from a small set of values to all scales. |
| 85 | // Find k-means of log(spatiotemporal scale), k in 3..=8 |
| 86 | let c: ([_; 8], [_; 7], [_; 6], [_; 5], [_; 4], [_; 3]) = { |
| 87 | let spatiotemporal_scores = |
| 88 | &fi.coded_frame_data.as_ref().unwrap().spatiotemporal_scores; |
| 89 | let mut log2_scale_q11 = Vec::with_capacity(spatiotemporal_scores.len()); |
| 90 | log2_scale_q11.extend(spatiotemporal_scores.iter().map(|&s| s.blog16())); |
| 91 | log2_scale_q11.sort_unstable(); |
| 92 | let l = &log2_scale_q11; |
| 93 | (kmeans(l), kmeans(l), kmeans(l), kmeans(l), kmeans(l), kmeans(l)) |
| 94 | }; |
| 95 | |
| 96 | // Find variance in spacing between successive log(scale) |
| 97 | let var = |c: &[i16]| { |
| 98 | let delta = ArrayVec::<_, MAX_SEGMENTS>::from_iter( |
| 99 | c.iter().skip(1).zip(c).map(|(&a, &b)| b as i64 - a as i64), |
| 100 | ); |
| 101 | let mean = delta.iter().sum::<i64>() / delta.len() as i64; |
| 102 | delta.iter().map(|&d| (d - mean).pow(2)).sum::<i64>() as u64 |
| 103 | }; |
| 104 | let variance = |
| 105 | [var(&c.0), var(&c.1), var(&c.2), var(&c.3), var(&c.4), var(&c.5)]; |
| 106 | |
| 107 | // Choose the k value with minimal variance in spacing |
| 108 | let min_variance = *variance.iter().min().unwrap(); |
| 109 | let position = variance.iter().rposition(|&v| v == min_variance).unwrap(); |
| 110 | |
| 111 | // For the selected centroids, derive a target quantizer: |
| 112 | // scale Q'^2 = Q^2 |
| 113 | // See `distortion_scale_for` for more information. |
| 114 | let compute_delta = |centroids: &[i16]| { |
| 115 | use crate::util::{bexp64, blog64}; |
| 116 | let log2_base_ac_q_q57 = |
| 117 | blog64(ac_q(fi.base_q_idx, 0, fi.config.bit_depth).get().into()); |
| 118 | centroids |
| 119 | .iter() |
| 120 | .rev() |
| 121 | // Rewrite in log form and exponentiate: |
| 122 | // scale Q'^2 = Q^2 |
| 123 | // Q' = Q / sqrt(scale) |
| 124 | // log(Q') = log(Q) - 0.5 log(scale) |
| 125 | .map(|&log2_scale_q11| { |
| 126 | bexp64(log2_base_ac_q_q57 - ((log2_scale_q11 as i64) << (57 - 11 - 1))) |
| 127 | }) |
| 128 | // Find the index of the nearest quantizer to the target, |
| 129 | // and take the delta from the base quantizer index. |
| 130 | .map(|q| { |
| 131 | // Avoid going into lossless mode by never bringing qidx below 1. |
| 132 | select_ac_qi(q, fi.config.bit_depth).max(1) as i16 |
| 133 | - fi.base_q_idx as i16 |
| 134 | }) |
| 135 | .collect::<ArrayVec<_, MAX_SEGMENTS>>() |
| 136 | }; |
| 137 | |
| 138 | // Compute segment deltas for best value of k |
| 139 | let seg_delta = match position { |
| 140 | 0 => compute_delta(&c.0), |
| 141 | 1 => compute_delta(&c.1), |
| 142 | 2 => compute_delta(&c.2), |
| 143 | 3 => compute_delta(&c.3), |
| 144 | 4 => compute_delta(&c.4), |
| 145 | _ => compute_delta(&c.5), |
| 146 | }; |
| 147 | |
| 148 | // Update the segmentation data |
| 149 | fs.segmentation.min_segment = 0; |
| 150 | fs.segmentation.max_segment = seg_delta.len() as u8 - 1; |
| 151 | for (&delta, (features, data)) in seg_delta |
| 152 | .iter() |
| 153 | .zip(fs.segmentation.features.iter_mut().zip(&mut fs.segmentation.data)) |
| 154 | { |
| 155 | features[SegLvl::SEG_LVL_ALT_Q as usize] = true; |
| 156 | data[SegLvl::SEG_LVL_ALT_Q as usize] = delta.max(offset_lower_limit); |
| 157 | } |
| 158 | |
| 159 | fs.segmentation.update_threshold(fi.base_q_idx, fi.config.bit_depth); |
| 160 | } |
| 161 | |
| 162 | #[profiling::function ] |
| 163 | pub fn select_segment<T: Pixel>( |
| 164 | fi: &FrameInvariants<T>, ts: &TileStateMut<'_, T>, tile_bo: TileBlockOffset, |
| 165 | bsize: BlockSize, skip: bool, |
| 166 | ) -> std::ops::RangeInclusive<u8> { |
| 167 | // If skip is true or segmentation is turned off, sidx is not coded. |
| 168 | if skip || !fi.enable_segmentation { |
| 169 | return 0..=0; |
| 170 | } |
| 171 | |
| 172 | use crate::api::SegmentationLevel; |
| 173 | if fi.config.speed_settings.segmentation == SegmentationLevel::Full { |
| 174 | return ts.segmentation.min_segment..=ts.segmentation.max_segment; |
| 175 | } |
| 176 | |
| 177 | let frame_bo = ts.to_frame_block_offset(tile_bo); |
| 178 | let scale = spatiotemporal_scale(fi, frame_bo, bsize); |
| 179 | |
| 180 | let sidx = segment_idx_from_distortion(&ts.segmentation.threshold, scale); |
| 181 | |
| 182 | // Avoid going into lossless mode by never bringing qidx below 1. |
| 183 | let sidx = sidx.max(ts.segmentation.min_segment); |
| 184 | |
| 185 | if fi.config.speed_settings.segmentation == SegmentationLevel::Complex { |
| 186 | return sidx..=ts.segmentation.max_segment.min(sidx.saturating_add(1)); |
| 187 | } |
| 188 | |
| 189 | sidx..=sidx |
| 190 | } |
| 191 | |
| 192 | fn segment_idx_from_distortion( |
| 193 | threshold: &[DistortionScale; MAX_SEGMENTS - 1], s: DistortionScale, |
| 194 | ) -> u8 { |
| 195 | threshold.partition_point(|&t: DistortionScale| s.0 < t.0) as u8 |
| 196 | } |
| 197 | |