1// Copyright (c) 2001-2016, Alliance for Open Media. All rights reserved
2// Copyright (c) 2017-2022, The rav1e contributors. All rights reserved
3//
4// This source code is subject to the terms of the BSD 2 Clause License and
5// the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6// was not distributed with this source code in the LICENSE file, you can
7// obtain it at www.aomedia.org/license/software. If the Alliance for Open
8// Media Patent License 1.0 was not distributed with this source code in the
9// PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10
11#![allow(non_camel_case_types)]
12
13cfg_if::cfg_if! {
14 if #[cfg(nasm_x86_64)] {
15 pub use crate::asm::x86::ec::*;
16 } else {
17 pub use self::rust::*;
18 }
19}
20
21use crate::context::{CDFContext, CDFContextLog, CDFOffset};
22use bitstream_io::{BigEndian, BitWrite, BitWriter};
23use std::io;
24
25pub const OD_BITRES: u8 = 3;
26const EC_PROB_SHIFT: u32 = 6;
27const EC_MIN_PROB: u32 = 4;
28type ec_window = u32;
29
30/// Public trait interface to a bitstream `Writer`: a `Counter` can be
31/// used to count bits for cost analysis without actually storing
32/// anything (using a new `WriterCounter` as a `Writer`), to record
33/// tokens for later writing (using a new `WriterRecorder` as a
34/// `Writer`) to write actual final bits out using a range encoder
35/// (using a new `WriterEncoder` as a `Writer`). A `WriterRecorder`'s
36/// contents can be replayed into a `WriterEncoder`.
37pub trait Writer {
38 /// Write a symbol `s`, using the passed in cdf reference; leaves `cdf` unchanged
39 fn symbol<const CDF_LEN: usize>(&mut self, s: u32, cdf: &[u16; CDF_LEN]);
40 /// return approximate number of fractional bits in `OD_BITRES`
41 /// precision to write a symbol `s` using the passed in cdf reference;
42 /// leaves `cdf` unchanged
43 fn symbol_bits(&self, s: u32, cdf: &[u16]) -> u32;
44 /// Write a symbol `s`, using the passed in cdf reference; updates the referenced cdf.
45 fn symbol_with_update<const CDF_LEN: usize>(
46 &mut self, s: u32, cdf: CDFOffset<CDF_LEN>, log: &mut CDFContextLog,
47 fc: &mut CDFContext,
48 );
49 /// Write a bool using passed in probability
50 fn bool(&mut self, val: bool, f: u16);
51 /// Write a single bit with flat probability
52 fn bit(&mut self, bit: u16);
53 /// Write literal `bits` with flat probability
54 fn literal(&mut self, bits: u8, s: u32);
55 /// Write passed `level` as a golomb code
56 fn write_golomb(&mut self, level: u32);
57 /// Write a value `v` in `[0, n-1]` quasi-uniformly
58 fn write_quniform(&mut self, n: u32, v: u32);
59 /// Return fractional bits needed to write a value `v` in `[0, n-1]`
60 /// quasi-uniformly
61 fn count_quniform(&self, n: u32, v: u32) -> u32;
62 /// Write symbol `v` in `[0, n-1]` with parameter `k` as finite subexponential
63 fn write_subexp(&mut self, n: u32, k: u8, v: u32);
64 /// Return fractional bits needed to write symbol v in `[0, n-1]` with
65 /// parameter k as finite subexponential
66 fn count_subexp(&self, n: u32, k: u8, v: u32) -> u32;
67 /// Write symbol `v` in `[0, n-1]` with parameter `k` as finite
68 /// subexponential based on a reference `r` also in `[0, n-1]`.
69 fn write_unsigned_subexp_with_ref(&mut self, v: u32, mx: u32, k: u8, r: u32);
70 /// Return fractional bits needed to write symbol `v` in `[0, n-1]` with
71 /// parameter `k` as finite subexponential based on a reference `r`
72 /// also in `[0, n-1]`.
73 fn count_unsigned_subexp_with_ref(
74 &self, v: u32, mx: u32, k: u8, r: u32,
75 ) -> u32;
76 /// Write symbol v in `[-(n-1), n-1]` with parameter k as finite
77 /// subexponential based on a reference ref also in `[-(n-1), n-1]`.
78 fn write_signed_subexp_with_ref(
79 &mut self, v: i32, low: i32, high: i32, k: u8, r: i32,
80 );
81 /// Return fractional bits needed to write symbol `v` in `[-(n-1), n-1]`
82 /// with parameter `k` as finite subexponential based on a reference
83 /// `r` also in `[-(n-1), n-1]`.
84 fn count_signed_subexp_with_ref(
85 &self, v: i32, low: i32, high: i32, k: u8, r: i32,
86 ) -> u32;
87 /// Return current length of range-coded bitstream in integer bits
88 fn tell(&mut self) -> u32;
89 /// Return current length of range-coded bitstream in fractional
90 /// bits with `OD_BITRES` decimal precision
91 fn tell_frac(&mut self) -> u32;
92 /// Save current point in coding/recording to a checkpoint
93 fn checkpoint(&mut self) -> WriterCheckpoint;
94 /// Restore saved position in coding/recording from a checkpoint
95 fn rollback(&mut self, _: &WriterCheckpoint);
96 /// Add additional bits from rate estimators without coding a real symbol
97 fn add_bits_frac(&mut self, bits_frac: u32);
98}
99
100/// `StorageBackend` is an internal trait used to tie a specific `Writer`
101/// implementation's storage to the generic `Writer`. It would be
102/// private, but Rust is deprecating 'private trait in a public
103/// interface' support.
104pub trait StorageBackend {
105 /// Store partially-computed range code into given storage backend
106 fn store(&mut self, fl: u16, fh: u16, nms: u16);
107 /// Return bit-length of encoded stream to date
108 fn stream_bits(&mut self) -> usize;
109 /// Backend implementation of checkpoint to pass through Writer interface
110 fn checkpoint(&mut self) -> WriterCheckpoint;
111 /// Backend implementation of rollback to pass through Writer interface
112 fn rollback(&mut self, _: &WriterCheckpoint);
113}
114
115#[derive(Debug, Clone)]
116pub struct WriterBase<S> {
117 /// The number of values in the current range.
118 rng: u16,
119 /// The number of bits of data in the current value.
120 cnt: i16,
121 #[cfg(feature = "desync_finder")]
122 /// Debug enable flag
123 debug: bool,
124 /// Extra offset added to tell() and tell_frac() to approximate costs
125 /// of actually coding a symbol
126 fake_bits_frac: u32,
127 /// Use-specific storage
128 s: S,
129}
130
131#[derive(Debug, Clone)]
132pub struct WriterCounter {
133 /// Bits that would be shifted out to date
134 bits: usize,
135}
136
137#[derive(Debug, Clone)]
138pub struct WriterRecorder {
139 /// Storage for tokens
140 storage: Vec<(u16, u16, u16)>,
141 /// Bits that would be shifted out to date
142 bits: usize,
143}
144
145#[derive(Debug, Clone)]
146pub struct WriterEncoder {
147 /// A buffer for output bytes with their associated carry flags.
148 precarry: Vec<u16>,
149 /// The low end of the current range.
150 low: ec_window,
151}
152
153#[derive(Clone)]
154pub struct WriterCheckpoint {
155 /// Stream length coded/recorded to date, in the unit used by the Writer,
156 /// which may be bytes or bits. This depends on the assumption
157 /// that a Writer will only ever restore its own Checkpoint.
158 stream_size: usize,
159 /// To be defined by backend
160 backend_var: usize,
161 /// Saved number of values in the current range.
162 rng: u16,
163 /// Saved number of bits of data in the current value.
164 cnt: i16,
165}
166
167/// Constructor for a counting Writer
168impl WriterCounter {
169 #[inline]
170 pub const fn new() -> WriterBase<WriterCounter> {
171 WriterBase::new(storage:WriterCounter { bits: 0 })
172 }
173}
174
175/// Constructor for a recording Writer
176impl WriterRecorder {
177 #[inline]
178 pub const fn new() -> WriterBase<WriterRecorder> {
179 WriterBase::new(storage:WriterRecorder { storage: Vec::new(), bits: 0 })
180 }
181}
182
183/// Constructor for a encoding Writer
184impl WriterEncoder {
185 #[inline]
186 pub const fn new() -> WriterBase<WriterEncoder> {
187 WriterBase::new(storage:WriterEncoder { precarry: Vec::new(), low: 0 })
188 }
189}
190
191/// The Counter stores nothing we write to it, it merely counts the
192/// bit usage like in an Encoder for cost analysis.
193impl StorageBackend for WriterBase<WriterCounter> {
194 #[inline]
195 fn store(&mut self, fl: u16, fh: u16, nms: u16) {
196 let (_l, r) = self.lr_compute(fl, fh, nms);
197 let d = r.leading_zeros() as usize;
198
199 self.s.bits += d;
200 self.rng = r << d;
201 }
202 #[inline]
203 fn stream_bits(&mut self) -> usize {
204 self.s.bits
205 }
206 #[inline]
207 fn checkpoint(&mut self) -> WriterCheckpoint {
208 WriterCheckpoint {
209 stream_size: self.s.bits,
210 backend_var: 0,
211 rng: self.rng,
212 // We do not use `cnt` within Counter, but setting it here allows the compiler
213 // to do a 32-bit merged load/store.
214 cnt: self.cnt,
215 }
216 }
217 #[inline]
218 fn rollback(&mut self, checkpoint: &WriterCheckpoint) {
219 self.rng = checkpoint.rng;
220 self.s.bits = checkpoint.stream_size;
221 }
222}
223
224/// The Recorder does not produce a range-coded bitstream, but it
225/// still tracks the range coding progress like in an Encoder, as it
226/// neds to be able to report bit costs for RDO decisions. It stores a
227/// pair of mostly-computed range coding values per token recorded.
228impl StorageBackend for WriterBase<WriterRecorder> {
229 #[inline]
230 fn store(&mut self, fl: u16, fh: u16, nms: u16) {
231 let (_l, r) = self.lr_compute(fl, fh, nms);
232 let d = r.leading_zeros() as usize;
233
234 self.s.bits += d;
235 self.rng = r << d;
236 self.s.storage.push((fl, fh, nms));
237 }
238 #[inline]
239 fn stream_bits(&mut self) -> usize {
240 self.s.bits
241 }
242 #[inline]
243 fn checkpoint(&mut self) -> WriterCheckpoint {
244 WriterCheckpoint {
245 stream_size: self.s.bits,
246 backend_var: self.s.storage.len(),
247 rng: self.rng,
248 cnt: self.cnt,
249 }
250 }
251 #[inline]
252 fn rollback(&mut self, checkpoint: &WriterCheckpoint) {
253 self.rng = checkpoint.rng;
254 self.cnt = checkpoint.cnt;
255 self.s.bits = checkpoint.stream_size;
256 self.s.storage.truncate(checkpoint.backend_var);
257 }
258}
259
260/// An Encoder produces an actual range-coded bitstream from passed in
261/// tokens. It does not retain any information about the coded
262/// tokens, only the resulting bitstream, and so it cannot be replayed
263/// (only checkpointed and rolled back).
264impl StorageBackend for WriterBase<WriterEncoder> {
265 fn store(&mut self, fl: u16, fh: u16, nms: u16) {
266 let (l, r) = self.lr_compute(fl, fh, nms);
267 let mut low = l + self.s.low;
268 let mut c = self.cnt;
269 let d = r.leading_zeros() as usize;
270 let mut s = c + (d as i16);
271
272 if s >= 0 {
273 c += 16;
274 let mut m = (1 << c) - 1;
275 if s >= 8 {
276 self.s.precarry.push((low >> c) as u16);
277 low &= m;
278 c -= 8;
279 m >>= 8;
280 }
281 self.s.precarry.push((low >> c) as u16);
282 s = c + (d as i16) - 24;
283 low &= m;
284 }
285 self.s.low = low << d;
286 self.rng = r << d;
287 self.cnt = s;
288 }
289 #[inline]
290 fn stream_bits(&mut self) -> usize {
291 self.s.precarry.len() * 8
292 }
293 #[inline]
294 fn checkpoint(&mut self) -> WriterCheckpoint {
295 WriterCheckpoint {
296 stream_size: self.s.precarry.len(),
297 backend_var: self.s.low as usize,
298 rng: self.rng,
299 cnt: self.cnt,
300 }
301 }
302 fn rollback(&mut self, checkpoint: &WriterCheckpoint) {
303 self.rng = checkpoint.rng;
304 self.cnt = checkpoint.cnt;
305 self.s.low = checkpoint.backend_var as ec_window;
306 self.s.precarry.truncate(checkpoint.stream_size);
307 }
308}
309
310/// A few local helper functions needed by the Writer that are not
311/// part of the public interface.
312impl<S> WriterBase<S> {
313 /// Internal constructor called by the subtypes that implement the
314 /// actual encoder and Recorder.
315 #[inline]
316 #[cfg(not(feature = "desync_finder"))]
317 const fn new(storage: S) -> Self {
318 WriterBase { rng: 0x8000, cnt: -9, fake_bits_frac: 0, s: storage }
319 }
320
321 #[inline]
322 #[cfg(feature = "desync_finder")]
323 fn new(storage: S) -> Self {
324 WriterBase {
325 rng: 0x8000,
326 cnt: -9,
327 debug: std::env::var_os("RAV1E_DEBUG").is_some(),
328 fake_bits_frac: 0,
329 s: storage,
330 }
331 }
332
333 /// Compute low and range values from token cdf values and local state
334 const fn lr_compute(&self, fl: u16, fh: u16, nms: u16) -> (ec_window, u16) {
335 let r = self.rng as u32;
336 debug_assert!(32768 <= r);
337 let mut u = (((r >> 8) * (fl as u32 >> EC_PROB_SHIFT))
338 >> (7 - EC_PROB_SHIFT))
339 + EC_MIN_PROB * nms as u32;
340 if fl >= 32768 {
341 u = r;
342 }
343 let v = (((r >> 8) * (fh as u32 >> EC_PROB_SHIFT)) >> (7 - EC_PROB_SHIFT))
344 + EC_MIN_PROB * (nms - 1) as u32;
345 (r - u, (u - v) as u16)
346 }
347
348 /// Given the current total integer number of bits used and the current value of
349 /// rng, computes the fraction number of bits used to `OD_BITRES` precision.
350 /// This is used by `od_ec_enc_tell_frac()` and `od_ec_dec_tell_frac()`.
351 /// `nbits_total`: The number of whole bits currently used, i.e., the value
352 /// returned by `od_ec_enc_tell()` or `od_ec_dec_tell()`.
353 /// `rng`: The current value of rng from either the encoder or decoder state.
354 /// Return: The number of bits scaled by `2**OD_BITRES`.
355 /// This will always be slightly larger than the exact value (e.g., all
356 /// rounding error is in the positive direction).
357 fn frac_compute(nbits_total: u32, mut rng: u32) -> u32 {
358 // To handle the non-integral number of bits still left in the encoder/decoder
359 // state, we compute the worst-case number of bits of val that must be
360 // encoded to ensure that the value is inside the range for any possible
361 // subsequent bits.
362 // The computation here is independent of val itself (the decoder does not
363 // even track that value), even though the real number of bits used after
364 // od_ec_enc_done() may be 1 smaller if rng is a power of two and the
365 // corresponding trailing bits of val are all zeros.
366 // If we did try to track that special case, then coding a value with a
367 // probability of 1/(1 << n) might sometimes appear to use more than n bits.
368 // This may help explain the surprising result that a newly initialized
369 // encoder or decoder claims to have used 1 bit.
370 let nbits = nbits_total << OD_BITRES;
371 let mut l = 0;
372 for _ in 0..OD_BITRES {
373 rng = (rng * rng) >> 15;
374 let b = rng >> 16;
375 l = (l << 1) | b;
376 rng >>= b;
377 }
378 nbits - l
379 }
380
381 const fn recenter(r: u32, v: u32) -> u32 {
382 if v > (r << 1) {
383 v
384 } else if v >= r {
385 (v - r) << 1
386 } else {
387 ((r - v) << 1) - 1
388 }
389 }
390
391 #[cfg(feature = "desync_finder")]
392 fn print_backtrace(&self, s: u32) {
393 let mut depth = 3;
394 backtrace::trace(|frame| {
395 let ip = frame.ip();
396
397 depth -= 1;
398
399 if depth == 0 {
400 backtrace::resolve(ip, |symbol| {
401 if let Some(name) = symbol.name() {
402 println!("Writing symbol {} from {}", s, name);
403 }
404 });
405 false
406 } else {
407 true
408 }
409 });
410 }
411}
412
413/// Replay implementation specific to the Recorder
414impl WriterBase<WriterRecorder> {
415 /// Replays the partially-computed range tokens out of the Recorder's
416 /// storage and into the passed in Writer, which may be an Encoder
417 /// or another Recorder. Clears the Recorder after replay.
418 pub fn replay(&mut self, dest: &mut dyn StorageBackend) {
419 for &(fl: u16, fh: u16, nms: u16) in &self.s.storage {
420 dest.store(fl, fh, nms);
421 }
422 self.rng = 0x8000;
423 self.cnt = -9;
424 self.s.storage.truncate(len:0);
425 self.s.bits = 0;
426 }
427}
428
429/// Done implementation specific to the Encoder
430impl WriterBase<WriterEncoder> {
431 /// Indicates that there are no more symbols to encode. Flushes
432 /// remaining state into coding and returns a vector containing the
433 /// final bitstream.
434 pub fn done(&mut self) -> Vec<u8> {
435 // We output the minimum number of bits that ensures that the symbols encoded
436 // thus far will be decoded correctly regardless of the bits that follow.
437 let l = self.s.low;
438 let mut c = self.cnt;
439 let mut s = 10;
440 let m = 0x3FFF;
441 let mut e = ((l + m) & !m) | (m + 1);
442
443 s += c;
444
445 if s > 0 {
446 let mut n = (1 << (c + 16)) - 1;
447
448 loop {
449 self.s.precarry.push((e >> (c + 16)) as u16);
450 e &= n;
451 s -= 8;
452 c -= 8;
453 n >>= 8;
454
455 if s <= 0 {
456 break;
457 }
458 }
459 }
460
461 let mut c = 0;
462 let mut offs = self.s.precarry.len();
463 // dynamic allocation: grows during encode
464 let mut out = vec![0_u8; offs];
465 while offs > 0 {
466 offs -= 1;
467 c += self.s.precarry[offs];
468 out[offs] = c as u8;
469 c >>= 8;
470 }
471
472 out
473 }
474}
475
476/// Generic/shared implementation for `Writer`s with `StorageBackend`s
477/// (ie, `Encoder`s and `Recorder`s)
478impl<S> Writer for WriterBase<S>
479where
480 WriterBase<S>: StorageBackend,
481{
482 /// Encode a single binary value.
483 /// `val`: The value to encode (0 or 1).
484 /// `f`: The probability that the val is one, scaled by 32768.
485 fn bool(&mut self, val: bool, f: u16) {
486 debug_assert!(0 < f);
487 debug_assert!(f < 32768);
488 self.symbol(u32::from(val), &[f, 0]);
489 }
490 /// Encode a single boolean value.
491 ///
492 /// - `val`: The value to encode (`false` or `true`).
493 /// - `f`: The probability that the `val` is `true`, scaled by `32768`.
494 fn bit(&mut self, bit: u16) {
495 self.bool(bit == 1, 16384);
496 }
497 // fake add bits
498 fn add_bits_frac(&mut self, bits_frac: u32) {
499 self.fake_bits_frac += bits_frac
500 }
501 /// Encode a literal bitstring, bit by bit in MSB order, with flat
502 /// probability.
503 ///
504 /// - 'bits': Length of bitstring
505 /// - 's': Bit string to encode
506 fn literal(&mut self, bits: u8, s: u32) {
507 for bit in (0..bits).rev() {
508 self.bit((1 & (s >> bit)) as u16);
509 }
510 }
511 /// Encodes a symbol given a cumulative distribution function (CDF) table in Q15.
512 ///
513 /// - `s`: The index of the symbol to encode.
514 /// - `cdf`: The CDF, such that symbol s falls in the range
515 /// `[s > 0 ? cdf[s - 1] : 0, cdf[s])`.
516 /// The values must be monotonically non-decreasing, and the last value
517 /// must be greater than 32704. There should be at most 16 values.
518 /// The lower 6 bits of the last value hold the count.
519 #[inline(always)]
520 fn symbol<const CDF_LEN: usize>(&mut self, s: u32, cdf: &[u16; CDF_LEN]) {
521 debug_assert!(cdf[cdf.len() - 1] < (1 << EC_PROB_SHIFT));
522 let s = s as usize;
523 debug_assert!(s < cdf.len());
524 // The above is stricter than the following overflow check: s <= cdf.len()
525 let nms = cdf.len() - s;
526 let fl = if s > 0 {
527 // SAFETY: We asserted that s is less than the length of the cdf
528 unsafe { *cdf.get_unchecked(s - 1) }
529 } else {
530 32768
531 };
532 // SAFETY: We asserted that s is less than the length of the cdf
533 let fh = unsafe { *cdf.get_unchecked(s) };
534 debug_assert!((fh >> EC_PROB_SHIFT) <= (fl >> EC_PROB_SHIFT));
535 debug_assert!(fl <= 32768);
536 self.store(fl, fh, nms as u16);
537 }
538 /// Encodes a symbol given a cumulative distribution function (CDF)
539 /// table in Q15, then updates the CDF probabilities to reflect we've
540 /// written one more symbol 's'.
541 ///
542 /// - `s`: The index of the symbol to encode.
543 /// - `cdf`: The CDF, such that symbol s falls in the range
544 /// `[s > 0 ? cdf[s - 1] : 0, cdf[s])`.
545 /// The values must be monotonically non-decreasing, and the last value
546 /// must be greater 32704. There should be at most 16 values.
547 /// The lower 6 bits of the last value hold the count.
548 fn symbol_with_update<const CDF_LEN: usize>(
549 &mut self, s: u32, cdf: CDFOffset<CDF_LEN>, log: &mut CDFContextLog,
550 fc: &mut CDFContext,
551 ) {
552 #[cfg(feature = "desync_finder")]
553 {
554 if self.debug {
555 self.print_backtrace(s);
556 }
557 }
558 let cdf = log.push(fc, cdf);
559 self.symbol(s, cdf);
560
561 update_cdf(cdf, s);
562 }
563 /// Returns approximate cost for a symbol given a cumulative
564 /// distribution function (CDF) table and current write state.
565 ///
566 /// - `s`: The index of the symbol to encode.
567 /// - `cdf`: The CDF, such that symbol s falls in the range
568 /// `[s > 0 ? cdf[s - 1] : 0, cdf[s])`.
569 /// The values must be monotonically non-decreasing, and the last value
570 /// must be greater than 32704. There should be at most 16 values.
571 /// The lower 6 bits of the last value hold the count.
572 fn symbol_bits(&self, s: u32, cdf: &[u16]) -> u32 {
573 let mut bits = 0;
574 debug_assert!(cdf[cdf.len() - 1] < (1 << EC_PROB_SHIFT));
575 debug_assert!(32768 <= self.rng);
576 let rng = (self.rng >> 8) as u32;
577 let fh = cdf[s as usize] as u32 >> EC_PROB_SHIFT;
578 let r: u32 = if s > 0 {
579 let fl = cdf[s as usize - 1] as u32 >> EC_PROB_SHIFT;
580 ((rng * fl) >> (7 - EC_PROB_SHIFT)) - ((rng * fh) >> (7 - EC_PROB_SHIFT))
581 + EC_MIN_PROB
582 } else {
583 let nms1 = cdf.len() as u32 - s - 1;
584 self.rng as u32
585 - ((rng * fh) >> (7 - EC_PROB_SHIFT))
586 - nms1 * EC_MIN_PROB
587 };
588
589 // The 9 here counteracts the offset of -9 baked into cnt. Don't include a termination bit.
590 let pre = Self::frac_compute((self.cnt + 9) as u32, self.rng as u32);
591 let d = r.leading_zeros() - 16;
592 let mut c = self.cnt;
593 let mut sh = c + (d as i16);
594 if sh >= 0 {
595 c += 16;
596 if sh >= 8 {
597 bits += 8;
598 c -= 8;
599 }
600 bits += 8;
601 sh = c + (d as i16) - 24;
602 }
603 // The 9 here counteracts the offset of -9 baked into cnt. Don't include a termination bit.
604 Self::frac_compute((bits + sh + 9) as u32, r << d) - pre
605 }
606 /// Encode a golomb to the bitstream.
607 ///
608 /// - 'level': passed in value to encode
609 fn write_golomb(&mut self, level: u32) {
610 let x = level + 1;
611 let length = 32 - x.leading_zeros();
612
613 for _ in 0..length - 1 {
614 self.bit(0);
615 }
616
617 for i in (0..length).rev() {
618 self.bit(((x >> i) & 0x01) as u16);
619 }
620 }
621 /// Write a value `v` in `[0, n-1]` quasi-uniformly
622 /// - `n`: size of interval
623 /// - `v`: value to encode
624 fn write_quniform(&mut self, n: u32, v: u32) {
625 if n > 1 {
626 let l = 32 - n.leading_zeros() as u8;
627 let m = (1 << l) - n;
628 if v < m {
629 self.literal(l - 1, v);
630 } else {
631 self.literal(l - 1, m + ((v - m) >> 1));
632 self.literal(1, (v - m) & 1);
633 }
634 }
635 }
636 /// Returns `QOD_BITRES` bits for a value `v` in `[0, n-1]` quasi-uniformly
637 /// - `n`: size of interval
638 /// - `v`: value to encode
639 fn count_quniform(&self, n: u32, v: u32) -> u32 {
640 let mut bits = 0;
641 if n > 1 {
642 let l = 32 - n.leading_zeros();
643 let m = (1 << l) - n;
644 bits += (l - 1) << OD_BITRES;
645 if v >= m {
646 bits += 1 << OD_BITRES;
647 }
648 }
649 bits
650 }
651 /// Write symbol `v` in `[0, n-1]` with parameter `k` as finite subexponential
652 ///
653 /// - `n`: size of interval
654 /// - `k`: "parameter"
655 /// - `v`: value to encode
656 fn write_subexp(&mut self, n: u32, k: u8, v: u32) {
657 let mut i = 0;
658 let mut mk = 0;
659 loop {
660 let b = if i != 0 { k + i - 1 } else { k };
661 let a = 1 << b;
662 if n <= mk + 3 * a {
663 self.write_quniform(n - mk, v - mk);
664 break;
665 } else {
666 let t = v >= mk + a;
667 self.bool(t, 16384);
668 if t {
669 i += 1;
670 mk += a;
671 } else {
672 self.literal(b, v - mk);
673 break;
674 }
675 }
676 }
677 }
678 /// Returns `QOD_BITRES` bits for symbol `v` in `[0, n-1]` with parameter `k`
679 /// as finite subexponential
680 ///
681 /// - `n`: size of interval
682 /// - `k`: "parameter"
683 /// - `v`: value to encode
684 fn count_subexp(&self, n: u32, k: u8, v: u32) -> u32 {
685 let mut i = 0;
686 let mut mk = 0;
687 let mut bits = 0;
688 loop {
689 let b = if i != 0 { k + i - 1 } else { k };
690 let a = 1 << b;
691 if n <= mk + 3 * a {
692 bits += self.count_quniform(n - mk, v - mk);
693 break;
694 } else {
695 let t = v >= mk + a;
696 bits += 1 << OD_BITRES;
697 if t {
698 i += 1;
699 mk += a;
700 } else {
701 bits += (b as u32) << OD_BITRES;
702 break;
703 }
704 }
705 }
706 bits
707 }
708 /// Write symbol `v` in `[0, n-1]` with parameter `k` as finite
709 /// subexponential based on a reference `r` also in `[0, n-1]`.
710 ///
711 /// - `v`: value to encode
712 /// - `n`: size of interval
713 /// - `k`: "parameter"
714 /// - `r`: reference
715 fn write_unsigned_subexp_with_ref(&mut self, v: u32, n: u32, k: u8, r: u32) {
716 if (r << 1) <= n {
717 self.write_subexp(n, k, Self::recenter(r, v));
718 } else {
719 self.write_subexp(n, k, Self::recenter(n - 1 - r, n - 1 - v));
720 }
721 }
722 /// Returns `QOD_BITRES` bits for symbol `v` in `[0, n-1]`
723 /// with parameter `k` as finite subexponential based on a
724 /// reference `r` also in `[0, n-1]`.
725 ///
726 /// - `v`: value to encode
727 /// - `n`: size of interval
728 /// - `k`: "parameter"
729 /// - `r`: reference
730 fn count_unsigned_subexp_with_ref(
731 &self, v: u32, n: u32, k: u8, r: u32,
732 ) -> u32 {
733 if (r << 1) <= n {
734 self.count_subexp(n, k, Self::recenter(r, v))
735 } else {
736 self.count_subexp(n, k, Self::recenter(n - 1 - r, n - 1 - v))
737 }
738 }
739 /// Write symbol `v` in `[-(n-1), n-1]` with parameter `k` as finite
740 /// subexponential based on a reference `r` also in `[-(n-1), n-1]`.
741 ///
742 /// - `v`: value to encode
743 /// - `n`: size of interval
744 /// - `k`: "parameter"
745 /// - `r`: reference
746 fn write_signed_subexp_with_ref(
747 &mut self, v: i32, low: i32, high: i32, k: u8, r: i32,
748 ) {
749 self.write_unsigned_subexp_with_ref(
750 (v - low) as u32,
751 (high - low) as u32,
752 k,
753 (r - low) as u32,
754 );
755 }
756 /// Returns `QOD_BITRES` bits for symbol `v` in `[-(n-1), n-1]`
757 /// with parameter `k` as finite subexponential based on a
758 /// reference `r` also in `[-(n-1), n-1]`.
759 ///
760 /// - `v`: value to encode
761 /// - `n`: size of interval
762 /// - `k`: "parameter"
763 /// - `r`: reference
764
765 fn count_signed_subexp_with_ref(
766 &self, v: i32, low: i32, high: i32, k: u8, r: i32,
767 ) -> u32 {
768 self.count_unsigned_subexp_with_ref(
769 (v - low) as u32,
770 (high - low) as u32,
771 k,
772 (r - low) as u32,
773 )
774 }
775 /// Returns the number of bits "used" by the encoded symbols so far.
776 /// This same number can be computed in either the encoder or the
777 /// decoder, and is suitable for making coding decisions. The value
778 /// will be the same whether using an `Encoder` or `Recorder`.
779 ///
780 /// Return: The integer number of bits.
781 /// This will always be slightly larger than the exact value (e.g., all
782 /// rounding error is in the positive direction).
783 fn tell(&mut self) -> u32 {
784 // The 10 here counteracts the offset of -9 baked into cnt, and adds 1 extra
785 // bit, which we reserve for terminating the stream.
786 (((self.stream_bits()) as i32) + (self.cnt as i32) + 10) as u32
787 + (self.fake_bits_frac >> 8)
788 }
789 /// Returns the number of bits "used" by the encoded symbols so far.
790 /// This same number can be computed in either the encoder or the
791 /// decoder, and is suitable for making coding decisions. The value
792 /// will be the same whether using an `Encoder` or `Recorder`.
793 ///
794 /// Return: The number of bits scaled by `2**OD_BITRES`.
795 /// This will always be slightly larger than the exact value (e.g., all
796 /// rounding error is in the positive direction).
797 fn tell_frac(&mut self) -> u32 {
798 Self::frac_compute(self.tell(), self.rng as u32) + self.fake_bits_frac
799 }
800 /// Save current point in coding/recording to a checkpoint that can
801 /// be restored later. A `WriterCheckpoint` can be generated for an
802 /// `Encoder` or `Recorder`, but can only be used to rollback the `Writer`
803 /// instance from which it was generated.
804 fn checkpoint(&mut self) -> WriterCheckpoint {
805 StorageBackend::checkpoint(self)
806 }
807 /// Roll back a given `Writer` to the state saved in the `WriterCheckpoint`
808 ///
809 /// - 'wc': Saved `Writer` state/posiiton to restore
810 fn rollback(&mut self, wc: &WriterCheckpoint) {
811 StorageBackend::rollback(self, wc)
812 }
813}
814
815pub trait BCodeWriter {
816 fn recenter_nonneg(&mut self, r: u16, v: u16) -> u16;
817 fn recenter_finite_nonneg(&mut self, n: u16, r: u16, v: u16) -> u16;
818 /// # Errors
819 ///
820 /// - Returns `std::io::Error` if the writer cannot be written to.
821 fn write_quniform(&mut self, n: u16, v: u16) -> Result<(), std::io::Error>;
822 /// # Errors
823 ///
824 /// - Returns `std::io::Error` if the writer cannot be written to.
825 fn write_subexpfin(
826 &mut self, n: u16, k: u16, v: u16,
827 ) -> Result<(), std::io::Error>;
828 /// # Errors
829 ///
830 /// - Returns `std::io::Error` if the writer cannot be written to.
831 fn write_refsubexpfin(
832 &mut self, n: u16, k: u16, r: i16, v: i16,
833 ) -> Result<(), std::io::Error>;
834 /// # Errors
835 ///
836 /// - Returns `std::io::Error` if the writer cannot be written to.
837 fn write_s_refsubexpfin(
838 &mut self, n: u16, k: u16, r: i16, v: i16,
839 ) -> Result<(), std::io::Error>;
840}
841
842impl<W: io::Write> BCodeWriter for BitWriter<W, BigEndian> {
843 fn recenter_nonneg(&mut self, r: u16, v: u16) -> u16 {
844 /* Recenters a non-negative literal v around a reference r */
845 if v > (r << 1) {
846 v
847 } else if v >= r {
848 (v - r) << 1
849 } else {
850 ((r - v) << 1) - 1
851 }
852 }
853 fn recenter_finite_nonneg(&mut self, n: u16, r: u16, v: u16) -> u16 {
854 /* Recenters a non-negative literal v in [0, n-1] around a
855 reference r also in [0, n-1] */
856 if (r << 1) <= n {
857 self.recenter_nonneg(r, v)
858 } else {
859 self.recenter_nonneg(n - 1 - r, n - 1 - v)
860 }
861 }
862 fn write_quniform(&mut self, n: u16, v: u16) -> Result<(), std::io::Error> {
863 if n > 1 {
864 let l = 16 - n.leading_zeros() as u8;
865 let m = (1 << l) - n;
866 if v < m {
867 self.write(l as u32 - 1, v)
868 } else {
869 self.write(l as u32 - 1, m + ((v - m) >> 1))?;
870 self.write(1, (v - m) & 1)
871 }
872 } else {
873 Ok(())
874 }
875 }
876 fn write_subexpfin(
877 &mut self, n: u16, k: u16, v: u16,
878 ) -> Result<(), std::io::Error> {
879 /* Finite subexponential code that codes a symbol v in [0, n-1] with parameter k */
880 let mut i = 0;
881 let mut mk = 0;
882 loop {
883 let b = if i > 0 { k + i - 1 } else { k };
884 let a = 1 << b;
885 if n <= mk + 3 * a {
886 return self.write_quniform(n - mk, v - mk);
887 } else {
888 let t = v >= mk + a;
889 self.write_bit(t)?;
890 if t {
891 i += 1;
892 mk += a;
893 } else {
894 return self.write(b as u32, v - mk);
895 }
896 }
897 }
898 }
899 fn write_refsubexpfin(
900 &mut self, n: u16, k: u16, r: i16, v: i16,
901 ) -> Result<(), std::io::Error> {
902 /* Finite subexponential code that codes a symbol v in [0, n-1] with
903 parameter k based on a reference ref also in [0, n-1].
904 Recenters symbol around r first and then uses a finite subexponential code. */
905 let recentered_v = self.recenter_finite_nonneg(n, r as u16, v as u16);
906 self.write_subexpfin(n, k, recentered_v)
907 }
908 fn write_s_refsubexpfin(
909 &mut self, n: u16, k: u16, r: i16, v: i16,
910 ) -> Result<(), std::io::Error> {
911 /* Signed version of the above function */
912 self.write_refsubexpfin(
913 (n << 1) - 1,
914 k,
915 r + (n - 1) as i16,
916 v + (n - 1) as i16,
917 )
918 }
919}
920
921pub(crate) fn cdf_to_pdf<const CDF_LEN: usize>(
922 cdf: &[u16; CDF_LEN],
923) -> [u16; CDF_LEN] {
924 let mut pdf: [u16; CDF_LEN] = [0; CDF_LEN];
925 let mut z: u16 = 32768u16 >> EC_PROB_SHIFT;
926 for (d: &mut u16, &a: u16) in pdf.iter_mut().zip(cdf.iter()) {
927 *d = z - (a >> EC_PROB_SHIFT);
928 z = a >> EC_PROB_SHIFT;
929 }
930 pdf
931}
932
933pub(crate) mod rust {
934 // Function to update the CDF for Writer calls that do so.
935 #[inline]
936 pub fn update_cdf<const N: usize>(cdf: &mut [u16; N], val: u32) {
937 use crate::context::CDF_LEN_MAX;
938 let nsymbs = cdf.len();
939 let mut rate = 3 + (nsymbs >> 1).min(2);
940 if let Some(count) = cdf.last_mut() {
941 rate += (*count >> 4) as usize;
942 *count += 1 - (*count >> 5);
943 } else {
944 return;
945 }
946 // Single loop (faster)
947 for (i, v) in
948 cdf[..nsymbs - 1].iter_mut().enumerate().take(CDF_LEN_MAX - 1)
949 {
950 if i as u32 >= val {
951 *v -= *v >> rate;
952 } else {
953 *v += (32768 - *v) >> rate;
954 }
955 }
956 }
957}
958
959#[cfg(test)]
960mod test {
961 use super::*;
962
963 const WINDOW_SIZE: i16 = 32;
964 const LOTS_OF_BITS: i16 = 0x4000;
965
966 #[derive(Debug)]
967 struct Reader<'a> {
968 buf: &'a [u8],
969 bptr: usize,
970 dif: ec_window,
971 rng: u16,
972 cnt: i16,
973 }
974
975 impl<'a> Reader<'a> {
976 fn new(buf: &'a [u8]) -> Self {
977 let mut r = Reader {
978 buf,
979 bptr: 0,
980 dif: (1 << (WINDOW_SIZE - 1)) - 1,
981 rng: 0x8000,
982 cnt: -15,
983 };
984 r.refill();
985 r
986 }
987
988 fn refill(&mut self) {
989 let mut s = WINDOW_SIZE - 9 - (self.cnt + 15);
990 while s >= 0 && self.bptr < self.buf.len() {
991 assert!(s <= WINDOW_SIZE - 8);
992 self.dif ^= (self.buf[self.bptr] as ec_window) << s;
993 self.cnt += 8;
994 s -= 8;
995 self.bptr += 1;
996 }
997 if self.bptr >= self.buf.len() {
998 self.cnt = LOTS_OF_BITS;
999 }
1000 }
1001
1002 fn normalize(&mut self, dif: ec_window, rng: u32) {
1003 assert!(rng <= 65536);
1004 let d = rng.leading_zeros() - 16;
1005 //let d = 16 - (32-rng.leading_zeros());
1006 self.cnt -= d as i16;
1007 /*This is equivalent to shifting in 1's instead of 0's.*/
1008 self.dif = ((dif + 1) << d) - 1;
1009 self.rng = (rng << d) as u16;
1010 if self.cnt < 0 {
1011 self.refill()
1012 }
1013 }
1014
1015 fn bool(&mut self, f: u32) -> bool {
1016 assert!(f < 32768);
1017 let r = self.rng as u32;
1018 assert!(self.dif >> (WINDOW_SIZE - 16) < r);
1019 assert!(32768 <= r);
1020 let v = (((r >> 8) * (f >> EC_PROB_SHIFT)) >> (7 - EC_PROB_SHIFT))
1021 + EC_MIN_PROB;
1022 let vw = v << (WINDOW_SIZE - 16);
1023 let (dif, rng, ret) = if self.dif >= vw {
1024 (self.dif - vw, r - v, false)
1025 } else {
1026 (self.dif, v, true)
1027 };
1028 self.normalize(dif, rng);
1029 ret
1030 }
1031
1032 fn symbol(&mut self, icdf: &[u16]) -> i32 {
1033 let r = self.rng as u32;
1034 assert!(self.dif >> (WINDOW_SIZE - 16) < r);
1035 assert!(32768 <= r);
1036 let n = icdf.len() as u32 - 1;
1037 let c = self.dif >> (WINDOW_SIZE - 16);
1038 let mut v = self.rng as u32;
1039 let mut ret = 0i32;
1040 let mut u = v;
1041 v = ((r >> 8) * (icdf[ret as usize] as u32 >> EC_PROB_SHIFT))
1042 >> (7 - EC_PROB_SHIFT);
1043 v += EC_MIN_PROB * (n - ret as u32);
1044 while c < v {
1045 u = v;
1046 ret += 1;
1047 v = ((r >> 8) * (icdf[ret as usize] as u32 >> EC_PROB_SHIFT))
1048 >> (7 - EC_PROB_SHIFT);
1049 v += EC_MIN_PROB * (n - ret as u32);
1050 }
1051 assert!(v < u);
1052 assert!(u <= r);
1053 let new_dif = self.dif - (v << (WINDOW_SIZE - 16));
1054 self.normalize(new_dif, u - v);
1055 ret
1056 }
1057 }
1058
1059 #[test]
1060 fn booleans() {
1061 let mut w = WriterEncoder::new();
1062
1063 w.bool(false, 1);
1064 w.bool(true, 2);
1065 w.bool(false, 3);
1066 w.bool(true, 1);
1067 w.bool(true, 2);
1068 w.bool(false, 3);
1069
1070 let b = w.done();
1071
1072 let mut r = Reader::new(&b);
1073
1074 assert!(!r.bool(1));
1075 assert!(r.bool(2));
1076 assert!(!r.bool(3));
1077 assert!(r.bool(1));
1078 assert!(r.bool(2));
1079 assert!(!r.bool(3));
1080 }
1081
1082 #[test]
1083 fn cdf() {
1084 let cdf = [7296, 3819, 1716, 0];
1085
1086 let mut w = WriterEncoder::new();
1087
1088 w.symbol(0, &cdf);
1089 w.symbol(0, &cdf);
1090 w.symbol(0, &cdf);
1091 w.symbol(1, &cdf);
1092 w.symbol(1, &cdf);
1093 w.symbol(1, &cdf);
1094 w.symbol(2, &cdf);
1095 w.symbol(2, &cdf);
1096 w.symbol(2, &cdf);
1097
1098 let b = w.done();
1099
1100 let mut r = Reader::new(&b);
1101
1102 assert_eq!(r.symbol(&cdf), 0);
1103 assert_eq!(r.symbol(&cdf), 0);
1104 assert_eq!(r.symbol(&cdf), 0);
1105 assert_eq!(r.symbol(&cdf), 1);
1106 assert_eq!(r.symbol(&cdf), 1);
1107 assert_eq!(r.symbol(&cdf), 1);
1108 assert_eq!(r.symbol(&cdf), 2);
1109 assert_eq!(r.symbol(&cdf), 2);
1110 assert_eq!(r.symbol(&cdf), 2);
1111 }
1112
1113 #[test]
1114 fn mixed() {
1115 let cdf = [7296, 3819, 1716, 0];
1116
1117 let mut w = WriterEncoder::new();
1118
1119 w.symbol(0, &cdf);
1120 w.bool(true, 2);
1121 w.symbol(0, &cdf);
1122 w.bool(true, 2);
1123 w.symbol(0, &cdf);
1124 w.bool(true, 2);
1125 w.symbol(1, &cdf);
1126 w.bool(true, 1);
1127 w.symbol(1, &cdf);
1128 w.bool(false, 2);
1129 w.symbol(1, &cdf);
1130 w.symbol(2, &cdf);
1131 w.symbol(2, &cdf);
1132 w.symbol(2, &cdf);
1133
1134 let b = w.done();
1135
1136 let mut r = Reader::new(&b);
1137
1138 assert_eq!(r.symbol(&cdf), 0);
1139 assert!(r.bool(2));
1140 assert_eq!(r.symbol(&cdf), 0);
1141 assert!(r.bool(2));
1142 assert_eq!(r.symbol(&cdf), 0);
1143 assert!(r.bool(2));
1144 assert_eq!(r.symbol(&cdf), 1);
1145 assert!(r.bool(1));
1146 assert_eq!(r.symbol(&cdf), 1);
1147 assert!(!r.bool(2));
1148 assert_eq!(r.symbol(&cdf), 1);
1149 assert_eq!(r.symbol(&cdf), 2);
1150 assert_eq!(r.symbol(&cdf), 2);
1151 assert_eq!(r.symbol(&cdf), 2);
1152 }
1153}
1154