1// Copyright (c) 2018-2023, The rav1e contributors. All rights reserved
2//
3// This source code is subject to the terms of the BSD 2 Clause License and
4// the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
5// was not distributed with this source code in the LICENSE file, you can
6// obtain it at www.aomedia.org/license/software. If the Alliance for Open
7// Media Patent License 1.0 was not distributed with this source code in the
8// PATENTS file, you can obtain it at www.aomedia.org/license/patent.
9
10use crate::context::*;
11use crate::header::PRIMARY_REF_NONE;
12use crate::partition::BlockSize;
13use crate::rdo::spatiotemporal_scale;
14use crate::rdo::DistortionScale;
15use crate::tiling::TileStateMut;
16use crate::util::Pixel;
17use crate::FrameInvariants;
18use crate::FrameState;
19
20pub const MAX_SEGMENTS: usize = 8;
21
22#[profiling::function]
23pub fn segmentation_optimize<T: Pixel>(
24 fi: &FrameInvariants<T>, fs: &mut FrameState<T>,
25) {
26 assert!(fi.enable_segmentation);
27 fs.segmentation.enabled = true;
28
29 if fs.segmentation.enabled {
30 fs.segmentation.update_map = true;
31
32 // We don't change the values between frames.
33 fs.segmentation.update_data = fi.primary_ref_frame == PRIMARY_REF_NONE;
34
35 // Avoid going into lossless mode by never bringing qidx below 1.
36 // Because base_q_idx changes more frequently than the segmentation
37 // data, it is still possible for a segment to enter lossless, so
38 // enforcement elsewhere is needed.
39 let offset_lower_limit = 1 - fi.base_q_idx as i16;
40
41 if !fs.segmentation.update_data {
42 let mut min_segment = MAX_SEGMENTS;
43 for i in 0..MAX_SEGMENTS {
44 if fs.segmentation.features[i][SegLvl::SEG_LVL_ALT_Q as usize]
45 && fs.segmentation.data[i][SegLvl::SEG_LVL_ALT_Q as usize]
46 >= offset_lower_limit
47 {
48 min_segment = i;
49 break;
50 }
51 }
52 assert_ne!(min_segment, MAX_SEGMENTS);
53 fs.segmentation.min_segment = min_segment as u8;
54 fs.segmentation.update_threshold(fi.base_q_idx, fi.config.bit_depth);
55 return;
56 }
57
58 segmentation_optimize_inner(fi, fs, offset_lower_limit);
59
60 /* Figure out parameters */
61 fs.segmentation.preskip = false;
62 fs.segmentation.last_active_segid = 0;
63 for i in 0..MAX_SEGMENTS {
64 for j in 0..SegLvl::SEG_LVL_MAX as usize {
65 if fs.segmentation.features[i][j] {
66 fs.segmentation.last_active_segid = i as u8;
67 if j >= SegLvl::SEG_LVL_REF_FRAME as usize {
68 fs.segmentation.preskip = true;
69 }
70 }
71 }
72 }
73 }
74}
75
76// Select target quantizers for each segment by fitting to log(scale).
77fn segmentation_optimize_inner<T: Pixel>(
78 fi: &FrameInvariants<T>, fs: &mut FrameState<T>, offset_lower_limit: i16,
79) {
80 use crate::quantize::{ac_q, select_ac_qi};
81 use crate::util::kmeans;
82 use arrayvec::ArrayVec;
83
84 // Minimize the total distance from a small set of values to all scales.
85 // Find k-means of log(spatiotemporal scale), k in 3..=8
86 let c: ([_; 8], [_; 7], [_; 6], [_; 5], [_; 4], [_; 3]) = {
87 let spatiotemporal_scores =
88 &fi.coded_frame_data.as_ref().unwrap().spatiotemporal_scores;
89 let mut log2_scale_q11 = Vec::with_capacity(spatiotemporal_scores.len());
90 log2_scale_q11.extend(spatiotemporal_scores.iter().map(|&s| s.blog16()));
91 log2_scale_q11.sort_unstable();
92 let l = &log2_scale_q11;
93 (kmeans(l), kmeans(l), kmeans(l), kmeans(l), kmeans(l), kmeans(l))
94 };
95
96 // Find variance in spacing between successive log(scale)
97 let var = |c: &[i16]| {
98 let delta = ArrayVec::<_, MAX_SEGMENTS>::from_iter(
99 c.iter().skip(1).zip(c).map(|(&a, &b)| b as i64 - a as i64),
100 );
101 let mean = delta.iter().sum::<i64>() / delta.len() as i64;
102 delta.iter().map(|&d| (d - mean).pow(2)).sum::<i64>() as u64
103 };
104 let variance =
105 [var(&c.0), var(&c.1), var(&c.2), var(&c.3), var(&c.4), var(&c.5)];
106
107 // Choose the k value with minimal variance in spacing
108 let min_variance = *variance.iter().min().unwrap();
109 let position = variance.iter().rposition(|&v| v == min_variance).unwrap();
110
111 // For the selected centroids, derive a target quantizer:
112 // scale Q'^2 = Q^2
113 // See `distortion_scale_for` for more information.
114 let compute_delta = |centroids: &[i16]| {
115 use crate::util::{bexp64, blog64};
116 let log2_base_ac_q_q57 =
117 blog64(ac_q(fi.base_q_idx, 0, fi.config.bit_depth).get().into());
118 centroids
119 .iter()
120 .rev()
121 // Rewrite in log form and exponentiate:
122 // scale Q'^2 = Q^2
123 // Q' = Q / sqrt(scale)
124 // log(Q') = log(Q) - 0.5 log(scale)
125 .map(|&log2_scale_q11| {
126 bexp64(log2_base_ac_q_q57 - ((log2_scale_q11 as i64) << (57 - 11 - 1)))
127 })
128 // Find the index of the nearest quantizer to the target,
129 // and take the delta from the base quantizer index.
130 .map(|q| {
131 // Avoid going into lossless mode by never bringing qidx below 1.
132 select_ac_qi(q, fi.config.bit_depth).max(1) as i16
133 - fi.base_q_idx as i16
134 })
135 .collect::<ArrayVec<_, MAX_SEGMENTS>>()
136 };
137
138 // Compute segment deltas for best value of k
139 let seg_delta = match position {
140 0 => compute_delta(&c.0),
141 1 => compute_delta(&c.1),
142 2 => compute_delta(&c.2),
143 3 => compute_delta(&c.3),
144 4 => compute_delta(&c.4),
145 _ => compute_delta(&c.5),
146 };
147
148 // Update the segmentation data
149 fs.segmentation.min_segment = 0;
150 fs.segmentation.max_segment = seg_delta.len() as u8 - 1;
151 for (&delta, (features, data)) in seg_delta
152 .iter()
153 .zip(fs.segmentation.features.iter_mut().zip(&mut fs.segmentation.data))
154 {
155 features[SegLvl::SEG_LVL_ALT_Q as usize] = true;
156 data[SegLvl::SEG_LVL_ALT_Q as usize] = delta.max(offset_lower_limit);
157 }
158
159 fs.segmentation.update_threshold(fi.base_q_idx, fi.config.bit_depth);
160}
161
162#[profiling::function]
163pub fn select_segment<T: Pixel>(
164 fi: &FrameInvariants<T>, ts: &TileStateMut<'_, T>, tile_bo: TileBlockOffset,
165 bsize: BlockSize, skip: bool,
166) -> std::ops::RangeInclusive<u8> {
167 // If skip is true or segmentation is turned off, sidx is not coded.
168 if skip || !fi.enable_segmentation {
169 return 0..=0;
170 }
171
172 use crate::api::SegmentationLevel;
173 if fi.config.speed_settings.segmentation == SegmentationLevel::Full {
174 return ts.segmentation.min_segment..=ts.segmentation.max_segment;
175 }
176
177 let frame_bo = ts.to_frame_block_offset(tile_bo);
178 let scale = spatiotemporal_scale(fi, frame_bo, bsize);
179
180 let sidx = segment_idx_from_distortion(&ts.segmentation.threshold, scale);
181
182 // Avoid going into lossless mode by never bringing qidx below 1.
183 let sidx = sidx.max(ts.segmentation.min_segment);
184
185 if fi.config.speed_settings.segmentation == SegmentationLevel::Complex {
186 return sidx..=ts.segmentation.max_segment.min(sidx.saturating_add(1));
187 }
188
189 sidx..=sidx
190}
191
192fn segment_idx_from_distortion(
193 threshold: &[DistortionScale; MAX_SEGMENTS - 1], s: DistortionScale,
194) -> u8 {
195 threshold.partition_point(|&t: DistortionScale| s.0 < t.0) as u8
196}
197