1/*
2A collection of helper functions, types and traits for serializing automata.
3
4This crate defines its own bespoke serialization mechanism for some structures
5provided in the public API, namely, DFAs. A bespoke mechanism was developed
6primarily because structures like automata demand a specific binary format.
7Attempting to encode their rich structure in an existing serialization
8format is just not feasible. Moreover, the format for each structure is
9generally designed such that deserialization is cheap. More specifically, that
10deserialization can be done in constant time. (The idea being that you can
11embed it into your binary or mmap it, and then use it immediately.)
12
13In order to achieve this, most of the structures in this crate use an in-memory
14representation that very closely corresponds to its binary serialized form.
15This pervades and complicates everything, and in some cases, requires dealing
16with alignment and reasoning about safety.
17
18This technique does have major advantages. In particular, it permits doing
19the potentially costly work of compiling a finite state machine in an offline
20manner, and then loading it at runtime not only without having to re-compile
21the regex, but even without the code required to do the compilation. This, for
22example, permits one to use a pre-compiled DFA not only in environments without
23Rust's standard library, but also in environments without a heap.
24
25In the code below, whenever we insert some kind of padding, it's to enforce a
264-byte alignment, unless otherwise noted. Namely, u32 is the only state ID type
27supported. (In a previous version of this library, DFAs were generic over the
28state ID representation.)
29
30Also, serialization generally requires the caller to specify endianness,
31where as deserialization always assumes native endianness (otherwise cheap
32deserialization would be impossible). This implies that serializing a structure
33generally requires serializing both its big-endian and little-endian variants,
34and then loading the correct one based on the target's endianness.
35*/
36
37use core::{
38 cmp,
39 convert::{TryFrom, TryInto},
40 mem::size_of,
41};
42
43#[cfg(feature = "alloc")]
44use alloc::{vec, vec::Vec};
45
46use crate::util::id::{PatternID, PatternIDError, StateID, StateIDError};
47
48/// An error that occurs when serializing an object from this crate.
49///
50/// Serialization, as used in this crate, universally refers to the process
51/// of transforming a structure (like a DFA) into a custom binary format
52/// represented by `&[u8]`. To this end, serialization is generally infallible.
53/// However, it can fail when caller provided buffer sizes are too small. When
54/// that occurs, a serialization error is reported.
55///
56/// A `SerializeError` provides no introspection capabilities. Its only
57/// supported operation is conversion to a human readable error message.
58///
59/// This error type implements the `std::error::Error` trait only when the
60/// `std` feature is enabled. Otherwise, this type is defined in all
61/// configurations.
62#[derive(Debug)]
63pub struct SerializeError {
64 /// The name of the thing that a buffer is too small for.
65 ///
66 /// Currently, the only kind of serialization error is one that is
67 /// committed by a caller: providing a destination buffer that is too
68 /// small to fit the serialized object. This makes sense conceptually,
69 /// since every valid inhabitant of a type should be serializable.
70 ///
71 /// This is somewhat exposed in the public API of this crate. For example,
72 /// the `to_bytes_{big,little}_endian` APIs return a `Vec<u8>` and are
73 /// guaranteed to never panic or error. This is only possible because the
74 /// implementation guarantees that it will allocate a `Vec<u8>` that is
75 /// big enough.
76 ///
77 /// In summary, if a new serialization error kind needs to be added, then
78 /// it will need careful consideration.
79 what: &'static str,
80}
81
82impl SerializeError {
83 pub(crate) fn buffer_too_small(what: &'static str) -> SerializeError {
84 SerializeError { what }
85 }
86}
87
88impl core::fmt::Display for SerializeError {
89 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
90 write!(f, "destination buffer is too small to write {}", self.what)
91 }
92}
93
94#[cfg(feature = "std")]
95impl std::error::Error for SerializeError {}
96
97/// An error that occurs when deserializing an object defined in this crate.
98///
99/// Serialization, as used in this crate, universally refers to the process
100/// of transforming a structure (like a DFA) into a custom binary format
101/// represented by `&[u8]`. Deserialization, then, refers to the process of
102/// cheaply converting this binary format back to the object's in-memory
103/// representation as defined in this crate. To the extent possible,
104/// deserialization will report this error whenever this process fails.
105///
106/// A `DeserializeError` provides no introspection capabilities. Its only
107/// supported operation is conversion to a human readable error message.
108///
109/// This error type implements the `std::error::Error` trait only when the
110/// `std` feature is enabled. Otherwise, this type is defined in all
111/// configurations.
112#[derive(Debug)]
113pub struct DeserializeError(DeserializeErrorKind);
114
115#[derive(Debug)]
116enum DeserializeErrorKind {
117 Generic { msg: &'static str },
118 BufferTooSmall { what: &'static str },
119 InvalidUsize { what: &'static str },
120 InvalidVarint { what: &'static str },
121 VersionMismatch { expected: u32, found: u32 },
122 EndianMismatch { expected: u32, found: u32 },
123 AlignmentMismatch { alignment: usize, address: usize },
124 LabelMismatch { expected: &'static str },
125 ArithmeticOverflow { what: &'static str },
126 PatternID { err: PatternIDError, what: &'static str },
127 StateID { err: StateIDError, what: &'static str },
128}
129
130impl DeserializeError {
131 pub(crate) fn generic(msg: &'static str) -> DeserializeError {
132 DeserializeError(DeserializeErrorKind::Generic { msg })
133 }
134
135 pub(crate) fn buffer_too_small(what: &'static str) -> DeserializeError {
136 DeserializeError(DeserializeErrorKind::BufferTooSmall { what })
137 }
138
139 pub(crate) fn invalid_usize(what: &'static str) -> DeserializeError {
140 DeserializeError(DeserializeErrorKind::InvalidUsize { what })
141 }
142
143 fn invalid_varint(what: &'static str) -> DeserializeError {
144 DeserializeError(DeserializeErrorKind::InvalidVarint { what })
145 }
146
147 fn version_mismatch(expected: u32, found: u32) -> DeserializeError {
148 DeserializeError(DeserializeErrorKind::VersionMismatch {
149 expected,
150 found,
151 })
152 }
153
154 fn endian_mismatch(expected: u32, found: u32) -> DeserializeError {
155 DeserializeError(DeserializeErrorKind::EndianMismatch {
156 expected,
157 found,
158 })
159 }
160
161 fn alignment_mismatch(
162 alignment: usize,
163 address: usize,
164 ) -> DeserializeError {
165 DeserializeError(DeserializeErrorKind::AlignmentMismatch {
166 alignment,
167 address,
168 })
169 }
170
171 fn label_mismatch(expected: &'static str) -> DeserializeError {
172 DeserializeError(DeserializeErrorKind::LabelMismatch { expected })
173 }
174
175 fn arithmetic_overflow(what: &'static str) -> DeserializeError {
176 DeserializeError(DeserializeErrorKind::ArithmeticOverflow { what })
177 }
178
179 pub(crate) fn pattern_id_error(
180 err: PatternIDError,
181 what: &'static str,
182 ) -> DeserializeError {
183 DeserializeError(DeserializeErrorKind::PatternID { err, what })
184 }
185
186 pub(crate) fn state_id_error(
187 err: StateIDError,
188 what: &'static str,
189 ) -> DeserializeError {
190 DeserializeError(DeserializeErrorKind::StateID { err, what })
191 }
192}
193
194#[cfg(feature = "std")]
195impl std::error::Error for DeserializeError {}
196
197impl core::fmt::Display for DeserializeError {
198 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
199 use self::DeserializeErrorKind::*;
200
201 match self.0 {
202 Generic { msg } => write!(f, "{}", msg),
203 BufferTooSmall { what } => {
204 write!(f, "buffer is too small to read {}", what)
205 }
206 InvalidUsize { what } => {
207 write!(f, "{} is too big to fit in a usize", what)
208 }
209 InvalidVarint { what } => {
210 write!(f, "could not decode valid varint for {}", what)
211 }
212 VersionMismatch { expected, found } => write!(
213 f,
214 "unsupported version: \
215 expected version {} but found version {}",
216 expected, found,
217 ),
218 EndianMismatch { expected, found } => write!(
219 f,
220 "endianness mismatch: expected 0x{:X} but got 0x{:X}. \
221 (Are you trying to load an object serialized with a \
222 different endianness?)",
223 expected, found,
224 ),
225 AlignmentMismatch { alignment, address } => write!(
226 f,
227 "alignment mismatch: slice starts at address \
228 0x{:X}, which is not aligned to a {} byte boundary",
229 address, alignment,
230 ),
231 LabelMismatch { expected } => write!(
232 f,
233 "label mismatch: start of serialized object should \
234 contain a NUL terminated {:?} label, but a different \
235 label was found",
236 expected,
237 ),
238 ArithmeticOverflow { what } => {
239 write!(f, "arithmetic overflow for {}", what)
240 }
241 PatternID { ref err, what } => {
242 write!(f, "failed to read pattern ID for {}: {}", what, err)
243 }
244 StateID { ref err, what } => {
245 write!(f, "failed to read state ID for {}: {}", what, err)
246 }
247 }
248 }
249}
250
251/// Checks that the given slice has an alignment that matches `T`.
252///
253/// This is useful for checking that a slice has an appropriate alignment
254/// before casting it to a &[T]. Note though that alignment is not itself
255/// sufficient to perform the cast for any `T`.
256pub fn check_alignment<T>(slice: &[u8]) -> Result<(), DeserializeError> {
257 let alignment: usize = core::mem::align_of::<T>();
258 let address: usize = slice.as_ptr() as usize;
259 if address % alignment == 0 {
260 return Ok(());
261 }
262 Err(DeserializeError::alignment_mismatch(alignment, address))
263}
264
265/// Reads a possibly empty amount of padding, up to 7 bytes, from the beginning
266/// of the given slice. All padding bytes must be NUL bytes.
267///
268/// This is useful because it can be theoretically necessary to pad the
269/// beginning of a serialized object with NUL bytes to ensure that it starts
270/// at a correctly aligned address. These padding bytes should come immediately
271/// before the label.
272///
273/// This returns the number of bytes read from the given slice.
274pub fn skip_initial_padding(slice: &[u8]) -> usize {
275 let mut nread: usize = 0;
276 while nread < 7 && nread < slice.len() && slice[nread] == 0 {
277 nread += 1;
278 }
279 nread
280}
281
282/// Allocate a byte buffer of the given size, along with some initial padding
283/// such that `buf[padding..]` has the same alignment as `T`, where the
284/// alignment of `T` must be at most `8`. In particular, callers should treat
285/// the first N bytes (second return value) as padding bytes that must not be
286/// overwritten. In all cases, the following identity holds:
287///
288/// ```ignore
289/// let (buf, padding) = alloc_aligned_buffer::<StateID>(SIZE);
290/// assert_eq!(SIZE, buf[padding..].len());
291/// ```
292///
293/// In practice, padding is often zero.
294///
295/// The requirement for `8` as a maximum here is somewhat arbitrary. In
296/// practice, we never need anything bigger in this crate, and so this function
297/// does some sanity asserts under the assumption of a max alignment of `8`.
298#[cfg(feature = "alloc")]
299pub fn alloc_aligned_buffer<T>(size: usize) -> (Vec<u8>, usize) {
300 // FIXME: This is a kludge because there's no easy way to allocate a
301 // Vec<u8> with an alignment guaranteed to be greater than 1. We could
302 // create a Vec<u32>, but this cannot be safely transmuted to a Vec<u8>
303 // without concern, since reallocing or dropping the Vec<u8> is UB
304 // (different alignment than the initial allocation). We could define a
305 // wrapper type to manage this for us, but it seems like more machinery
306 // than it's worth.
307 let mut buf = vec![0; size];
308 let align = core::mem::align_of::<T>();
309 let address = buf.as_ptr() as usize;
310 if address % align == 0 {
311 return (buf, 0);
312 }
313 // It's not quite clear how to robustly test this code, since the allocator
314 // in my environment appears to always return addresses aligned to at
315 // least 8 bytes, even when the alignment requirement is smaller. A feeble
316 // attempt at ensuring correctness is provided with asserts.
317 let padding = ((address & !0b111).checked_add(8).unwrap())
318 .checked_sub(address)
319 .unwrap();
320 assert!(padding <= 7, "padding of {} is bigger than 7", padding);
321 buf.extend(core::iter::repeat(0).take(padding));
322 assert_eq!(size + padding, buf.len());
323 assert_eq!(
324 0,
325 buf[padding..].as_ptr() as usize % align,
326 "expected end of initial padding to be aligned to {}",
327 align,
328 );
329 (buf, padding)
330}
331
332/// Reads a NUL terminated label starting at the beginning of the given slice.
333///
334/// If a NUL terminated label could not be found, then an error is returned.
335/// Similary, if a label is found but doesn't match the expected label, then
336/// an error is returned.
337///
338/// Upon success, the total number of bytes read (including padding bytes) is
339/// returned.
340pub fn read_label(
341 slice: &[u8],
342 expected_label: &'static str,
343) -> Result<usize, DeserializeError> {
344 // Set an upper bound on how many bytes we scan for a NUL. Since no label
345 // in this crate is longer than 256 bytes, if we can't find one within that
346 // range, then we have corrupted data.
347 let first_nul =
348 slice[..cmp::min(slice.len(), 256)].iter().position(|&b| b == 0);
349 let first_nul = match first_nul {
350 Some(first_nul) => first_nul,
351 None => {
352 return Err(DeserializeError::generic(
353 "could not find NUL terminated label \
354 at start of serialized object",
355 ));
356 }
357 };
358 let len = first_nul + padding_len(first_nul);
359 if slice.len() < len {
360 return Err(DeserializeError::generic(
361 "could not find properly sized label at start of serialized object"
362 ));
363 }
364 if expected_label.as_bytes() != &slice[..first_nul] {
365 return Err(DeserializeError::label_mismatch(expected_label));
366 }
367 Ok(len)
368}
369
370/// Writes the given label to the buffer as a NUL terminated string. The label
371/// given must not contain NUL, otherwise this will panic. Similarly, the label
372/// must not be longer than 255 bytes, otherwise this will panic.
373///
374/// Additional NUL bytes are written as necessary to ensure that the number of
375/// bytes written is always a multiple of 4.
376///
377/// Upon success, the total number of bytes written (including padding) is
378/// returned.
379pub fn write_label(
380 label: &str,
381 dst: &mut [u8],
382) -> Result<usize, SerializeError> {
383 let nwrite: usize = write_label_len(label);
384 if dst.len() < nwrite {
385 return Err(SerializeError::buffer_too_small(what:"label"));
386 }
387 dst[..label.len()].copy_from_slice(src:label.as_bytes());
388 for i: usize in 0..(nwrite - label.len()) {
389 dst[label.len() + i] = 0;
390 }
391 assert_eq!(nwrite % 4, 0);
392 Ok(nwrite)
393}
394
395/// Returns the total number of bytes (including padding) that would be written
396/// for the given label. This panics if the given label contains a NUL byte or
397/// is longer than 255 bytes. (The size restriction exists so that searching
398/// for a label during deserialization can be done in small bounded space.)
399pub fn write_label_len(label: &str) -> usize {
400 if label.len() > 255 {
401 panic!("label must not be longer than 255 bytes");
402 }
403 if label.as_bytes().iter().position(|&b: u8| b == 0).is_some() {
404 panic!("label must not contain NUL bytes");
405 }
406 let label_len: usize = label.len() + 1; // +1 for the NUL terminator
407 label_len + padding_len(non_padding_len:label_len)
408}
409
410/// Reads the endianness check from the beginning of the given slice and
411/// confirms that the endianness of the serialized object matches the expected
412/// endianness. If the slice is too small or if the endianness check fails,
413/// this returns an error.
414///
415/// Upon success, the total number of bytes read is returned.
416pub fn read_endianness_check(slice: &[u8]) -> Result<usize, DeserializeError> {
417 let (n: u32, nr: usize) = try_read_u32(slice, what:"endianness check")?;
418 assert_eq!(nr, write_endianness_check_len());
419 if n != 0xFEFF {
420 return Err(DeserializeError::endian_mismatch(expected:0xFEFF, found:n));
421 }
422 Ok(nr)
423}
424
425/// Writes 0xFEFF as an integer using the given endianness.
426///
427/// This is useful for writing into the header of a serialized object. It can
428/// be read during deserialization as a sanity check to ensure the proper
429/// endianness is used.
430///
431/// Upon success, the total number of bytes written is returned.
432pub fn write_endianness_check<E: Endian>(
433 dst: &mut [u8],
434) -> Result<usize, SerializeError> {
435 let nwrite: usize = write_endianness_check_len();
436 if dst.len() < nwrite {
437 return Err(SerializeError::buffer_too_small(what:"endianness check"));
438 }
439 E::write_u32(n:0xFEFF, dst);
440 Ok(nwrite)
441}
442
443/// Returns the number of bytes written by the endianness check.
444pub fn write_endianness_check_len() -> usize {
445 size_of::<u32>()
446}
447
448/// Reads a version number from the beginning of the given slice and confirms
449/// that is matches the expected version number given. If the slice is too
450/// small or if the version numbers aren't equivalent, this returns an error.
451///
452/// Upon success, the total number of bytes read is returned.
453///
454/// N.B. Currently, we require that the version number is exactly equivalent.
455/// In the future, if we bump the version number without a semver bump, then
456/// we'll need to relax this a bit and support older versions.
457pub fn read_version(
458 slice: &[u8],
459 expected_version: u32,
460) -> Result<usize, DeserializeError> {
461 let (n: u32, nr: usize) = try_read_u32(slice, what:"version")?;
462 assert_eq!(nr, write_version_len());
463 if n != expected_version {
464 return Err(DeserializeError::version_mismatch(expected_version, found:n));
465 }
466 Ok(nr)
467}
468
469/// Writes the given version number to the beginning of the given slice.
470///
471/// This is useful for writing into the header of a serialized object. It can
472/// be read during deserialization as a sanity check to ensure that the library
473/// code supports the format of the serialized object.
474///
475/// Upon success, the total number of bytes written is returned.
476pub fn write_version<E: Endian>(
477 version: u32,
478 dst: &mut [u8],
479) -> Result<usize, SerializeError> {
480 let nwrite: usize = write_version_len();
481 if dst.len() < nwrite {
482 return Err(SerializeError::buffer_too_small(what:"version number"));
483 }
484 E::write_u32(n:version, dst);
485 Ok(nwrite)
486}
487
488/// Returns the number of bytes written by writing the version number.
489pub fn write_version_len() -> usize {
490 size_of::<u32>()
491}
492
493/// Reads a pattern ID from the given slice. If the slice has insufficient
494/// length, then this panics. If the deserialized integer exceeds the pattern
495/// ID limit for the current target, then this returns an error.
496///
497/// Upon success, this also returns the number of bytes read.
498pub fn read_pattern_id(
499 slice: &[u8],
500 what: &'static str,
501) -> Result<(PatternID, usize), DeserializeError> {
502 let bytes: [u8; PatternID::SIZE] =
503 slice[..PatternID::SIZE].try_into().unwrap();
504 let pid: PatternID = PatternID::from_ne_bytes(bytes)
505 .map_err(|err: PatternIDError| DeserializeError::pattern_id_error(err, what))?;
506 Ok((pid, PatternID::SIZE))
507}
508
509/// Reads a pattern ID from the given slice. If the slice has insufficient
510/// length, then this panics. Otherwise, the deserialized integer is assumed
511/// to be a valid pattern ID.
512///
513/// This also returns the number of bytes read.
514pub fn read_pattern_id_unchecked(slice: &[u8]) -> (PatternID, usize) {
515 let pid: PatternID = PatternID::from_ne_bytes_unchecked(
516 bytes:slice[..PatternID::SIZE].try_into().unwrap(),
517 );
518 (pid, PatternID::SIZE)
519}
520
521/// Write the given pattern ID to the beginning of the given slice of bytes
522/// using the specified endianness. The given slice must have length at least
523/// `PatternID::SIZE`, or else this panics. Upon success, the total number of
524/// bytes written is returned.
525pub fn write_pattern_id<E: Endian>(pid: PatternID, dst: &mut [u8]) -> usize {
526 E::write_u32(n:pid.as_u32(), dst);
527 PatternID::SIZE
528}
529
530/// Attempts to read a state ID from the given slice. If the slice has an
531/// insufficient number of bytes or if the state ID exceeds the limit for
532/// the current target, then this returns an error.
533///
534/// Upon success, this also returns the number of bytes read.
535pub fn try_read_state_id(
536 slice: &[u8],
537 what: &'static str,
538) -> Result<(StateID, usize), DeserializeError> {
539 if slice.len() < StateID::SIZE {
540 return Err(DeserializeError::buffer_too_small(what));
541 }
542 read_state_id(slice, what)
543}
544
545/// Reads a state ID from the given slice. If the slice has insufficient
546/// length, then this panics. If the deserialized integer exceeds the state ID
547/// limit for the current target, then this returns an error.
548///
549/// Upon success, this also returns the number of bytes read.
550pub fn read_state_id(
551 slice: &[u8],
552 what: &'static str,
553) -> Result<(StateID, usize), DeserializeError> {
554 let bytes: [u8; StateID::SIZE] =
555 slice[..StateID::SIZE].try_into().unwrap();
556 let sid: StateID = StateID::from_ne_bytes(bytes)
557 .map_err(|err: StateIDError| DeserializeError::state_id_error(err, what))?;
558 Ok((sid, StateID::SIZE))
559}
560
561/// Reads a state ID from the given slice. If the slice has insufficient
562/// length, then this panics. Otherwise, the deserialized integer is assumed
563/// to be a valid state ID.
564///
565/// This also returns the number of bytes read.
566pub fn read_state_id_unchecked(slice: &[u8]) -> (StateID, usize) {
567 let sid: StateID = StateID::from_ne_bytes_unchecked(
568 bytes:slice[..StateID::SIZE].try_into().unwrap(),
569 );
570 (sid, StateID::SIZE)
571}
572
573/// Write the given state ID to the beginning of the given slice of bytes
574/// using the specified endianness. The given slice must have length at least
575/// `StateID::SIZE`, or else this panics. Upon success, the total number of
576/// bytes written is returned.
577pub fn write_state_id<E: Endian>(sid: StateID, dst: &mut [u8]) -> usize {
578 E::write_u32(n:sid.as_u32(), dst);
579 StateID::SIZE
580}
581
582/// Try to read a u16 as a usize from the beginning of the given slice in
583/// native endian format. If the slice has fewer than 2 bytes or if the
584/// deserialized number cannot be represented by usize, then this returns an
585/// error. The error message will include the `what` description of what is
586/// being deserialized, for better error messages. `what` should be a noun in
587/// singular form.
588///
589/// Upon success, this also returns the number of bytes read.
590pub fn try_read_u16_as_usize(
591 slice: &[u8],
592 what: &'static str,
593) -> Result<(usize, usize), DeserializeError> {
594 try_read_u16(slice, what).and_then(|(n: u16, nr: usize)| {
595 usize::try_from(n)
596 .map(|n| (n, nr))
597 .map_err(|_| DeserializeError::invalid_usize(what))
598 })
599}
600
601/// Try to read a u32 as a usize from the beginning of the given slice in
602/// native endian format. If the slice has fewer than 4 bytes or if the
603/// deserialized number cannot be represented by usize, then this returns an
604/// error. The error message will include the `what` description of what is
605/// being deserialized, for better error messages. `what` should be a noun in
606/// singular form.
607///
608/// Upon success, this also returns the number of bytes read.
609pub fn try_read_u32_as_usize(
610 slice: &[u8],
611 what: &'static str,
612) -> Result<(usize, usize), DeserializeError> {
613 try_read_u32(slice, what).and_then(|(n: u32, nr: usize)| {
614 usize::try_from(n)
615 .map(|n| (n, nr))
616 .map_err(|_| DeserializeError::invalid_usize(what))
617 })
618}
619
620/// Try to read a u16 from the beginning of the given slice in native endian
621/// format. If the slice has fewer than 2 bytes, then this returns an error.
622/// The error message will include the `what` description of what is being
623/// deserialized, for better error messages. `what` should be a noun in
624/// singular form.
625///
626/// Upon success, this also returns the number of bytes read.
627pub fn try_read_u16(
628 slice: &[u8],
629 what: &'static str,
630) -> Result<(u16, usize), DeserializeError> {
631 if slice.len() < size_of::<u16>() {
632 return Err(DeserializeError::buffer_too_small(what));
633 }
634 Ok((read_u16(slice), size_of::<u16>()))
635}
636
637/// Try to read a u32 from the beginning of the given slice in native endian
638/// format. If the slice has fewer than 4 bytes, then this returns an error.
639/// The error message will include the `what` description of what is being
640/// deserialized, for better error messages. `what` should be a noun in
641/// singular form.
642///
643/// Upon success, this also returns the number of bytes read.
644pub fn try_read_u32(
645 slice: &[u8],
646 what: &'static str,
647) -> Result<(u32, usize), DeserializeError> {
648 if slice.len() < size_of::<u32>() {
649 return Err(DeserializeError::buffer_too_small(what));
650 }
651 Ok((read_u32(slice), size_of::<u32>()))
652}
653
654/// Read a u16 from the beginning of the given slice in native endian format.
655/// If the slice has fewer than 2 bytes, then this panics.
656///
657/// Marked as inline to speed up sparse searching which decodes integers from
658/// its automaton at search time.
659#[inline(always)]
660pub fn read_u16(slice: &[u8]) -> u16 {
661 let bytes: [u8; 2] = slice[..size_of::<u16>()].try_into().unwrap();
662 u16::from_ne_bytes(bytes)
663}
664
665/// Read a u32 from the beginning of the given slice in native endian format.
666/// If the slice has fewer than 4 bytes, then this panics.
667///
668/// Marked as inline to speed up sparse searching which decodes integers from
669/// its automaton at search time.
670#[inline(always)]
671pub fn read_u32(slice: &[u8]) -> u32 {
672 let bytes: [u8; 4] = slice[..size_of::<u32>()].try_into().unwrap();
673 u32::from_ne_bytes(bytes)
674}
675
676/// Read a u64 from the beginning of the given slice in native endian format.
677/// If the slice has fewer than 8 bytes, then this panics.
678///
679/// Marked as inline to speed up sparse searching which decodes integers from
680/// its automaton at search time.
681#[inline(always)]
682pub fn read_u64(slice: &[u8]) -> u64 {
683 let bytes: [u8; 8] = slice[..size_of::<u64>()].try_into().unwrap();
684 u64::from_ne_bytes(bytes)
685}
686
687/// Write a variable sized integer and return the total number of bytes
688/// written. If the slice was not big enough to contain the bytes, then this
689/// returns an error including the "what" description in it. This does no
690/// padding.
691///
692/// See: https://developers.google.com/protocol-buffers/docs/encoding#varints
693#[allow(dead_code)]
694pub fn write_varu64(
695 mut n: u64,
696 what: &'static str,
697 dst: &mut [u8],
698) -> Result<usize, SerializeError> {
699 let mut i: usize = 0;
700 while n >= 0b1000_0000 {
701 if i >= dst.len() {
702 return Err(SerializeError::buffer_too_small(what));
703 }
704 dst[i] = (n as u8) | 0b1000_0000;
705 n >>= 7;
706 i += 1;
707 }
708 if i >= dst.len() {
709 return Err(SerializeError::buffer_too_small(what));
710 }
711 dst[i] = n as u8;
712 Ok(i + 1)
713}
714
715/// Returns the total number of bytes that would be writen to encode n as a
716/// variable sized integer.
717///
718/// See: https://developers.google.com/protocol-buffers/docs/encoding#varints
719#[allow(dead_code)]
720pub fn write_varu64_len(mut n: u64) -> usize {
721 let mut i: usize = 0;
722 while n >= 0b1000_0000 {
723 n >>= 7;
724 i += 1;
725 }
726 i + 1
727}
728
729/// Like read_varu64, but attempts to cast the result to usize. If the integer
730/// cannot fit into a usize, then an error is returned.
731#[allow(dead_code)]
732pub fn read_varu64_as_usize(
733 slice: &[u8],
734 what: &'static str,
735) -> Result<(usize, usize), DeserializeError> {
736 let (n: u64, nread: usize) = read_varu64(slice, what)?;
737 let n: usize = usize::try_from(n)
738 .map_err(|_| DeserializeError::invalid_usize(what))?;
739 Ok((n, nread))
740}
741
742/// Reads a variable sized integer from the beginning of slice, and returns the
743/// integer along with the total number of bytes read. If a valid variable
744/// sized integer could not be found, then an error is returned that includes
745/// the "what" description in it.
746///
747/// https://developers.google.com/protocol-buffers/docs/encoding#varints
748#[allow(dead_code)]
749pub fn read_varu64(
750 slice: &[u8],
751 what: &'static str,
752) -> Result<(u64, usize), DeserializeError> {
753 let mut n: u64 = 0;
754 let mut shift: u32 = 0;
755 // The biggest possible value is u64::MAX, which needs all 64 bits which
756 // requires 10 bytes (because 7 * 9 < 64). We use a limit to avoid reading
757 // an unnecessary number of bytes.
758 let limit: usize = cmp::min(v1:slice.len(), v2:10);
759 for (i: usize, &b: u8) in slice[..limit].iter().enumerate() {
760 if b < 0b1000_0000 {
761 return match (b as u64).checked_shl(shift) {
762 None => Err(DeserializeError::invalid_varint(what)),
763 Some(b: u64) => Ok((n | b, i + 1)),
764 };
765 }
766 match ((b as u64) & 0b0111_1111).checked_shl(shift) {
767 None => return Err(DeserializeError::invalid_varint(what)),
768 Some(b: u64) => n |= b,
769 }
770 shift += 7;
771 }
772 Err(DeserializeError::invalid_varint(what))
773}
774
775/// Checks that the given slice has some minimal length. If it's smaller than
776/// the bound given, then a "buffer too small" error is returned with `what`
777/// describing what the buffer represents.
778pub fn check_slice_len<T>(
779 slice: &[T],
780 at_least_len: usize,
781 what: &'static str,
782) -> Result<(), DeserializeError> {
783 if slice.len() < at_least_len {
784 return Err(DeserializeError::buffer_too_small(what));
785 }
786 Ok(())
787}
788
789/// Multiply the given numbers, and on overflow, return an error that includes
790/// 'what' in the error message.
791///
792/// This is useful when doing arithmetic with untrusted data.
793pub fn mul(
794 a: usize,
795 b: usize,
796 what: &'static str,
797) -> Result<usize, DeserializeError> {
798 match a.checked_mul(b) {
799 Some(c: usize) => Ok(c),
800 None => Err(DeserializeError::arithmetic_overflow(what)),
801 }
802}
803
804/// Add the given numbers, and on overflow, return an error that includes
805/// 'what' in the error message.
806///
807/// This is useful when doing arithmetic with untrusted data.
808pub fn add(
809 a: usize,
810 b: usize,
811 what: &'static str,
812) -> Result<usize, DeserializeError> {
813 match a.checked_add(b) {
814 Some(c: usize) => Ok(c),
815 None => Err(DeserializeError::arithmetic_overflow(what)),
816 }
817}
818
819/// Shift `a` left by `b`, and on overflow, return an error that includes
820/// 'what' in the error message.
821///
822/// This is useful when doing arithmetic with untrusted data.
823pub fn shl(
824 a: usize,
825 b: usize,
826 what: &'static str,
827) -> Result<usize, DeserializeError> {
828 let amount: u32 = u32::try_from(b)
829 .map_err(|_| DeserializeError::arithmetic_overflow(what))?;
830 match a.checked_shl(amount) {
831 Some(c: usize) => Ok(c),
832 None => Err(DeserializeError::arithmetic_overflow(what)),
833 }
834}
835
836/// A simple trait for writing code generic over endianness.
837///
838/// This is similar to what byteorder provides, but we only need a very small
839/// subset.
840pub trait Endian {
841 /// Writes a u16 to the given destination buffer in a particular
842 /// endianness. If the destination buffer has a length smaller than 2, then
843 /// this panics.
844 fn write_u16(n: u16, dst: &mut [u8]);
845
846 /// Writes a u32 to the given destination buffer in a particular
847 /// endianness. If the destination buffer has a length smaller than 4, then
848 /// this panics.
849 fn write_u32(n: u32, dst: &mut [u8]);
850
851 /// Writes a u64 to the given destination buffer in a particular
852 /// endianness. If the destination buffer has a length smaller than 8, then
853 /// this panics.
854 fn write_u64(n: u64, dst: &mut [u8]);
855}
856
857/// Little endian writing.
858pub enum LE {}
859/// Big endian writing.
860pub enum BE {}
861
862#[cfg(target_endian = "little")]
863pub type NE = LE;
864#[cfg(target_endian = "big")]
865pub type NE = BE;
866
867impl Endian for LE {
868 fn write_u16(n: u16, dst: &mut [u8]) {
869 dst[..2].copy_from_slice(&n.to_le_bytes());
870 }
871
872 fn write_u32(n: u32, dst: &mut [u8]) {
873 dst[..4].copy_from_slice(&n.to_le_bytes());
874 }
875
876 fn write_u64(n: u64, dst: &mut [u8]) {
877 dst[..8].copy_from_slice(&n.to_le_bytes());
878 }
879}
880
881impl Endian for BE {
882 fn write_u16(n: u16, dst: &mut [u8]) {
883 dst[..2].copy_from_slice(&n.to_be_bytes());
884 }
885
886 fn write_u32(n: u32, dst: &mut [u8]) {
887 dst[..4].copy_from_slice(&n.to_be_bytes());
888 }
889
890 fn write_u64(n: u64, dst: &mut [u8]) {
891 dst[..8].copy_from_slice(&n.to_be_bytes());
892 }
893}
894
895/// Returns the number of additional bytes required to add to the given length
896/// in order to make the total length a multiple of 4. The return value is
897/// always less than 4.
898pub fn padding_len(non_padding_len: usize) -> usize {
899 (4 - (non_padding_len & 0b11)) & 0b11
900}
901
902#[cfg(all(test, feature = "alloc"))]
903mod tests {
904 use super::*;
905
906 #[test]
907 fn labels() {
908 let mut buf = [0; 1024];
909
910 let nwrite = write_label("fooba", &mut buf).unwrap();
911 assert_eq!(nwrite, 8);
912 assert_eq!(&buf[..nwrite], b"fooba\x00\x00\x00");
913
914 let nread = read_label(&buf, "fooba").unwrap();
915 assert_eq!(nread, 8);
916 }
917
918 #[test]
919 #[should_panic]
920 fn bad_label_interior_nul() {
921 // interior NULs are not allowed
922 write_label("foo\x00bar", &mut [0; 1024]).unwrap();
923 }
924
925 #[test]
926 fn bad_label_almost_too_long() {
927 // ok
928 write_label(&"z".repeat(255), &mut [0; 1024]).unwrap();
929 }
930
931 #[test]
932 #[should_panic]
933 fn bad_label_too_long() {
934 // labels longer than 255 bytes are banned
935 write_label(&"z".repeat(256), &mut [0; 1024]).unwrap();
936 }
937
938 #[test]
939 fn padding() {
940 assert_eq!(0, padding_len(8));
941 assert_eq!(3, padding_len(9));
942 assert_eq!(2, padding_len(10));
943 assert_eq!(1, padding_len(11));
944 assert_eq!(0, padding_len(12));
945 assert_eq!(3, padding_len(13));
946 assert_eq!(2, padding_len(14));
947 assert_eq!(1, padding_len(15));
948 assert_eq!(0, padding_len(16));
949 }
950}
951