1// Copyright (c) 2017-2022, The rav1e contributors. All rights reserved
2//
3// This source code is subject to the terms of the BSD 2 Clause License and
4// the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
5// was not distributed with this source code in the LICENSE file, you can
6// obtain it at www.aomedia.org/license/software. If the Alliance for Open
7// Media Patent License 1.0 was not distributed with this source code in the
8// PATENTS file, you can obtain it at www.aomedia.org/license/patent.
9
10#![allow(non_camel_case_types)]
11#![allow(dead_code)]
12
13#[macro_use]
14pub mod forward_shared;
15
16pub use self::forward::forward_transform;
17pub use self::inverse::inverse_transform_add;
18
19use crate::context::MI_SIZE_LOG2;
20use crate::partition::{BlockSize, BlockSize::*};
21use crate::util::*;
22
23use TxSize::*;
24
25pub mod forward;
26pub mod inverse;
27
28pub static RAV1E_TX_TYPES: &[TxType] = &[
29 TxType::DCT_DCT,
30 TxType::ADST_DCT,
31 TxType::DCT_ADST,
32 TxType::ADST_ADST,
33 // TODO: Add a speed setting for FLIPADST
34 // TxType::FLIPADST_DCT,
35 // TxType::DCT_FLIPADST,
36 // TxType::FLIPADST_FLIPADST,
37 // TxType::ADST_FLIPADST,
38 // TxType::FLIPADST_ADST,
39 TxType::IDTX,
40 TxType::V_DCT,
41 TxType::H_DCT,
42 //TxType::V_FLIPADST,
43 //TxType::H_FLIPADST,
44];
45
46pub mod consts {
47 pub static SQRT2_BITS: usize = 12;
48 pub static SQRT2: i32 = 5793; // 2^12 * sqrt(2)
49 pub static INV_SQRT2: i32 = 2896; // 2^12 / sqrt(2)
50}
51
52pub const TX_TYPES: usize = 16;
53pub const TX_TYPES_PLUS_LL: usize = 17;
54
55#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord)]
56pub enum TxType {
57 DCT_DCT = 0, // DCT in both horizontal and vertical
58 ADST_DCT = 1, // ADST in vertical, DCT in horizontal
59 DCT_ADST = 2, // DCT in vertical, ADST in horizontal
60 ADST_ADST = 3, // ADST in both directions
61 FLIPADST_DCT = 4,
62 DCT_FLIPADST = 5,
63 FLIPADST_FLIPADST = 6,
64 ADST_FLIPADST = 7,
65 FLIPADST_ADST = 8,
66 IDTX = 9,
67 V_DCT = 10,
68 H_DCT = 11,
69 V_ADST = 12,
70 H_ADST = 13,
71 V_FLIPADST = 14,
72 H_FLIPADST = 15,
73 WHT_WHT = 16,
74}
75
76impl TxType {
77 /// Compute transform type for inter chroma.
78 ///
79 /// <https://aomediacodec.github.io/av1-spec/#compute-transform-type-function>
80 #[inline]
81 pub fn uv_inter(self, uv_tx_size: TxSize) -> Self {
82 use TxType::*;
83 if uv_tx_size.sqr_up() == TX_32X32 {
84 match self {
85 IDTX => IDTX,
86 _ => DCT_DCT,
87 }
88 } else if uv_tx_size.sqr() == TX_16X16 {
89 match self {
90 V_ADST | H_ADST | V_FLIPADST | H_FLIPADST => DCT_DCT,
91 _ => self,
92 }
93 } else {
94 self
95 }
96 }
97}
98
99/// Transform Size
100#[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Eq, Ord)]
101pub enum TxSize {
102 TX_4X4,
103 TX_8X8,
104 TX_16X16,
105 TX_32X32,
106 TX_64X64,
107
108 TX_4X8,
109 TX_8X4,
110 TX_8X16,
111 TX_16X8,
112 TX_16X32,
113 TX_32X16,
114 TX_32X64,
115 TX_64X32,
116
117 TX_4X16,
118 TX_16X4,
119 TX_8X32,
120 TX_32X8,
121 TX_16X64,
122 TX_64X16,
123}
124
125impl TxSize {
126 /// Number of square transform sizes
127 pub const TX_SIZES: usize = 5;
128
129 /// Number of transform sizes (including non-square sizes)
130 pub const TX_SIZES_ALL: usize = 14 + 5;
131
132 #[inline]
133 pub const fn width(self) -> usize {
134 1 << self.width_log2()
135 }
136
137 #[inline]
138 pub const fn width_log2(self) -> usize {
139 match self {
140 TX_4X4 | TX_4X8 | TX_4X16 => 2,
141 TX_8X8 | TX_8X4 | TX_8X16 | TX_8X32 => 3,
142 TX_16X16 | TX_16X8 | TX_16X32 | TX_16X4 | TX_16X64 => 4,
143 TX_32X32 | TX_32X16 | TX_32X64 | TX_32X8 => 5,
144 TX_64X64 | TX_64X32 | TX_64X16 => 6,
145 }
146 }
147
148 #[inline]
149 pub const fn width_index(self) -> usize {
150 self.width_log2() - TX_4X4.width_log2()
151 }
152
153 #[inline]
154 pub const fn height(self) -> usize {
155 1 << self.height_log2()
156 }
157
158 #[inline]
159 pub const fn height_log2(self) -> usize {
160 match self {
161 TX_4X4 | TX_8X4 | TX_16X4 => 2,
162 TX_8X8 | TX_4X8 | TX_16X8 | TX_32X8 => 3,
163 TX_16X16 | TX_8X16 | TX_32X16 | TX_4X16 | TX_64X16 => 4,
164 TX_32X32 | TX_16X32 | TX_64X32 | TX_8X32 => 5,
165 TX_64X64 | TX_32X64 | TX_16X64 => 6,
166 }
167 }
168
169 #[inline]
170 pub const fn height_index(self) -> usize {
171 self.height_log2() - TX_4X4.height_log2()
172 }
173
174 #[inline]
175 pub const fn width_mi(self) -> usize {
176 self.width() >> MI_SIZE_LOG2
177 }
178
179 #[inline]
180 pub const fn area(self) -> usize {
181 1 << self.area_log2()
182 }
183
184 #[inline]
185 pub const fn area_log2(self) -> usize {
186 self.width_log2() + self.height_log2()
187 }
188
189 #[inline]
190 pub const fn height_mi(self) -> usize {
191 self.height() >> MI_SIZE_LOG2
192 }
193
194 #[inline]
195 pub const fn block_size(self) -> BlockSize {
196 match self {
197 TX_4X4 => BLOCK_4X4,
198 TX_8X8 => BLOCK_8X8,
199 TX_16X16 => BLOCK_16X16,
200 TX_32X32 => BLOCK_32X32,
201 TX_64X64 => BLOCK_64X64,
202 TX_4X8 => BLOCK_4X8,
203 TX_8X4 => BLOCK_8X4,
204 TX_8X16 => BLOCK_8X16,
205 TX_16X8 => BLOCK_16X8,
206 TX_16X32 => BLOCK_16X32,
207 TX_32X16 => BLOCK_32X16,
208 TX_32X64 => BLOCK_32X64,
209 TX_64X32 => BLOCK_64X32,
210 TX_4X16 => BLOCK_4X16,
211 TX_16X4 => BLOCK_16X4,
212 TX_8X32 => BLOCK_8X32,
213 TX_32X8 => BLOCK_32X8,
214 TX_16X64 => BLOCK_16X64,
215 TX_64X16 => BLOCK_64X16,
216 }
217 }
218
219 #[inline]
220 pub const fn sqr(self) -> TxSize {
221 match self {
222 TX_4X4 | TX_4X8 | TX_8X4 | TX_4X16 | TX_16X4 => TX_4X4,
223 TX_8X8 | TX_8X16 | TX_16X8 | TX_8X32 | TX_32X8 => TX_8X8,
224 TX_16X16 | TX_16X32 | TX_32X16 | TX_16X64 | TX_64X16 => TX_16X16,
225 TX_32X32 | TX_32X64 | TX_64X32 => TX_32X32,
226 TX_64X64 => TX_64X64,
227 }
228 }
229
230 #[inline]
231 pub const fn sqr_up(self) -> TxSize {
232 match self {
233 TX_4X4 => TX_4X4,
234 TX_8X8 | TX_4X8 | TX_8X4 => TX_8X8,
235 TX_16X16 | TX_8X16 | TX_16X8 | TX_4X16 | TX_16X4 => TX_16X16,
236 TX_32X32 | TX_16X32 | TX_32X16 | TX_8X32 | TX_32X8 => TX_32X32,
237 TX_64X64 | TX_32X64 | TX_64X32 | TX_16X64 | TX_64X16 => TX_64X64,
238 }
239 }
240
241 #[inline]
242 pub fn by_dims(w: usize, h: usize) -> TxSize {
243 match (w, h) {
244 (4, 4) => TX_4X4,
245 (8, 8) => TX_8X8,
246 (16, 16) => TX_16X16,
247 (32, 32) => TX_32X32,
248 (64, 64) => TX_64X64,
249 (4, 8) => TX_4X8,
250 (8, 4) => TX_8X4,
251 (8, 16) => TX_8X16,
252 (16, 8) => TX_16X8,
253 (16, 32) => TX_16X32,
254 (32, 16) => TX_32X16,
255 (32, 64) => TX_32X64,
256 (64, 32) => TX_64X32,
257 (4, 16) => TX_4X16,
258 (16, 4) => TX_16X4,
259 (8, 32) => TX_8X32,
260 (32, 8) => TX_32X8,
261 (16, 64) => TX_16X64,
262 (64, 16) => TX_64X16,
263 _ => unreachable!(),
264 }
265 }
266
267 #[inline]
268 pub const fn is_rect(self) -> bool {
269 self.width_log2() != self.height_log2()
270 }
271}
272
273#[derive(Copy, Clone, PartialEq, Eq, PartialOrd)]
274pub enum TxSet {
275 // DCT only
276 TX_SET_DCTONLY,
277 // DCT + Identity only
278 TX_SET_INTER_3, // TX_SET_DCT_IDTX
279 // Discrete Trig transforms w/o flip (4) + Identity (1)
280 TX_SET_INTRA_2, // TX_SET_DTT4_IDTX
281 // Discrete Trig transforms w/o flip (4) + Identity (1) + 1D Hor/vert DCT (2)
282 TX_SET_INTRA_1, // TX_SET_DTT4_IDTX_1DDCT
283 // Discrete Trig transforms w/ flip (9) + Identity (1) + 1D Hor/Ver DCT (2)
284 TX_SET_INTER_2, // TX_SET_DTT9_IDTX_1DDCT
285 // Discrete Trig transforms w/ flip (9) + Identity (1) + 1D Hor/Ver (6)
286 TX_SET_INTER_1, // TX_SET_ALL16
287}
288
289/// Utility function that returns the log of the ratio of the col and row sizes.
290#[inline]
291pub fn get_rect_tx_log_ratio(col: usize, row: usize) -> i8 {
292 debug_assert!(col > 0 && row > 0);
293 ILog::ilog(self:col) as i8 - ILog::ilog(self:row) as i8
294}
295
296// performs half a butterfly
297#[inline]
298const fn half_btf(w0: i32, in0: i32, w1: i32, in1: i32, bit: usize) -> i32 {
299 // Ensure defined behaviour for when w0*in0 + w1*in1 is negative and
300 // overflows, but w0*in0 + w1*in1 + rounding isn't.
301 let result: i32 = (w0 * in0).wrapping_add(w1 * in1);
302 // Implement a version of round_shift with wrapping
303 if bit == 0 {
304 result
305 } else {
306 result.wrapping_add(1 << (bit - 1)) >> bit
307 }
308}
309
310// clamps value to a signed integer type of bit bits
311#[inline]
312fn clamp_value(value: i32, bit: usize) -> i32 {
313 let max_value: i32 = ((1i64 << (bit - 1)) - 1) as i32;
314 let min_value: i32 = (-(1i64 << (bit - 1))) as i32;
315 clamp(input:value, min_value, max_value)
316}
317
318pub fn av1_round_shift_array(arr: &mut [i32], size: usize, bit: i8) {
319 if bit == 0 {
320 return;
321 }
322 if bit > 0 {
323 let bit: usize = bit as usize;
324 arr.iter_mut().take(size).for_each(|i: &mut i32| {
325 *i = round_shift(*i, bit);
326 })
327 } else {
328 arr.iter_mut().take(size).for_each(|i: &mut i32| {
329 *i <<= -bit;
330 })
331 }
332}
333
334#[derive(Debug, Clone, Copy)]
335enum TxType1D {
336 DCT,
337 ADST,
338 FLIPADST,
339 IDTX,
340 WHT,
341}
342
343const fn get_1d_tx_types(tx_type: TxType) -> (TxType1D, TxType1D) {
344 match tx_type {
345 TxType::DCT_DCT => (TxType1D::DCT, TxType1D::DCT),
346 TxType::ADST_DCT => (TxType1D::ADST, TxType1D::DCT),
347 TxType::DCT_ADST => (TxType1D::DCT, TxType1D::ADST),
348 TxType::ADST_ADST => (TxType1D::ADST, TxType1D::ADST),
349 TxType::FLIPADST_DCT => (TxType1D::FLIPADST, TxType1D::DCT),
350 TxType::DCT_FLIPADST => (TxType1D::DCT, TxType1D::FLIPADST),
351 TxType::FLIPADST_FLIPADST => (TxType1D::FLIPADST, TxType1D::FLIPADST),
352 TxType::ADST_FLIPADST => (TxType1D::ADST, TxType1D::FLIPADST),
353 TxType::FLIPADST_ADST => (TxType1D::FLIPADST, TxType1D::ADST),
354 TxType::IDTX => (TxType1D::IDTX, TxType1D::IDTX),
355 TxType::V_DCT => (TxType1D::DCT, TxType1D::IDTX),
356 TxType::H_DCT => (TxType1D::IDTX, TxType1D::DCT),
357 TxType::V_ADST => (TxType1D::ADST, TxType1D::IDTX),
358 TxType::H_ADST => (TxType1D::IDTX, TxType1D::ADST),
359 TxType::V_FLIPADST => (TxType1D::FLIPADST, TxType1D::IDTX),
360 TxType::H_FLIPADST => (TxType1D::IDTX, TxType1D::FLIPADST),
361 TxType::WHT_WHT => (TxType1D::WHT, TxType1D::WHT),
362 }
363}
364
365const VTX_TAB: [TxType1D; TX_TYPES_PLUS_LL] = [
366 TxType1D::DCT,
367 TxType1D::ADST,
368 TxType1D::DCT,
369 TxType1D::ADST,
370 TxType1D::FLIPADST,
371 TxType1D::DCT,
372 TxType1D::FLIPADST,
373 TxType1D::ADST,
374 TxType1D::FLIPADST,
375 TxType1D::IDTX,
376 TxType1D::DCT,
377 TxType1D::IDTX,
378 TxType1D::ADST,
379 TxType1D::IDTX,
380 TxType1D::FLIPADST,
381 TxType1D::IDTX,
382 TxType1D::WHT,
383];
384
385const HTX_TAB: [TxType1D; TX_TYPES_PLUS_LL] = [
386 TxType1D::DCT,
387 TxType1D::DCT,
388 TxType1D::ADST,
389 TxType1D::ADST,
390 TxType1D::DCT,
391 TxType1D::FLIPADST,
392 TxType1D::FLIPADST,
393 TxType1D::FLIPADST,
394 TxType1D::ADST,
395 TxType1D::IDTX,
396 TxType1D::IDTX,
397 TxType1D::DCT,
398 TxType1D::IDTX,
399 TxType1D::ADST,
400 TxType1D::IDTX,
401 TxType1D::FLIPADST,
402 TxType1D::WHT,
403];
404
405#[inline]
406pub const fn valid_av1_transform(tx_size: TxSize, tx_type: TxType) -> bool {
407 let size_sq: TxSize = tx_size.sqr_up();
408 use TxSize::*;
409 use TxType::*;
410 match (size_sq, tx_type) {
411 (TX_64X64, DCT_DCT) => true,
412 (TX_64X64, _) => false,
413 (TX_32X32, DCT_DCT) => true,
414 (TX_32X32, IDTX) => true,
415 (TX_32X32, _) => false,
416 (_, _) => true,
417 }
418}
419
420#[cfg(any(test, feature = "bench"))]
421pub fn get_valid_txfm_types(tx_size: TxSize) -> &'static [TxType] {
422 let size_sq = tx_size.sqr_up();
423 use TxType::*;
424 if size_sq == TxSize::TX_64X64 {
425 &[DCT_DCT]
426 } else if size_sq == TxSize::TX_32X32 {
427 &[DCT_DCT, IDTX]
428 } else if size_sq == TxSize::TX_4X4 {
429 &[
430 DCT_DCT,
431 ADST_DCT,
432 DCT_ADST,
433 ADST_ADST,
434 FLIPADST_DCT,
435 DCT_FLIPADST,
436 FLIPADST_FLIPADST,
437 ADST_FLIPADST,
438 FLIPADST_ADST,
439 IDTX,
440 V_DCT,
441 H_DCT,
442 V_ADST,
443 H_ADST,
444 V_FLIPADST,
445 H_FLIPADST,
446 WHT_WHT,
447 ]
448 } else {
449 &[
450 DCT_DCT,
451 ADST_DCT,
452 DCT_ADST,
453 ADST_ADST,
454 FLIPADST_DCT,
455 DCT_FLIPADST,
456 FLIPADST_FLIPADST,
457 ADST_FLIPADST,
458 FLIPADST_ADST,
459 IDTX,
460 V_DCT,
461 H_DCT,
462 V_ADST,
463 H_ADST,
464 V_FLIPADST,
465 H_FLIPADST,
466 ]
467 }
468}
469
470#[cfg(test)]
471mod test {
472 use super::TxType::*;
473 use super::*;
474 use crate::context::av1_get_coded_tx_size;
475 use crate::cpu_features::CpuFeatureLevel;
476 use crate::frame::*;
477 use rand::random;
478 use std::mem::MaybeUninit;
479
480 fn test_roundtrip<T: Pixel>(
481 tx_size: TxSize, tx_type: TxType, tolerance: i16,
482 ) {
483 let cpu = CpuFeatureLevel::default();
484
485 let coeff_area: usize = av1_get_coded_tx_size(tx_size).area();
486 let mut src_storage = [T::cast_from(0); 64 * 64];
487 let src = &mut src_storage[..tx_size.area()];
488 let mut dst = Plane::from_slice(
489 &[T::zero(); 64 * 64][..tx_size.area()],
490 tx_size.width(),
491 );
492 let mut res_storage = [0i16; 64 * 64];
493 let res = &mut res_storage[..tx_size.area()];
494 let mut freq_storage = [MaybeUninit::uninit(); 64 * 64];
495 let freq = &mut freq_storage[..tx_size.area()];
496 for ((r, s), d) in
497 res.iter_mut().zip(src.iter_mut()).zip(dst.data.iter_mut())
498 {
499 *s = T::cast_from(random::<u8>());
500 *d = T::cast_from(random::<u8>());
501 *r = i16::cast_from(*s) - i16::cast_from(*d);
502 }
503 forward_transform(res, freq, tx_size.width(), tx_size, tx_type, 8, cpu);
504 // SAFETY: forward_transform initialized freq
505 let freq = unsafe { slice_assume_init_mut(freq) };
506 inverse_transform_add(
507 freq,
508 &mut dst.as_region_mut(),
509 coeff_area.try_into().unwrap(),
510 tx_size,
511 tx_type,
512 8,
513 cpu,
514 );
515
516 for (s, d) in src.iter().zip(dst.data.iter()) {
517 assert!(i16::abs(i16::cast_from(*s) - i16::cast_from(*d)) <= tolerance);
518 }
519 }
520
521 #[test]
522 fn log_tx_ratios() {
523 let combinations = [
524 (TxSize::TX_4X4, 0),
525 (TxSize::TX_8X8, 0),
526 (TxSize::TX_16X16, 0),
527 (TxSize::TX_32X32, 0),
528 (TxSize::TX_64X64, 0),
529 (TxSize::TX_4X8, -1),
530 (TxSize::TX_8X4, 1),
531 (TxSize::TX_8X16, -1),
532 (TxSize::TX_16X8, 1),
533 (TxSize::TX_16X32, -1),
534 (TxSize::TX_32X16, 1),
535 (TxSize::TX_32X64, -1),
536 (TxSize::TX_64X32, 1),
537 (TxSize::TX_4X16, -2),
538 (TxSize::TX_16X4, 2),
539 (TxSize::TX_8X32, -2),
540 (TxSize::TX_32X8, 2),
541 (TxSize::TX_16X64, -2),
542 (TxSize::TX_64X16, 2),
543 ];
544
545 for &(tx_size, expected) in combinations.iter() {
546 println!(
547 "Testing combination {:?}, {:?}",
548 tx_size.width(),
549 tx_size.height()
550 );
551 assert!(
552 get_rect_tx_log_ratio(tx_size.width(), tx_size.height()) == expected
553 );
554 }
555 }
556
557 fn roundtrips<T: Pixel>() {
558 let combinations = [
559 (TX_4X4, WHT_WHT, 0),
560 (TX_4X4, DCT_DCT, 0),
561 (TX_4X4, ADST_DCT, 0),
562 (TX_4X4, DCT_ADST, 0),
563 (TX_4X4, ADST_ADST, 0),
564 (TX_4X4, FLIPADST_DCT, 0),
565 (TX_4X4, DCT_FLIPADST, 0),
566 (TX_4X4, IDTX, 0),
567 (TX_4X4, V_DCT, 0),
568 (TX_4X4, H_DCT, 0),
569 (TX_4X4, V_ADST, 0),
570 (TX_4X4, H_ADST, 0),
571 (TX_8X8, DCT_DCT, 1),
572 (TX_8X8, ADST_DCT, 1),
573 (TX_8X8, DCT_ADST, 1),
574 (TX_8X8, ADST_ADST, 1),
575 (TX_8X8, FLIPADST_DCT, 1),
576 (TX_8X8, DCT_FLIPADST, 1),
577 (TX_8X8, IDTX, 0),
578 (TX_8X8, V_DCT, 0),
579 (TX_8X8, H_DCT, 0),
580 (TX_8X8, V_ADST, 0),
581 (TX_8X8, H_ADST, 1),
582 (TX_16X16, DCT_DCT, 1),
583 (TX_16X16, ADST_DCT, 1),
584 (TX_16X16, DCT_ADST, 1),
585 (TX_16X16, ADST_ADST, 1),
586 (TX_16X16, FLIPADST_DCT, 1),
587 (TX_16X16, DCT_FLIPADST, 1),
588 (TX_16X16, IDTX, 0),
589 (TX_16X16, V_DCT, 1),
590 (TX_16X16, H_DCT, 1),
591 // 32x transforms only use DCT_DCT and IDTX
592 (TX_32X32, DCT_DCT, 2),
593 (TX_32X32, IDTX, 0),
594 // 64x transforms only use DCT_DCT and IDTX
595 //(TX_64X64, DCT_DCT, 0),
596 (TX_4X8, DCT_DCT, 1),
597 (TX_8X4, DCT_DCT, 1),
598 (TX_4X16, DCT_DCT, 1),
599 (TX_16X4, DCT_DCT, 1),
600 (TX_8X16, DCT_DCT, 1),
601 (TX_16X8, DCT_DCT, 1),
602 (TX_8X32, DCT_DCT, 2),
603 (TX_32X8, DCT_DCT, 2),
604 (TX_16X32, DCT_DCT, 2),
605 (TX_32X16, DCT_DCT, 2),
606 ];
607 for &(tx_size, tx_type, tolerance) in combinations.iter() {
608 println!("Testing combination {:?}, {:?}", tx_size, tx_type);
609 test_roundtrip::<T>(tx_size, tx_type, tolerance);
610 }
611 }
612
613 #[test]
614 fn roundtrips_u8() {
615 roundtrips::<u8>();
616 }
617
618 #[test]
619 fn roundtrips_u16() {
620 roundtrips::<u16>();
621 }
622}
623