1// Copyright (c) 2017-2022, The rav1e contributors. All rights reserved
2//
3// This source code is subject to the terms of the BSD 2 Clause License and
4// the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
5// was not distributed with this source code in the LICENSE file, you can
6// obtain it at www.aomedia.org/license/software. If the Alliance for Open
7// Media Patent License 1.0 was not distributed with this source code in the
8// PATENTS file, you can obtain it at www.aomedia.org/license/patent.
9
10use super::*;
11use crate::predict::PredictionMode;
12use crate::predict::PredictionMode::*;
13use crate::transform::TxType::*;
14use std::mem::MaybeUninit;
15
16pub const MAX_TX_SIZE: usize = 64;
17
18pub const MAX_CODED_TX_SIZE: usize = 32;
19pub const MAX_CODED_TX_SQUARE: usize = MAX_CODED_TX_SIZE * MAX_CODED_TX_SIZE;
20
21pub const TX_SIZE_SQR_CONTEXTS: usize = 4; // Coded tx_size <= 32x32, so is the # of CDF contexts from tx sizes
22
23pub const TX_SETS: usize = 6;
24pub const TX_SETS_INTRA: usize = 3;
25pub const TX_SETS_INTER: usize = 4;
26
27pub const INTRA_MODES: usize = 13;
28pub const UV_INTRA_MODES: usize = 14;
29
30const MAX_VARTX_DEPTH: usize = 2;
31
32pub const TXFM_PARTITION_CONTEXTS: usize =
33 (TxSize::TX_SIZES - TxSize::TX_8X8 as usize) * 6 - 3;
34
35// Number of transform types in each set type
36pub static num_tx_set: [usize; TX_SETS] = [1, 2, 5, 7, 12, 16];
37pub static av1_tx_used: [[usize; TX_TYPES]; TX_SETS] = [
38 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
39 [1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
40 [1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
41 [1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
42 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
43 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
44];
45
46// Maps set types above to the indices used for intra
47static tx_set_index_intra: [i8; TX_SETS] = [0, -1, 2, 1, -1, -1];
48// Maps set types above to the indices used for inter
49static tx_set_index_inter: [i8; TX_SETS] = [0, 3, -1, -1, 2, 1];
50
51pub static av1_tx_ind: [[usize; TX_TYPES]; TX_SETS] = [
52 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
53 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
54 [1, 3, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
55 [1, 5, 6, 4, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0, 0, 0],
56 [3, 4, 5, 8, 6, 7, 9, 10, 11, 0, 1, 2, 0, 0, 0, 0],
57 [7, 8, 9, 12, 10, 11, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6],
58];
59
60pub static max_txsize_rect_lookup: [TxSize; BlockSize::BLOCK_SIZES_ALL] = [
61 TX_4X4, // 4x4
62 TX_4X8, // 4x8
63 TX_8X4, // 8x4
64 TX_8X8, // 8x8
65 TX_8X16, // 8x16
66 TX_16X8, // 16x8
67 TX_16X16, // 16x16
68 TX_16X32, // 16x32
69 TX_32X16, // 32x16
70 TX_32X32, // 32x32
71 TX_32X64, // 32x64
72 TX_64X32, // 64x32
73 TX_64X64, // 64x64
74 TX_64X64, // 64x128
75 TX_64X64, // 128x64
76 TX_64X64, // 128x128
77 TX_4X16, // 4x16
78 TX_16X4, // 16x4
79 TX_8X32, // 8x32
80 TX_32X8, // 32x8
81 TX_16X64, // 16x64
82 TX_64X16, // 64x16
83];
84
85pub static sub_tx_size_map: [TxSize; TxSize::TX_SIZES_ALL] = [
86 TX_4X4, // TX_4X4
87 TX_4X4, // TX_8X8
88 TX_8X8, // TX_16X16
89 TX_16X16, // TX_32X32
90 TX_32X32, // TX_64X64
91 TX_4X4, // TX_4X8
92 TX_4X4, // TX_8X4
93 TX_8X8, // TX_8X16
94 TX_8X8, // TX_16X8
95 TX_16X16, // TX_16X32
96 TX_16X16, // TX_32X16
97 TX_32X32, // TX_32X64
98 TX_32X32, // TX_64X32
99 TX_4X8, // TX_4X16
100 TX_8X4, // TX_16X4
101 TX_8X16, // TX_8X32
102 TX_16X8, // TX_32X8
103 TX_16X32, // TX_16X64
104 TX_32X16, // TX_64X16
105];
106
107#[inline]
108pub fn has_chroma(
109 bo: TileBlockOffset, bsize: BlockSize, subsampling_x: usize,
110 subsampling_y: usize, chroma_sampling: ChromaSampling,
111) -> bool {
112 if chroma_sampling == ChromaSampling::Cs400 {
113 return false;
114 };
115
116 let bw: usize = bsize.width_mi();
117 let bh: usize = bsize.height_mi();
118
119 ((bo.0.x & 0x01) == 1 || (bw & 0x01) == 0 || subsampling_x == 0)
120 && ((bo.0.y & 0x01) == 1 || (bh & 0x01) == 0 || subsampling_y == 0)
121}
122
123pub fn get_tx_set(
124 tx_size: TxSize, is_inter: bool, use_reduced_set: bool,
125) -> TxSet {
126 let tx_size_sqr_up: TxSize = tx_size.sqr_up();
127 let tx_size_sqr: TxSize = tx_size.sqr();
128
129 if tx_size_sqr_up.block_size() > BlockSize::BLOCK_32X32 {
130 return TxSet::TX_SET_DCTONLY;
131 }
132
133 if is_inter {
134 if use_reduced_set || tx_size_sqr_up == TxSize::TX_32X32 {
135 TxSet::TX_SET_INTER_3
136 } else if tx_size_sqr == TxSize::TX_16X16 {
137 TxSet::TX_SET_INTER_2
138 } else {
139 TxSet::TX_SET_INTER_1
140 }
141 } else if tx_size_sqr_up == TxSize::TX_32X32 {
142 TxSet::TX_SET_DCTONLY
143 } else if use_reduced_set || tx_size_sqr == TxSize::TX_16X16 {
144 TxSet::TX_SET_INTRA_2
145 } else {
146 TxSet::TX_SET_INTRA_1
147 }
148}
149
150pub fn get_tx_set_index(
151 tx_size: TxSize, is_inter: bool, use_reduced_set: bool,
152) -> i8 {
153 let set_type: TxSet = get_tx_set(tx_size, is_inter, use_reduced_set);
154
155 if is_inter {
156 tx_set_index_inter[set_type as usize]
157 } else {
158 tx_set_index_intra[set_type as usize]
159 }
160}
161
162static intra_mode_to_tx_type_context: [TxType; INTRA_MODES] = [
163 DCT_DCT, // DC
164 ADST_DCT, // V
165 DCT_ADST, // H
166 DCT_DCT, // D45
167 ADST_ADST, // D135
168 ADST_DCT, // D113
169 DCT_ADST, // D157
170 DCT_ADST, // D203
171 ADST_DCT, // D67
172 ADST_ADST, // SMOOTH
173 ADST_DCT, // SMOOTH_V
174 DCT_ADST, // SMOOTH_H
175 ADST_ADST, // PAETH
176];
177
178static uv2y: [PredictionMode; UV_INTRA_MODES] = [
179 DC_PRED, // UV_DC_PRED
180 V_PRED, // UV_V_PRED
181 H_PRED, // UV_H_PRED
182 D45_PRED, // UV_D45_PRED
183 D135_PRED, // UV_D135_PRED
184 D113_PRED, // UV_D113_PRED
185 D157_PRED, // UV_D157_PRED
186 D203_PRED, // UV_D203_PRED
187 D67_PRED, // UV_D67_PRED
188 SMOOTH_PRED, // UV_SMOOTH_PRED
189 SMOOTH_V_PRED, // UV_SMOOTH_V_PRED
190 SMOOTH_H_PRED, // UV_SMOOTH_H_PRED
191 PAETH_PRED, // UV_PAETH_PRED
192 DC_PRED, // CFL_PRED
193];
194
195pub fn uv_intra_mode_to_tx_type_context(pred: PredictionMode) -> TxType {
196 intra_mode_to_tx_type_context[uv2y[pred as usize] as usize]
197}
198
199// Level Map
200pub const TXB_SKIP_CONTEXTS: usize = 13;
201
202pub const EOB_COEF_CONTEXTS: usize = 9;
203
204const SIG_COEF_CONTEXTS_2D: usize = 26;
205const SIG_COEF_CONTEXTS_1D: usize = 16;
206pub const SIG_COEF_CONTEXTS_EOB: usize = 4;
207pub const SIG_COEF_CONTEXTS: usize =
208 SIG_COEF_CONTEXTS_2D + SIG_COEF_CONTEXTS_1D;
209
210const COEFF_BASE_CONTEXTS: usize = SIG_COEF_CONTEXTS;
211pub const DC_SIGN_CONTEXTS: usize = 3;
212
213const BR_TMP_OFFSET: usize = 12;
214const BR_REF_CAT: usize = 4;
215pub const LEVEL_CONTEXTS: usize = 21;
216
217pub const NUM_BASE_LEVELS: usize = 2;
218
219pub const BR_CDF_SIZE: usize = 4;
220pub const COEFF_BASE_RANGE: usize = 4 * (BR_CDF_SIZE - 1);
221
222pub const COEFF_CONTEXT_BITS: usize = 6;
223pub const COEFF_CONTEXT_MASK: usize = (1 << COEFF_CONTEXT_BITS) - 1;
224const MAX_BASE_BR_RANGE: usize = COEFF_BASE_RANGE + NUM_BASE_LEVELS + 1;
225
226const BASE_CONTEXT_POSITION_NUM: usize = 12;
227
228// Pad 4 extra columns to remove horizontal availability check.
229pub const TX_PAD_HOR_LOG2: usize = 2;
230pub const TX_PAD_HOR: usize = 4;
231// Pad 6 extra rows (2 on top and 4 on bottom) to remove vertical availability
232// check.
233pub const TX_PAD_TOP: usize = 2;
234pub const TX_PAD_BOTTOM: usize = 4;
235pub const TX_PAD_VER: usize = TX_PAD_TOP + TX_PAD_BOTTOM;
236// Pad 16 extra bytes to avoid reading overflow in SIMD optimization.
237const TX_PAD_END: usize = 16;
238pub const TX_PAD_2D: usize = (MAX_CODED_TX_SIZE + TX_PAD_HOR)
239 * (MAX_CODED_TX_SIZE + TX_PAD_VER)
240 + TX_PAD_END;
241
242const TX_CLASSES: usize = 3;
243
244#[derive(Copy, Clone, PartialEq, Eq)]
245pub enum TxClass {
246 TX_CLASS_2D = 0,
247 TX_CLASS_HORIZ = 1,
248 TX_CLASS_VERT = 2,
249}
250
251#[derive(Copy, Clone, PartialEq, Eq)]
252pub enum SegLvl {
253 SEG_LVL_ALT_Q = 0, /* Use alternate Quantizer .... */
254 SEG_LVL_ALT_LF_Y_V = 1, /* Use alternate loop filter value on y plane vertical */
255 SEG_LVL_ALT_LF_Y_H = 2, /* Use alternate loop filter value on y plane horizontal */
256 SEG_LVL_ALT_LF_U = 3, /* Use alternate loop filter value on u plane */
257 SEG_LVL_ALT_LF_V = 4, /* Use alternate loop filter value on v plane */
258 SEG_LVL_REF_FRAME = 5, /* Optional Segment reference frame */
259 SEG_LVL_SKIP = 6, /* Optional Segment (0,0) + skip mode */
260 SEG_LVL_GLOBALMV = 7,
261 SEG_LVL_MAX = 8,
262}
263
264pub const seg_feature_bits: [u32; SegLvl::SEG_LVL_MAX as usize] =
265 [8, 6, 6, 6, 6, 3, 0, 0];
266
267pub const seg_feature_is_signed: [bool; SegLvl::SEG_LVL_MAX as usize] =
268 [true, true, true, true, true, false, false, false];
269
270use crate::context::TxClass::*;
271
272pub static tx_type_to_class: [TxClass; TX_TYPES] = [
273 TX_CLASS_2D, // DCT_DCT
274 TX_CLASS_2D, // ADST_DCT
275 TX_CLASS_2D, // DCT_ADST
276 TX_CLASS_2D, // ADST_ADST
277 TX_CLASS_2D, // FLIPADST_DCT
278 TX_CLASS_2D, // DCT_FLIPADST
279 TX_CLASS_2D, // FLIPADST_FLIPADST
280 TX_CLASS_2D, // ADST_FLIPADST
281 TX_CLASS_2D, // FLIPADST_ADST
282 TX_CLASS_2D, // IDTX
283 TX_CLASS_VERT, // V_DCT
284 TX_CLASS_HORIZ, // H_DCT
285 TX_CLASS_VERT, // V_ADST
286 TX_CLASS_HORIZ, // H_ADST
287 TX_CLASS_VERT, // V_FLIPADST
288 TX_CLASS_HORIZ, // H_FLIPADST
289];
290
291pub static eob_to_pos_small: [u8; 33] = [
292 0, 1, 2, // 0-2
293 3, 3, // 3-4
294 4, 4, 4, 4, // 5-8
295 5, 5, 5, 5, 5, 5, 5, 5, // 9-16
296 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // 17-32
297];
298
299pub static eob_to_pos_large: [u8; 17] = [
300 6, // place holder
301 7, // 33-64
302 8, 8, // 65-128
303 9, 9, 9, 9, // 129-256
304 10, 10, 10, 10, 10, 10, 10, 10, // 257-512
305 11, // 513-
306];
307
308pub static k_eob_group_start: [u16; 12] =
309 [0, 1, 2, 3, 5, 9, 17, 33, 65, 129, 257, 513];
310pub static k_eob_offset_bits: [u16; 12] = [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
311
312// The ctx offset table when TX is TX_CLASS_2D.
313// TX col and row indices are clamped to 4
314
315#[rustfmt::skip]
316pub static av1_nz_map_ctx_offset: [[[i8; 5]; 5]; TxSize::TX_SIZES_ALL] = [
317 // TX_4X4
318 [
319 [ 0, 1, 6, 6, 0],
320 [ 1, 6, 6, 21, 0],
321 [ 6, 6, 21, 21, 0],
322 [ 6, 21, 21, 21, 0],
323 [ 0, 0, 0, 0, 0]
324 ],
325 // TX_8X8
326 [
327 [ 0, 1, 6, 6, 21],
328 [ 1, 6, 6, 21, 21],
329 [ 6, 6, 21, 21, 21],
330 [ 6, 21, 21, 21, 21],
331 [21, 21, 21, 21, 21]
332 ],
333 // TX_16X16
334 [
335 [ 0, 1, 6, 6, 21],
336 [ 1, 6, 6, 21, 21],
337 [ 6, 6, 21, 21, 21],
338 [ 6, 21, 21, 21, 21],
339 [21, 21, 21, 21, 21]
340 ],
341 // TX_32X32
342 [
343 [ 0, 1, 6, 6, 21],
344 [ 1, 6, 6, 21, 21],
345 [ 6, 6, 21, 21, 21],
346 [ 6, 21, 21, 21, 21],
347 [21, 21, 21, 21, 21]
348 ],
349 // TX_64X64
350 [
351 [ 0, 1, 6, 6, 21],
352 [ 1, 6, 6, 21, 21],
353 [ 6, 6, 21, 21, 21],
354 [ 6, 21, 21, 21, 21],
355 [21, 21, 21, 21, 21]
356 ],
357 // TX_4X8
358 [
359 [ 0, 11, 11, 11, 0],
360 [11, 11, 11, 11, 0],
361 [ 6, 6, 21, 21, 0],
362 [ 6, 21, 21, 21, 0],
363 [21, 21, 21, 21, 0]
364 ],
365 // TX_8X4
366 [
367 [ 0, 16, 6, 6, 21],
368 [16, 16, 6, 21, 21],
369 [16, 16, 21, 21, 21],
370 [16, 16, 21, 21, 21],
371 [ 0, 0, 0, 0, 0]
372 ],
373 // TX_8X16
374 [
375 [ 0, 11, 11, 11, 11],
376 [11, 11, 11, 11, 11],
377 [ 6, 6, 21, 21, 21],
378 [ 6, 21, 21, 21, 21],
379 [21, 21, 21, 21, 21]
380 ],
381 // TX_16X8
382 [
383 [ 0, 16, 6, 6, 21],
384 [16, 16, 6, 21, 21],
385 [16, 16, 21, 21, 21],
386 [16, 16, 21, 21, 21],
387 [16, 16, 21, 21, 21]
388 ],
389 // TX_16X32
390 [
391 [ 0, 11, 11, 11, 11],
392 [11, 11, 11, 11, 11],
393 [ 6, 6, 21, 21, 21],
394 [ 6, 21, 21, 21, 21],
395 [21, 21, 21, 21, 21]
396 ],
397 // TX_32X16
398 [
399 [ 0, 16, 6, 6, 21],
400 [16, 16, 6, 21, 21],
401 [16, 16, 21, 21, 21],
402 [16, 16, 21, 21, 21],
403 [16, 16, 21, 21, 21]
404 ],
405 // TX_32X64
406 [
407 [ 0, 11, 11, 11, 11],
408 [11, 11, 11, 11, 11],
409 [ 6, 6, 21, 21, 21],
410 [ 6, 21, 21, 21, 21],
411 [21, 21, 21, 21, 21]
412 ],
413 // TX_64X32
414 [
415 [ 0, 16, 6, 6, 21],
416 [16, 16, 6, 21, 21],
417 [16, 16, 21, 21, 21],
418 [16, 16, 21, 21, 21],
419 [16, 16, 21, 21, 21]
420 ],
421 // TX_4X16
422 [
423 [ 0, 11, 11, 11, 0],
424 [11, 11, 11, 11, 0],
425 [ 6, 6, 21, 21, 0],
426 [ 6, 21, 21, 21, 0],
427 [21, 21, 21, 21, 0]
428 ],
429 // TX_16X4
430 [
431 [ 0, 16, 6, 6, 21],
432 [16, 16, 6, 21, 21],
433 [16, 16, 21, 21, 21],
434 [16, 16, 21, 21, 21],
435 [ 0, 0, 0, 0, 0]
436 ],
437 // TX_8X32
438 [
439 [ 0, 11, 11, 11, 11],
440 [11, 11, 11, 11, 11],
441 [ 6, 6, 21, 21, 21],
442 [ 6, 21, 21, 21, 21],
443 [21, 21, 21, 21, 21]
444 ],
445 // TX_32X8
446 [
447 [ 0, 16, 6, 6, 21],
448 [16, 16, 6, 21, 21],
449 [16, 16, 21, 21, 21],
450 [16, 16, 21, 21, 21],
451 [16, 16, 21, 21, 21]
452 ],
453 // TX_16X64
454 [
455 [ 0, 11, 11, 11, 11],
456 [11, 11, 11, 11, 11],
457 [ 6, 6, 21, 21, 21],
458 [ 6, 21, 21, 21, 21],
459 [21, 21, 21, 21, 21]
460 ],
461 // TX_64X16
462 [
463 [ 0, 16, 6, 6, 21],
464 [16, 16, 6, 21, 21],
465 [16, 16, 21, 21, 21],
466 [16, 16, 21, 21, 21],
467 [16, 16, 21, 21, 21]
468 ]
469];
470
471const NZ_MAP_CTX_0: usize = SIG_COEF_CONTEXTS_2D;
472const NZ_MAP_CTX_5: usize = NZ_MAP_CTX_0 + 5;
473const NZ_MAP_CTX_10: usize = NZ_MAP_CTX_0 + 10;
474
475pub static nz_map_ctx_offset_1d: [usize; 32] = [
476 NZ_MAP_CTX_0,
477 NZ_MAP_CTX_5,
478 NZ_MAP_CTX_10,
479 NZ_MAP_CTX_10,
480 NZ_MAP_CTX_10,
481 NZ_MAP_CTX_10,
482 NZ_MAP_CTX_10,
483 NZ_MAP_CTX_10,
484 NZ_MAP_CTX_10,
485 NZ_MAP_CTX_10,
486 NZ_MAP_CTX_10,
487 NZ_MAP_CTX_10,
488 NZ_MAP_CTX_10,
489 NZ_MAP_CTX_10,
490 NZ_MAP_CTX_10,
491 NZ_MAP_CTX_10,
492 NZ_MAP_CTX_10,
493 NZ_MAP_CTX_10,
494 NZ_MAP_CTX_10,
495 NZ_MAP_CTX_10,
496 NZ_MAP_CTX_10,
497 NZ_MAP_CTX_10,
498 NZ_MAP_CTX_10,
499 NZ_MAP_CTX_10,
500 NZ_MAP_CTX_10,
501 NZ_MAP_CTX_10,
502 NZ_MAP_CTX_10,
503 NZ_MAP_CTX_10,
504 NZ_MAP_CTX_10,
505 NZ_MAP_CTX_10,
506 NZ_MAP_CTX_10,
507 NZ_MAP_CTX_10,
508];
509
510const CONTEXT_MAG_POSITION_NUM: usize = 3;
511
512static mag_ref_offset_with_txclass: [[[usize; 2]; CONTEXT_MAG_POSITION_NUM];
513 3] = [
514 [[0, 1], [1, 0], [1, 1]],
515 [[0, 1], [1, 0], [0, 2]],
516 [[0, 1], [1, 0], [2, 0]],
517];
518
519// End of Level Map
520
521pub struct TXB_CTX {
522 pub txb_skip_ctx: usize,
523 pub dc_sign_ctx: usize,
524}
525
526impl<'a> ContextWriter<'a> {
527 /// # Panics
528 ///
529 /// - If an invalid combination of `tx_type` and `tx_size` is passed
530 pub fn write_tx_type<W: Writer>(
531 &mut self, w: &mut W, tx_size: TxSize, tx_type: TxType,
532 y_mode: PredictionMode, is_inter: bool, use_reduced_tx_set: bool,
533 ) {
534 let square_tx_size = tx_size.sqr();
535 let tx_set = get_tx_set(tx_size, is_inter, use_reduced_tx_set);
536 let num_tx_types = num_tx_set[tx_set as usize];
537
538 if num_tx_types > 1 {
539 let tx_set_index =
540 get_tx_set_index(tx_size, is_inter, use_reduced_tx_set);
541 assert!(tx_set_index > 0);
542 assert!(av1_tx_used[tx_set as usize][tx_type as usize] != 0);
543
544 if is_inter {
545 let s = av1_tx_ind[tx_set as usize][tx_type as usize] as u32;
546 if tx_set_index == 1 {
547 let cdf = &self.fc.inter_tx_1_cdf[square_tx_size as usize];
548 symbol_with_update!(self, w, s, cdf);
549 } else if tx_set_index == 2 {
550 let cdf = &self.fc.inter_tx_2_cdf[square_tx_size as usize];
551 symbol_with_update!(self, w, s, cdf);
552 } else {
553 let cdf = &self.fc.inter_tx_3_cdf[square_tx_size as usize];
554 symbol_with_update!(self, w, s, cdf);
555 }
556 } else {
557 let intra_dir = y_mode;
558 // TODO: Once use_filter_intra is enabled,
559 // intra_dir =
560 // fimode_to_intradir[mbmi->filter_intra_mode_info.filter_intra_mode];
561
562 let s = av1_tx_ind[tx_set as usize][tx_type as usize] as u32;
563 if tx_set_index == 1 {
564 let cdf = &self.fc.intra_tx_1_cdf[square_tx_size as usize]
565 [intra_dir as usize];
566 symbol_with_update!(self, w, s, cdf);
567 } else {
568 let cdf = &self.fc.intra_tx_2_cdf[square_tx_size as usize]
569 [intra_dir as usize];
570 symbol_with_update!(self, w, s, cdf);
571 }
572 }
573 }
574 }
575
576 fn get_tx_size_context(
577 &self, bo: TileBlockOffset, bsize: BlockSize,
578 ) -> usize {
579 let max_tx_size = max_txsize_rect_lookup[bsize as usize];
580 let max_tx_wide = max_tx_size.width() as u8;
581 let max_tx_high = max_tx_size.height() as u8;
582 let has_above = bo.0.y > 0;
583 let has_left = bo.0.x > 0;
584 let mut above = self.bc.above_tx_context[bo.0.x] >= max_tx_wide;
585 let mut left = self.bc.left_tx_context[bo.y_in_sb()] >= max_tx_high;
586
587 if has_above {
588 let above_blk = self.bc.blocks.above_of(bo);
589 if above_blk.is_inter() {
590 above = (above_blk.n4_w << MI_SIZE_LOG2) >= max_tx_wide;
591 };
592 }
593 if has_left {
594 let left_blk = self.bc.blocks.left_of(bo);
595 if left_blk.is_inter() {
596 left = (left_blk.n4_h << MI_SIZE_LOG2) >= max_tx_high;
597 };
598 }
599 if has_above && has_left {
600 return above as usize + left as usize;
601 };
602 if has_above {
603 return above as usize;
604 };
605 if has_left {
606 return left as usize;
607 };
608 0
609 }
610
611 pub fn write_tx_size_intra<W: Writer>(
612 &mut self, w: &mut W, bo: TileBlockOffset, bsize: BlockSize,
613 tx_size: TxSize,
614 ) {
615 fn tx_size_to_depth(tx_size: TxSize, bsize: BlockSize) -> usize {
616 let mut ctx_size = max_txsize_rect_lookup[bsize as usize];
617 let mut depth: usize = 0;
618 while tx_size != ctx_size {
619 depth += 1;
620 ctx_size = sub_tx_size_map[ctx_size as usize];
621 debug_assert!(depth <= MAX_TX_DEPTH);
622 }
623 depth
624 }
625 fn bsize_to_max_depth(bsize: BlockSize) -> usize {
626 let mut tx_size: TxSize = max_txsize_rect_lookup[bsize as usize];
627 let mut depth = 0;
628 while depth < MAX_TX_DEPTH && tx_size != TX_4X4 {
629 depth += 1;
630 tx_size = sub_tx_size_map[tx_size as usize];
631 debug_assert!(depth <= MAX_TX_DEPTH);
632 }
633 depth
634 }
635 fn bsize_to_tx_size_cat(bsize: BlockSize) -> usize {
636 let mut tx_size: TxSize = max_txsize_rect_lookup[bsize as usize];
637 debug_assert!(tx_size != TX_4X4);
638 let mut depth = 0;
639 while tx_size != TX_4X4 {
640 depth += 1;
641 tx_size = sub_tx_size_map[tx_size as usize];
642 }
643 debug_assert!(depth <= MAX_TX_CATS);
644
645 depth - 1
646 }
647
648 debug_assert!(!self.bc.blocks[bo].is_inter());
649 debug_assert!(bsize > BlockSize::BLOCK_4X4);
650
651 let tx_size_ctx = self.get_tx_size_context(bo, bsize);
652 let depth = tx_size_to_depth(tx_size, bsize);
653
654 let max_depths = bsize_to_max_depth(bsize);
655 let tx_size_cat = bsize_to_tx_size_cat(bsize);
656
657 debug_assert!(depth <= max_depths);
658 debug_assert!(!tx_size.is_rect() || bsize.is_rect_tx_allowed());
659
660 if tx_size_cat > 0 {
661 let cdf = &self.fc.tx_size_cdf[tx_size_cat - 1][tx_size_ctx];
662 symbol_with_update!(self, w, depth as u32, cdf);
663 } else {
664 let cdf = &self.fc.tx_size_8x8_cdf[tx_size_ctx];
665 symbol_with_update!(self, w, depth as u32, cdf);
666 }
667 }
668
669 // Based on https://aomediacodec.github.io/av1-spec/#cdf-selection-process
670 // Used to decide the cdf (context) for txfm_split
671 fn get_above_tx_width(
672 &self, bo: TileBlockOffset, _bsize: BlockSize, _tx_size: TxSize,
673 first_tx: bool,
674 ) -> usize {
675 let has_above = bo.0.y > 0;
676 if first_tx {
677 if !has_above {
678 return 64;
679 }
680 let above_blk = self.bc.blocks.above_of(bo);
681 if above_blk.skip && above_blk.is_inter() {
682 return above_blk.bsize.width();
683 }
684 }
685 self.bc.above_tx_context[bo.0.x] as usize
686 }
687
688 fn get_left_tx_height(
689 &self, bo: TileBlockOffset, _bsize: BlockSize, _tx_size: TxSize,
690 first_tx: bool,
691 ) -> usize {
692 let has_left = bo.0.x > 0;
693 if first_tx {
694 if !has_left {
695 return 64;
696 }
697 let left_blk = self.bc.blocks.left_of(bo);
698 if left_blk.skip && left_blk.is_inter() {
699 return left_blk.bsize.height();
700 }
701 }
702 self.bc.left_tx_context[bo.y_in_sb()] as usize
703 }
704
705 fn txfm_partition_context(
706 &self, bo: TileBlockOffset, bsize: BlockSize, tx_size: TxSize, tbx: usize,
707 tby: usize,
708 ) -> usize {
709 debug_assert!(tx_size > TX_4X4);
710 debug_assert!(bsize > BlockSize::BLOCK_4X4);
711
712 // TODO: from 2nd level partition, must know whether the tx block is the topmost(or leftmost) within a partition
713 let above = (self.get_above_tx_width(bo, bsize, tx_size, tby == 0)
714 < tx_size.width()) as usize;
715 let left = (self.get_left_tx_height(bo, bsize, tx_size, tbx == 0)
716 < tx_size.height()) as usize;
717
718 let max_tx_size: TxSize = bsize.tx_size().sqr_up();
719 let category: usize = (tx_size.sqr_up() != max_tx_size) as usize
720 + (TxSize::TX_SIZES - 1 - max_tx_size as usize) * 2;
721
722 debug_assert!(category < TXFM_PARTITION_CONTEXTS);
723
724 category * 3 + above + left
725 }
726
727 pub fn write_tx_size_inter<W: Writer>(
728 &mut self, w: &mut W, bo: TileBlockOffset, bsize: BlockSize,
729 tx_size: TxSize, txfm_split: bool, tbx: usize, tby: usize, depth: usize,
730 ) {
731 if bo.0.x >= self.bc.blocks.cols() || bo.0.y >= self.bc.blocks.rows() {
732 return;
733 }
734 debug_assert!(self.bc.blocks[bo].is_inter());
735 debug_assert!(bsize > BlockSize::BLOCK_4X4);
736 debug_assert!(!tx_size.is_rect() || bsize.is_rect_tx_allowed());
737
738 if tx_size != TX_4X4 && depth < MAX_VARTX_DEPTH {
739 let ctx = self.txfm_partition_context(bo, bsize, tx_size, tbx, tby);
740 let cdf = &self.fc.txfm_partition_cdf[ctx];
741 symbol_with_update!(self, w, txfm_split as u32, cdf);
742 } else {
743 debug_assert!(!txfm_split);
744 }
745
746 if !txfm_split {
747 self.bc.update_tx_size_context(bo, tx_size.block_size(), tx_size, false);
748 } else {
749 // if txfm_split == true, split one level only
750 let split_tx_size = sub_tx_size_map[tx_size as usize];
751 let bw = bsize.width_mi() / split_tx_size.width_mi();
752 let bh = bsize.height_mi() / split_tx_size.height_mi();
753
754 for by in 0..bh {
755 for bx in 0..bw {
756 let tx_bo = TileBlockOffset(BlockOffset {
757 x: bo.0.x + bx * split_tx_size.width_mi(),
758 y: bo.0.y + by * split_tx_size.height_mi(),
759 });
760 self.write_tx_size_inter(
761 w,
762 tx_bo,
763 bsize,
764 split_tx_size,
765 false,
766 bx,
767 by,
768 depth + 1,
769 );
770 }
771 }
772 }
773 }
774
775 #[inline]
776 pub const fn get_txsize_entropy_ctx(tx_size: TxSize) -> usize {
777 (tx_size.sqr() as usize + tx_size.sqr_up() as usize + 1) >> 1
778 }
779
780 pub fn txb_init_levels<T: Coefficient>(
781 &self, coeffs: &[T], height: usize, levels: &mut [u8],
782 levels_stride: usize,
783 ) {
784 // Coefficients and levels are transposed from how they work in the spec
785 for (coeffs_col, levels_col) in
786 coeffs.chunks_exact(height).zip(levels.chunks_exact_mut(levels_stride))
787 {
788 for (coeff, level) in coeffs_col.iter().zip(levels_col) {
789 *level = coeff.abs().min(T::cast_from(127)).as_();
790 }
791 }
792 }
793
794 // Since the coefficients and levels are transposed in relation to how they
795 // work in the spec, use the log of block height in our calculations instead
796 // of block width.
797 #[inline]
798 pub const fn get_txb_bhl(tx_size: TxSize) -> usize {
799 av1_get_coded_tx_size(tx_size).height_log2()
800 }
801
802 /// Returns `(eob_pt, eob_extra)`
803 ///
804 /// # Panics
805 ///
806 /// - If `eob` is prior to the start of the group
807 #[inline]
808 pub fn get_eob_pos_token(eob: u16) -> (u32, u32) {
809 let t = if eob < 33 {
810 eob_to_pos_small[usize::from(eob)] as u32
811 } else {
812 let e = usize::from(cmp::min((eob - 1) >> 5, 16));
813 eob_to_pos_large[e] as u32
814 };
815 assert!(eob as i32 >= k_eob_group_start[t as usize] as i32);
816 let extra = eob as u32 - k_eob_group_start[t as usize] as u32;
817
818 (t, extra)
819 }
820
821 pub fn get_nz_mag(levels: &[u8], bhl: usize, tx_class: TxClass) -> usize {
822 // Levels are transposed from how they work in the spec
823
824 // May version.
825 // Note: AOMMIN(level, 3) is useless for decoder since level < 3.
826 let mut mag = cmp::min(3, levels[1]); // { 1, 0 }
827 mag += cmp::min(3, levels[(1 << bhl) + TX_PAD_HOR]); // { 0, 1 }
828
829 if tx_class == TX_CLASS_2D {
830 mag += cmp::min(3, levels[(1 << bhl) + TX_PAD_HOR + 1]); // { 1, 1 }
831 mag += cmp::min(3, levels[2]); // { 2, 0 }
832 mag += cmp::min(3, levels[(2 << bhl) + (2 << TX_PAD_HOR_LOG2)]); // { 0, 2 }
833 } else if tx_class == TX_CLASS_VERT {
834 mag += cmp::min(3, levels[2]); // { 2, 0 }
835 mag += cmp::min(3, levels[3]); // { 3, 0 }
836 mag += cmp::min(3, levels[4]); // { 4, 0 }
837 } else {
838 mag += cmp::min(3, levels[(2 << bhl) + (2 << TX_PAD_HOR_LOG2)]); // { 0, 2 }
839 mag += cmp::min(3, levels[(3 << bhl) + (3 << TX_PAD_HOR_LOG2)]); // { 0, 3 }
840 mag += cmp::min(3, levels[(4 << bhl) + (4 << TX_PAD_HOR_LOG2)]); // { 0, 4 }
841 }
842
843 mag as usize
844 }
845
846 fn get_nz_map_ctx_from_stats(
847 stats: usize,
848 coeff_idx: usize, // raster order
849 bhl: usize,
850 tx_size: TxSize,
851 tx_class: TxClass,
852 ) -> usize {
853 if (tx_class as u32 | coeff_idx as u32) == 0 {
854 return 0;
855 };
856
857 // Coefficients are transposed from how they work in the spec
858 let col: usize = coeff_idx >> bhl;
859 let row: usize = coeff_idx - (col << bhl);
860
861 let ctx = ((stats + 1) >> 1).min(4);
862
863 ctx
864 + match tx_class {
865 TX_CLASS_2D => {
866 // This is the algorithm to generate table av1_nz_map_ctx_offset[].
867 // const int width = tx_size_wide[tx_size];
868 // const int height = tx_size_high[tx_size];
869 // if (width < height) {
870 // if (row < 2) return 11 + ctx;
871 // } else if (width > height) {
872 // if (col < 2) return 16 + ctx;
873 // }
874 // if (row + col < 2) return ctx + 1;
875 // if (row + col < 4) return 5 + ctx + 1;
876 // return 21 + ctx;
877 av1_nz_map_ctx_offset[tx_size as usize][cmp::min(row, 4)]
878 [cmp::min(col, 4)] as usize
879 }
880 TX_CLASS_HORIZ => nz_map_ctx_offset_1d[col],
881 TX_CLASS_VERT => nz_map_ctx_offset_1d[row],
882 }
883 }
884
885 fn get_nz_map_ctx(
886 levels: &[u8], coeff_idx: usize, bhl: usize, area: usize, scan_idx: usize,
887 is_eob: bool, tx_size: TxSize, tx_class: TxClass,
888 ) -> usize {
889 if is_eob {
890 if scan_idx == 0 {
891 return 0;
892 }
893 if scan_idx <= area / 8 {
894 return 1;
895 }
896 if scan_idx <= area / 4 {
897 return 2;
898 }
899 return 3;
900 }
901
902 // Levels are transposed from how they work in the spec
903 let padded_idx = coeff_idx + ((coeff_idx >> bhl) << TX_PAD_HOR_LOG2);
904 let stats = Self::get_nz_mag(&levels[padded_idx..], bhl, tx_class);
905
906 Self::get_nz_map_ctx_from_stats(stats, coeff_idx, bhl, tx_size, tx_class)
907 }
908
909 /// `coeff_contexts_no_scan` is not in the scan order.
910 /// Value for `pos = scan[i]` is at `coeff[i]`, not at `coeff[pos]`.
911 pub fn get_nz_map_contexts<'c>(
912 &self, levels: &mut [u8], scan: &[u16], eob: u16, tx_size: TxSize,
913 tx_class: TxClass, coeff_contexts_no_scan: &'c mut [MaybeUninit<i8>],
914 ) -> &'c mut [i8] {
915 let bhl = Self::get_txb_bhl(tx_size);
916 let area = av1_get_coded_tx_size(tx_size).area();
917
918 let scan = &scan[..usize::from(eob)];
919 let coeffs = &mut coeff_contexts_no_scan[..usize::from(eob)];
920 for (i, (coeff, pos)) in
921 coeffs.iter_mut().zip(scan.iter().copied()).enumerate()
922 {
923 coeff.write(Self::get_nz_map_ctx(
924 levels,
925 pos as usize,
926 bhl,
927 area,
928 i,
929 i == usize::from(eob) - 1,
930 tx_size,
931 tx_class,
932 ) as i8);
933 }
934 // SAFETY: every element has been initialized
935 unsafe { slice_assume_init_mut(coeffs) }
936 }
937
938 pub fn get_br_ctx(
939 levels: &[u8],
940 coeff_idx: usize, // raster order
941 bhl: usize,
942 tx_class: TxClass,
943 ) -> usize {
944 // Coefficients and levels are transposed from how they work in the spec
945 let col: usize = coeff_idx >> bhl;
946 let row: usize = coeff_idx - (col << bhl);
947 let stride: usize = (1 << bhl) + TX_PAD_HOR;
948 let pos: usize = col * stride + row;
949 let mut mag: usize = (levels[pos + 1] + levels[pos + stride]) as usize;
950
951 match tx_class {
952 TX_CLASS_2D => {
953 mag += levels[pos + stride + 1] as usize;
954 mag = cmp::min((mag + 1) >> 1, 6);
955 if coeff_idx == 0 {
956 return mag;
957 }
958 if (row < 2) && (col < 2) {
959 return mag + 7;
960 }
961 }
962 TX_CLASS_HORIZ => {
963 mag += levels[pos + (stride << 1)] as usize;
964 mag = cmp::min((mag + 1) >> 1, 6);
965 if coeff_idx == 0 {
966 return mag;
967 }
968 if col == 0 {
969 return mag + 7;
970 }
971 }
972 TX_CLASS_VERT => {
973 mag += levels[pos + 2] as usize;
974 mag = cmp::min((mag + 1) >> 1, 6);
975 if coeff_idx == 0 {
976 return mag;
977 }
978 if row == 0 {
979 return mag + 7;
980 }
981 }
982 }
983
984 mag + 14
985 }
986}
987