1/*!
2Types and routines that support the wire format of finite automata.
3
4Currently, this module just exports a few error types and some small helpers
5for deserializing [dense DFAs](crate::dfa::dense::DFA) using correct alignment.
6*/
7
8/*
9A collection of helper functions, types and traits for serializing automata.
10
11This crate defines its own bespoke serialization mechanism for some structures
12provided in the public API, namely, DFAs. A bespoke mechanism was developed
13primarily because structures like automata demand a specific binary format.
14Attempting to encode their rich structure in an existing serialization
15format is just not feasible. Moreover, the format for each structure is
16generally designed such that deserialization is cheap. More specifically, that
17deserialization can be done in constant time. (The idea being that you can
18embed it into your binary or mmap it, and then use it immediately.)
19
20In order to achieve this, the dense and sparse DFAs in this crate use an
21in-memory representation that very closely corresponds to its binary serialized
22form. This pervades and complicates everything, and in some cases, requires
23dealing with alignment and reasoning about safety.
24
25This technique does have major advantages. In particular, it permits doing
26the potentially costly work of compiling a finite state machine in an offline
27manner, and then loading it at runtime not only without having to re-compile
28the regex, but even without the code required to do the compilation. This, for
29example, permits one to use a pre-compiled DFA not only in environments without
30Rust's standard library, but also in environments without a heap.
31
32In the code below, whenever we insert some kind of padding, it's to enforce a
334-byte alignment, unless otherwise noted. Namely, u32 is the only state ID type
34supported. (In a previous version of this library, DFAs were generic over the
35state ID representation.)
36
37Also, serialization generally requires the caller to specify endianness,
38where as deserialization always assumes native endianness (otherwise cheap
39deserialization would be impossible). This implies that serializing a structure
40generally requires serializing both its big-endian and little-endian variants,
41and then loading the correct one based on the target's endianness.
42*/
43
44use core::{
45 cmp,
46 convert::{TryFrom, TryInto},
47 mem::size_of,
48};
49
50#[cfg(feature = "alloc")]
51use alloc::{vec, vec::Vec};
52
53use crate::util::{
54 int::Pointer,
55 primitives::{PatternID, PatternIDError, StateID, StateIDError},
56};
57
58/// A hack to align a smaller type `B` with a bigger type `T`.
59///
60/// The usual use of this is with `B = [u8]` and `T = u32`. That is,
61/// it permits aligning a sequence of bytes on a 4-byte boundary. This
62/// is useful in contexts where one wants to embed a serialized [dense
63/// DFA](crate::dfa::dense::DFA) into a Rust a program while guaranteeing the
64/// alignment required for the DFA.
65///
66/// See [`dense::DFA::from_bytes`](crate::dfa::dense::DFA::from_bytes) for an
67/// example of how to use this type.
68#[repr(C)]
69#[derive(Debug)]
70pub struct AlignAs<B: ?Sized, T> {
71 /// A zero-sized field indicating the alignment we want.
72 pub _align: [T; 0],
73 /// A possibly non-sized field containing a sequence of bytes.
74 pub bytes: B,
75}
76
77/// An error that occurs when serializing an object from this crate.
78///
79/// Serialization, as used in this crate, universally refers to the process
80/// of transforming a structure (like a DFA) into a custom binary format
81/// represented by `&[u8]`. To this end, serialization is generally infallible.
82/// However, it can fail when caller provided buffer sizes are too small. When
83/// that occurs, a serialization error is reported.
84///
85/// A `SerializeError` provides no introspection capabilities. Its only
86/// supported operation is conversion to a human readable error message.
87///
88/// This error type implements the `std::error::Error` trait only when the
89/// `std` feature is enabled. Otherwise, this type is defined in all
90/// configurations.
91#[derive(Debug)]
92pub struct SerializeError {
93 /// The name of the thing that a buffer is too small for.
94 ///
95 /// Currently, the only kind of serialization error is one that is
96 /// committed by a caller: providing a destination buffer that is too
97 /// small to fit the serialized object. This makes sense conceptually,
98 /// since every valid inhabitant of a type should be serializable.
99 ///
100 /// This is somewhat exposed in the public API of this crate. For example,
101 /// the `to_bytes_{big,little}_endian` APIs return a `Vec<u8>` and are
102 /// guaranteed to never panic or error. This is only possible because the
103 /// implementation guarantees that it will allocate a `Vec<u8>` that is
104 /// big enough.
105 ///
106 /// In summary, if a new serialization error kind needs to be added, then
107 /// it will need careful consideration.
108 what: &'static str,
109}
110
111impl SerializeError {
112 pub(crate) fn buffer_too_small(what: &'static str) -> SerializeError {
113 SerializeError { what }
114 }
115}
116
117impl core::fmt::Display for SerializeError {
118 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
119 write!(f, "destination buffer is too small to write {}", self.what)
120 }
121}
122
123#[cfg(feature = "std")]
124impl std::error::Error for SerializeError {}
125
126/// An error that occurs when deserializing an object defined in this crate.
127///
128/// Serialization, as used in this crate, universally refers to the process
129/// of transforming a structure (like a DFA) into a custom binary format
130/// represented by `&[u8]`. Deserialization, then, refers to the process of
131/// cheaply converting this binary format back to the object's in-memory
132/// representation as defined in this crate. To the extent possible,
133/// deserialization will report this error whenever this process fails.
134///
135/// A `DeserializeError` provides no introspection capabilities. Its only
136/// supported operation is conversion to a human readable error message.
137///
138/// This error type implements the `std::error::Error` trait only when the
139/// `std` feature is enabled. Otherwise, this type is defined in all
140/// configurations.
141#[derive(Debug)]
142pub struct DeserializeError(DeserializeErrorKind);
143
144#[derive(Debug)]
145enum DeserializeErrorKind {
146 Generic { msg: &'static str },
147 BufferTooSmall { what: &'static str },
148 InvalidUsize { what: &'static str },
149 VersionMismatch { expected: u32, found: u32 },
150 EndianMismatch { expected: u32, found: u32 },
151 AlignmentMismatch { alignment: usize, address: usize },
152 LabelMismatch { expected: &'static str },
153 ArithmeticOverflow { what: &'static str },
154 PatternID { err: PatternIDError, what: &'static str },
155 StateID { err: StateIDError, what: &'static str },
156}
157
158impl DeserializeError {
159 pub(crate) fn generic(msg: &'static str) -> DeserializeError {
160 DeserializeError(DeserializeErrorKind::Generic { msg })
161 }
162
163 pub(crate) fn buffer_too_small(what: &'static str) -> DeserializeError {
164 DeserializeError(DeserializeErrorKind::BufferTooSmall { what })
165 }
166
167 fn invalid_usize(what: &'static str) -> DeserializeError {
168 DeserializeError(DeserializeErrorKind::InvalidUsize { what })
169 }
170
171 fn version_mismatch(expected: u32, found: u32) -> DeserializeError {
172 DeserializeError(DeserializeErrorKind::VersionMismatch {
173 expected,
174 found,
175 })
176 }
177
178 fn endian_mismatch(expected: u32, found: u32) -> DeserializeError {
179 DeserializeError(DeserializeErrorKind::EndianMismatch {
180 expected,
181 found,
182 })
183 }
184
185 fn alignment_mismatch(
186 alignment: usize,
187 address: usize,
188 ) -> DeserializeError {
189 DeserializeError(DeserializeErrorKind::AlignmentMismatch {
190 alignment,
191 address,
192 })
193 }
194
195 fn label_mismatch(expected: &'static str) -> DeserializeError {
196 DeserializeError(DeserializeErrorKind::LabelMismatch { expected })
197 }
198
199 fn arithmetic_overflow(what: &'static str) -> DeserializeError {
200 DeserializeError(DeserializeErrorKind::ArithmeticOverflow { what })
201 }
202
203 fn pattern_id_error(
204 err: PatternIDError,
205 what: &'static str,
206 ) -> DeserializeError {
207 DeserializeError(DeserializeErrorKind::PatternID { err, what })
208 }
209
210 pub(crate) fn state_id_error(
211 err: StateIDError,
212 what: &'static str,
213 ) -> DeserializeError {
214 DeserializeError(DeserializeErrorKind::StateID { err, what })
215 }
216}
217
218#[cfg(feature = "std")]
219impl std::error::Error for DeserializeError {}
220
221impl core::fmt::Display for DeserializeError {
222 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
223 use self::DeserializeErrorKind::*;
224
225 match self.0 {
226 Generic { msg } => write!(f, "{}", msg),
227 BufferTooSmall { what } => {
228 write!(f, "buffer is too small to read {}", what)
229 }
230 InvalidUsize { what } => {
231 write!(f, "{} is too big to fit in a usize", what)
232 }
233 VersionMismatch { expected, found } => write!(
234 f,
235 "unsupported version: \
236 expected version {} but found version {}",
237 expected, found,
238 ),
239 EndianMismatch { expected, found } => write!(
240 f,
241 "endianness mismatch: expected 0x{:X} but got 0x{:X}. \
242 (Are you trying to load an object serialized with a \
243 different endianness?)",
244 expected, found,
245 ),
246 AlignmentMismatch { alignment, address } => write!(
247 f,
248 "alignment mismatch: slice starts at address \
249 0x{:X}, which is not aligned to a {} byte boundary",
250 address, alignment,
251 ),
252 LabelMismatch { expected } => write!(
253 f,
254 "label mismatch: start of serialized object should \
255 contain a NUL terminated {:?} label, but a different \
256 label was found",
257 expected,
258 ),
259 ArithmeticOverflow { what } => {
260 write!(f, "arithmetic overflow for {}", what)
261 }
262 PatternID { ref err, what } => {
263 write!(f, "failed to read pattern ID for {}: {}", what, err)
264 }
265 StateID { ref err, what } => {
266 write!(f, "failed to read state ID for {}: {}", what, err)
267 }
268 }
269 }
270}
271
272/// Safely converts a `&[u32]` to `&[StateID]` with zero cost.
273#[cfg_attr(feature = "perf-inline", inline(always))]
274pub(crate) fn u32s_to_state_ids(slice: &[u32]) -> &[StateID] {
275 // SAFETY: This is safe because StateID is defined to have the same memory
276 // representation as a u32 (it is repr(transparent)). While not every u32
277 // is a "valid" StateID, callers are not permitted to rely on the validity
278 // of StateIDs for memory safety. It can only lead to logical errors. (This
279 // is why StateID::new_unchecked is safe.)
280 unsafe {
281 core::slice::from_raw_parts(
282 slice.as_ptr().cast::<StateID>(),
283 slice.len(),
284 )
285 }
286}
287
288/// Safely converts a `&mut [u32]` to `&mut [StateID]` with zero cost.
289pub(crate) fn u32s_to_state_ids_mut(slice: &mut [u32]) -> &mut [StateID] {
290 // SAFETY: This is safe because StateID is defined to have the same memory
291 // representation as a u32 (it is repr(transparent)). While not every u32
292 // is a "valid" StateID, callers are not permitted to rely on the validity
293 // of StateIDs for memory safety. It can only lead to logical errors. (This
294 // is why StateID::new_unchecked is safe.)
295 unsafe {
296 core::slice::from_raw_parts_mut(
297 slice.as_mut_ptr().cast::<StateID>(),
298 slice.len(),
299 )
300 }
301}
302
303/// Safely converts a `&[u32]` to `&[PatternID]` with zero cost.
304#[cfg_attr(feature = "perf-inline", inline(always))]
305pub(crate) fn u32s_to_pattern_ids(slice: &[u32]) -> &[PatternID] {
306 // SAFETY: This is safe because PatternID is defined to have the same
307 // memory representation as a u32 (it is repr(transparent)). While not
308 // every u32 is a "valid" PatternID, callers are not permitted to rely
309 // on the validity of PatternIDs for memory safety. It can only lead to
310 // logical errors. (This is why PatternID::new_unchecked is safe.)
311 unsafe {
312 core::slice::from_raw_parts(
313 slice.as_ptr().cast::<PatternID>(),
314 slice.len(),
315 )
316 }
317}
318
319/// Checks that the given slice has an alignment that matches `T`.
320///
321/// This is useful for checking that a slice has an appropriate alignment
322/// before casting it to a &[T]. Note though that alignment is not itself
323/// sufficient to perform the cast for any `T`.
324pub(crate) fn check_alignment<T>(
325 slice: &[u8],
326) -> Result<(), DeserializeError> {
327 let alignment = core::mem::align_of::<T>();
328 let address = slice.as_ptr().as_usize();
329 if address % alignment == 0 {
330 return Ok(());
331 }
332 Err(DeserializeError::alignment_mismatch(alignment, address))
333}
334
335/// Reads a possibly empty amount of padding, up to 7 bytes, from the beginning
336/// of the given slice. All padding bytes must be NUL bytes.
337///
338/// This is useful because it can be theoretically necessary to pad the
339/// beginning of a serialized object with NUL bytes to ensure that it starts
340/// at a correctly aligned address. These padding bytes should come immediately
341/// before the label.
342///
343/// This returns the number of bytes read from the given slice.
344pub(crate) fn skip_initial_padding(slice: &[u8]) -> usize {
345 let mut nread = 0;
346 while nread < 7 && nread < slice.len() && slice[nread] == 0 {
347 nread += 1;
348 }
349 nread
350}
351
352/// Allocate a byte buffer of the given size, along with some initial padding
353/// such that `buf[padding..]` has the same alignment as `T`, where the
354/// alignment of `T` must be at most `8`. In particular, callers should treat
355/// the first N bytes (second return value) as padding bytes that must not be
356/// overwritten. In all cases, the following identity holds:
357///
358/// ```ignore
359/// let (buf, padding) = alloc_aligned_buffer::<StateID>(SIZE);
360/// assert_eq!(SIZE, buf[padding..].len());
361/// ```
362///
363/// In practice, padding is often zero.
364///
365/// The requirement for `8` as a maximum here is somewhat arbitrary. In
366/// practice, we never need anything bigger in this crate, and so this function
367/// does some sanity asserts under the assumption of a max alignment of `8`.
368#[cfg(feature = "alloc")]
369pub(crate) fn alloc_aligned_buffer<T>(size: usize) -> (Vec<u8>, usize) {
370 // NOTE: This is a kludge because there's no easy way to allocate a Vec<u8>
371 // with an alignment guaranteed to be greater than 1. We could create a
372 // Vec<u32>, but this cannot be safely transmuted to a Vec<u8> without
373 // concern, since reallocing or dropping the Vec<u8> is UB (different
374 // alignment than the initial allocation). We could define a wrapper type
375 // to manage this for us, but it seems like more machinery than it's worth.
376 let buf = vec![0; size];
377 let align = core::mem::align_of::<T>();
378 let address = buf.as_ptr().as_usize();
379 if address % align == 0 {
380 return (buf, 0);
381 }
382 // Let's try this again. We have to create a totally new alloc with
383 // the maximum amount of bytes we might need. We can't just extend our
384 // pre-existing 'buf' because that might create a new alloc with a
385 // different alignment.
386 let extra = align - 1;
387 let mut buf = vec![0; size + extra];
388 let address = buf.as_ptr().as_usize();
389 // The code below handles the case where 'address' is aligned to T, so if
390 // we got lucky and 'address' is now aligned to T (when it previously
391 // wasn't), then we're done.
392 if address % align == 0 {
393 buf.truncate(size);
394 return (buf, 0);
395 }
396 let padding = ((address & !(align - 1)).checked_add(align).unwrap())
397 .checked_sub(address)
398 .unwrap();
399 assert!(padding <= 7, "padding of {} is bigger than 7", padding);
400 assert!(
401 padding <= extra,
402 "padding of {} is bigger than extra {} bytes",
403 padding,
404 extra
405 );
406 buf.truncate(size + padding);
407 assert_eq!(size + padding, buf.len());
408 assert_eq!(
409 0,
410 buf[padding..].as_ptr().as_usize() % align,
411 "expected end of initial padding to be aligned to {}",
412 align,
413 );
414 (buf, padding)
415}
416
417/// Reads a NUL terminated label starting at the beginning of the given slice.
418///
419/// If a NUL terminated label could not be found, then an error is returned.
420/// Similarly, if a label is found but doesn't match the expected label, then
421/// an error is returned.
422///
423/// Upon success, the total number of bytes read (including padding bytes) is
424/// returned.
425pub(crate) fn read_label(
426 slice: &[u8],
427 expected_label: &'static str,
428) -> Result<usize, DeserializeError> {
429 // Set an upper bound on how many bytes we scan for a NUL. Since no label
430 // in this crate is longer than 256 bytes, if we can't find one within that
431 // range, then we have corrupted data.
432 let first_nul =
433 slice[..cmp::min(slice.len(), 256)].iter().position(|&b| b == 0);
434 let first_nul = match first_nul {
435 Some(first_nul) => first_nul,
436 None => {
437 return Err(DeserializeError::generic(
438 "could not find NUL terminated label \
439 at start of serialized object",
440 ));
441 }
442 };
443 let len = first_nul + padding_len(first_nul);
444 if slice.len() < len {
445 return Err(DeserializeError::generic(
446 "could not find properly sized label at start of serialized object"
447 ));
448 }
449 if expected_label.as_bytes() != &slice[..first_nul] {
450 return Err(DeserializeError::label_mismatch(expected_label));
451 }
452 Ok(len)
453}
454
455/// Writes the given label to the buffer as a NUL terminated string. The label
456/// given must not contain NUL, otherwise this will panic. Similarly, the label
457/// must not be longer than 255 bytes, otherwise this will panic.
458///
459/// Additional NUL bytes are written as necessary to ensure that the number of
460/// bytes written is always a multiple of 4.
461///
462/// Upon success, the total number of bytes written (including padding) is
463/// returned.
464pub(crate) fn write_label(
465 label: &str,
466 dst: &mut [u8],
467) -> Result<usize, SerializeError> {
468 let nwrite = write_label_len(label);
469 if dst.len() < nwrite {
470 return Err(SerializeError::buffer_too_small("label"));
471 }
472 dst[..label.len()].copy_from_slice(label.as_bytes());
473 for i in 0..(nwrite - label.len()) {
474 dst[label.len() + i] = 0;
475 }
476 assert_eq!(nwrite % 4, 0);
477 Ok(nwrite)
478}
479
480/// Returns the total number of bytes (including padding) that would be written
481/// for the given label. This panics if the given label contains a NUL byte or
482/// is longer than 255 bytes. (The size restriction exists so that searching
483/// for a label during deserialization can be done in small bounded space.)
484pub(crate) fn write_label_len(label: &str) -> usize {
485 if label.len() > 255 {
486 panic!("label must not be longer than 255 bytes");
487 }
488 if label.as_bytes().iter().position(|&b| b == 0).is_some() {
489 panic!("label must not contain NUL bytes");
490 }
491 let label_len = label.len() + 1; // +1 for the NUL terminator
492 label_len + padding_len(label_len)
493}
494
495/// Reads the endianness check from the beginning of the given slice and
496/// confirms that the endianness of the serialized object matches the expected
497/// endianness. If the slice is too small or if the endianness check fails,
498/// this returns an error.
499///
500/// Upon success, the total number of bytes read is returned.
501pub(crate) fn read_endianness_check(
502 slice: &[u8],
503) -> Result<usize, DeserializeError> {
504 let (n, nr) = try_read_u32(slice, "endianness check")?;
505 assert_eq!(nr, write_endianness_check_len());
506 if n != 0xFEFF {
507 return Err(DeserializeError::endian_mismatch(0xFEFF, n));
508 }
509 Ok(nr)
510}
511
512/// Writes 0xFEFF as an integer using the given endianness.
513///
514/// This is useful for writing into the header of a serialized object. It can
515/// be read during deserialization as a sanity check to ensure the proper
516/// endianness is used.
517///
518/// Upon success, the total number of bytes written is returned.
519pub(crate) fn write_endianness_check<E: Endian>(
520 dst: &mut [u8],
521) -> Result<usize, SerializeError> {
522 let nwrite = write_endianness_check_len();
523 if dst.len() < nwrite {
524 return Err(SerializeError::buffer_too_small("endianness check"));
525 }
526 E::write_u32(0xFEFF, dst);
527 Ok(nwrite)
528}
529
530/// Returns the number of bytes written by the endianness check.
531pub(crate) fn write_endianness_check_len() -> usize {
532 size_of::<u32>()
533}
534
535/// Reads a version number from the beginning of the given slice and confirms
536/// that is matches the expected version number given. If the slice is too
537/// small or if the version numbers aren't equivalent, this returns an error.
538///
539/// Upon success, the total number of bytes read is returned.
540///
541/// N.B. Currently, we require that the version number is exactly equivalent.
542/// In the future, if we bump the version number without a semver bump, then
543/// we'll need to relax this a bit and support older versions.
544pub(crate) fn read_version(
545 slice: &[u8],
546 expected_version: u32,
547) -> Result<usize, DeserializeError> {
548 let (n, nr) = try_read_u32(slice, "version")?;
549 assert_eq!(nr, write_version_len());
550 if n != expected_version {
551 return Err(DeserializeError::version_mismatch(expected_version, n));
552 }
553 Ok(nr)
554}
555
556/// Writes the given version number to the beginning of the given slice.
557///
558/// This is useful for writing into the header of a serialized object. It can
559/// be read during deserialization as a sanity check to ensure that the library
560/// code supports the format of the serialized object.
561///
562/// Upon success, the total number of bytes written is returned.
563pub(crate) fn write_version<E: Endian>(
564 version: u32,
565 dst: &mut [u8],
566) -> Result<usize, SerializeError> {
567 let nwrite = write_version_len();
568 if dst.len() < nwrite {
569 return Err(SerializeError::buffer_too_small("version number"));
570 }
571 E::write_u32(version, dst);
572 Ok(nwrite)
573}
574
575/// Returns the number of bytes written by writing the version number.
576pub(crate) fn write_version_len() -> usize {
577 size_of::<u32>()
578}
579
580/// Reads a pattern ID from the given slice. If the slice has insufficient
581/// length, then this panics. If the deserialized integer exceeds the pattern
582/// ID limit for the current target, then this returns an error.
583///
584/// Upon success, this also returns the number of bytes read.
585pub(crate) fn read_pattern_id(
586 slice: &[u8],
587 what: &'static str,
588) -> Result<(PatternID, usize), DeserializeError> {
589 let bytes: [u8; PatternID::SIZE] =
590 slice[..PatternID::SIZE].try_into().unwrap();
591 let pid = PatternID::from_ne_bytes(bytes)
592 .map_err(|err| DeserializeError::pattern_id_error(err, what))?;
593 Ok((pid, PatternID::SIZE))
594}
595
596/// Reads a pattern ID from the given slice. If the slice has insufficient
597/// length, then this panics. Otherwise, the deserialized integer is assumed
598/// to be a valid pattern ID.
599///
600/// This also returns the number of bytes read.
601pub(crate) fn read_pattern_id_unchecked(slice: &[u8]) -> (PatternID, usize) {
602 let pid = PatternID::from_ne_bytes_unchecked(
603 slice[..PatternID::SIZE].try_into().unwrap(),
604 );
605 (pid, PatternID::SIZE)
606}
607
608/// Write the given pattern ID to the beginning of the given slice of bytes
609/// using the specified endianness. The given slice must have length at least
610/// `PatternID::SIZE`, or else this panics. Upon success, the total number of
611/// bytes written is returned.
612pub(crate) fn write_pattern_id<E: Endian>(
613 pid: PatternID,
614 dst: &mut [u8],
615) -> usize {
616 E::write_u32(pid.as_u32(), dst);
617 PatternID::SIZE
618}
619
620/// Attempts to read a state ID from the given slice. If the slice has an
621/// insufficient number of bytes or if the state ID exceeds the limit for
622/// the current target, then this returns an error.
623///
624/// Upon success, this also returns the number of bytes read.
625pub(crate) fn try_read_state_id(
626 slice: &[u8],
627 what: &'static str,
628) -> Result<(StateID, usize), DeserializeError> {
629 if slice.len() < StateID::SIZE {
630 return Err(DeserializeError::buffer_too_small(what));
631 }
632 read_state_id(slice, what)
633}
634
635/// Reads a state ID from the given slice. If the slice has insufficient
636/// length, then this panics. If the deserialized integer exceeds the state ID
637/// limit for the current target, then this returns an error.
638///
639/// Upon success, this also returns the number of bytes read.
640pub(crate) fn read_state_id(
641 slice: &[u8],
642 what: &'static str,
643) -> Result<(StateID, usize), DeserializeError> {
644 let bytes: [u8; StateID::SIZE] =
645 slice[..StateID::SIZE].try_into().unwrap();
646 let sid = StateID::from_ne_bytes(bytes)
647 .map_err(|err| DeserializeError::state_id_error(err, what))?;
648 Ok((sid, StateID::SIZE))
649}
650
651/// Reads a state ID from the given slice. If the slice has insufficient
652/// length, then this panics. Otherwise, the deserialized integer is assumed
653/// to be a valid state ID.
654///
655/// This also returns the number of bytes read.
656pub(crate) fn read_state_id_unchecked(slice: &[u8]) -> (StateID, usize) {
657 let sid = StateID::from_ne_bytes_unchecked(
658 slice[..StateID::SIZE].try_into().unwrap(),
659 );
660 (sid, StateID::SIZE)
661}
662
663/// Write the given state ID to the beginning of the given slice of bytes
664/// using the specified endianness. The given slice must have length at least
665/// `StateID::SIZE`, or else this panics. Upon success, the total number of
666/// bytes written is returned.
667pub(crate) fn write_state_id<E: Endian>(
668 sid: StateID,
669 dst: &mut [u8],
670) -> usize {
671 E::write_u32(sid.as_u32(), dst);
672 StateID::SIZE
673}
674
675/// Try to read a u16 as a usize from the beginning of the given slice in
676/// native endian format. If the slice has fewer than 2 bytes or if the
677/// deserialized number cannot be represented by usize, then this returns an
678/// error. The error message will include the `what` description of what is
679/// being deserialized, for better error messages. `what` should be a noun in
680/// singular form.
681///
682/// Upon success, this also returns the number of bytes read.
683pub(crate) fn try_read_u16_as_usize(
684 slice: &[u8],
685 what: &'static str,
686) -> Result<(usize, usize), DeserializeError> {
687 try_read_u16(slice, what).and_then(|(n, nr)| {
688 usize::try_from(n)
689 .map(|n| (n, nr))
690 .map_err(|_| DeserializeError::invalid_usize(what))
691 })
692}
693
694/// Try to read a u32 as a usize from the beginning of the given slice in
695/// native endian format. If the slice has fewer than 4 bytes or if the
696/// deserialized number cannot be represented by usize, then this returns an
697/// error. The error message will include the `what` description of what is
698/// being deserialized, for better error messages. `what` should be a noun in
699/// singular form.
700///
701/// Upon success, this also returns the number of bytes read.
702pub(crate) fn try_read_u32_as_usize(
703 slice: &[u8],
704 what: &'static str,
705) -> Result<(usize, usize), DeserializeError> {
706 try_read_u32(slice, what).and_then(|(n, nr)| {
707 usize::try_from(n)
708 .map(|n| (n, nr))
709 .map_err(|_| DeserializeError::invalid_usize(what))
710 })
711}
712
713/// Try to read a u16 from the beginning of the given slice in native endian
714/// format. If the slice has fewer than 2 bytes, then this returns an error.
715/// The error message will include the `what` description of what is being
716/// deserialized, for better error messages. `what` should be a noun in
717/// singular form.
718///
719/// Upon success, this also returns the number of bytes read.
720pub(crate) fn try_read_u16(
721 slice: &[u8],
722 what: &'static str,
723) -> Result<(u16, usize), DeserializeError> {
724 check_slice_len(slice, size_of::<u16>(), what)?;
725 Ok((read_u16(slice), size_of::<u16>()))
726}
727
728/// Try to read a u32 from the beginning of the given slice in native endian
729/// format. If the slice has fewer than 4 bytes, then this returns an error.
730/// The error message will include the `what` description of what is being
731/// deserialized, for better error messages. `what` should be a noun in
732/// singular form.
733///
734/// Upon success, this also returns the number of bytes read.
735pub(crate) fn try_read_u32(
736 slice: &[u8],
737 what: &'static str,
738) -> Result<(u32, usize), DeserializeError> {
739 check_slice_len(slice, size_of::<u32>(), what)?;
740 Ok((read_u32(slice), size_of::<u32>()))
741}
742
743/// Try to read a u128 from the beginning of the given slice in native endian
744/// format. If the slice has fewer than 16 bytes, then this returns an error.
745/// The error message will include the `what` description of what is being
746/// deserialized, for better error messages. `what` should be a noun in
747/// singular form.
748///
749/// Upon success, this also returns the number of bytes read.
750pub(crate) fn try_read_u128(
751 slice: &[u8],
752 what: &'static str,
753) -> Result<(u128, usize), DeserializeError> {
754 check_slice_len(slice, size_of::<u128>(), what)?;
755 Ok((read_u128(slice), size_of::<u128>()))
756}
757
758/// Read a u16 from the beginning of the given slice in native endian format.
759/// If the slice has fewer than 2 bytes, then this panics.
760///
761/// Marked as inline to speed up sparse searching which decodes integers from
762/// its automaton at search time.
763#[cfg_attr(feature = "perf-inline", inline(always))]
764pub(crate) fn read_u16(slice: &[u8]) -> u16 {
765 let bytes: [u8; 2] = slice[..size_of::<u16>()].try_into().unwrap();
766 u16::from_ne_bytes(bytes)
767}
768
769/// Read a u32 from the beginning of the given slice in native endian format.
770/// If the slice has fewer than 4 bytes, then this panics.
771///
772/// Marked as inline to speed up sparse searching which decodes integers from
773/// its automaton at search time.
774#[cfg_attr(feature = "perf-inline", inline(always))]
775pub(crate) fn read_u32(slice: &[u8]) -> u32 {
776 let bytes: [u8; 4] = slice[..size_of::<u32>()].try_into().unwrap();
777 u32::from_ne_bytes(bytes)
778}
779
780/// Read a u128 from the beginning of the given slice in native endian format.
781/// If the slice has fewer than 16 bytes, then this panics.
782pub(crate) fn read_u128(slice: &[u8]) -> u128 {
783 let bytes: [u8; 16] = slice[..size_of::<u128>()].try_into().unwrap();
784 u128::from_ne_bytes(bytes)
785}
786
787/// Checks that the given slice has some minimal length. If it's smaller than
788/// the bound given, then a "buffer too small" error is returned with `what`
789/// describing what the buffer represents.
790pub(crate) fn check_slice_len<T>(
791 slice: &[T],
792 at_least_len: usize,
793 what: &'static str,
794) -> Result<(), DeserializeError> {
795 if slice.len() < at_least_len {
796 return Err(DeserializeError::buffer_too_small(what));
797 }
798 Ok(())
799}
800
801/// Multiply the given numbers, and on overflow, return an error that includes
802/// 'what' in the error message.
803///
804/// This is useful when doing arithmetic with untrusted data.
805pub(crate) fn mul(
806 a: usize,
807 b: usize,
808 what: &'static str,
809) -> Result<usize, DeserializeError> {
810 match a.checked_mul(b) {
811 Some(c) => Ok(c),
812 None => Err(DeserializeError::arithmetic_overflow(what)),
813 }
814}
815
816/// Add the given numbers, and on overflow, return an error that includes
817/// 'what' in the error message.
818///
819/// This is useful when doing arithmetic with untrusted data.
820pub(crate) fn add(
821 a: usize,
822 b: usize,
823 what: &'static str,
824) -> Result<usize, DeserializeError> {
825 match a.checked_add(b) {
826 Some(c) => Ok(c),
827 None => Err(DeserializeError::arithmetic_overflow(what)),
828 }
829}
830
831/// Shift `a` left by `b`, and on overflow, return an error that includes
832/// 'what' in the error message.
833///
834/// This is useful when doing arithmetic with untrusted data.
835pub(crate) fn shl(
836 a: usize,
837 b: usize,
838 what: &'static str,
839) -> Result<usize, DeserializeError> {
840 let amount = u32::try_from(b)
841 .map_err(|_| DeserializeError::arithmetic_overflow(what))?;
842 match a.checked_shl(amount) {
843 Some(c) => Ok(c),
844 None => Err(DeserializeError::arithmetic_overflow(what)),
845 }
846}
847
848/// Returns the number of additional bytes required to add to the given length
849/// in order to make the total length a multiple of 4. The return value is
850/// always less than 4.
851pub(crate) fn padding_len(non_padding_len: usize) -> usize {
852 (4 - (non_padding_len & 0b11)) & 0b11
853}
854
855/// A simple trait for writing code generic over endianness.
856///
857/// This is similar to what byteorder provides, but we only need a very small
858/// subset.
859pub(crate) trait Endian {
860 /// Writes a u16 to the given destination buffer in a particular
861 /// endianness. If the destination buffer has a length smaller than 2, then
862 /// this panics.
863 fn write_u16(n: u16, dst: &mut [u8]);
864
865 /// Writes a u32 to the given destination buffer in a particular
866 /// endianness. If the destination buffer has a length smaller than 4, then
867 /// this panics.
868 fn write_u32(n: u32, dst: &mut [u8]);
869
870 /// Writes a u64 to the given destination buffer in a particular
871 /// endianness. If the destination buffer has a length smaller than 8, then
872 /// this panics.
873 fn write_u64(n: u64, dst: &mut [u8]);
874
875 /// Writes a u128 to the given destination buffer in a particular
876 /// endianness. If the destination buffer has a length smaller than 16,
877 /// then this panics.
878 fn write_u128(n: u128, dst: &mut [u8]);
879}
880
881/// Little endian writing.
882pub(crate) enum LE {}
883/// Big endian writing.
884pub(crate) enum BE {}
885
886#[cfg(target_endian = "little")]
887pub(crate) type NE = LE;
888#[cfg(target_endian = "big")]
889pub(crate) type NE = BE;
890
891impl Endian for LE {
892 fn write_u16(n: u16, dst: &mut [u8]) {
893 dst[..2].copy_from_slice(&n.to_le_bytes());
894 }
895
896 fn write_u32(n: u32, dst: &mut [u8]) {
897 dst[..4].copy_from_slice(&n.to_le_bytes());
898 }
899
900 fn write_u64(n: u64, dst: &mut [u8]) {
901 dst[..8].copy_from_slice(&n.to_le_bytes());
902 }
903
904 fn write_u128(n: u128, dst: &mut [u8]) {
905 dst[..16].copy_from_slice(&n.to_le_bytes());
906 }
907}
908
909impl Endian for BE {
910 fn write_u16(n: u16, dst: &mut [u8]) {
911 dst[..2].copy_from_slice(&n.to_be_bytes());
912 }
913
914 fn write_u32(n: u32, dst: &mut [u8]) {
915 dst[..4].copy_from_slice(&n.to_be_bytes());
916 }
917
918 fn write_u64(n: u64, dst: &mut [u8]) {
919 dst[..8].copy_from_slice(&n.to_be_bytes());
920 }
921
922 fn write_u128(n: u128, dst: &mut [u8]) {
923 dst[..16].copy_from_slice(&n.to_be_bytes());
924 }
925}
926
927#[cfg(all(test, feature = "alloc"))]
928mod tests {
929 use super::*;
930
931 #[test]
932 fn labels() {
933 let mut buf = [0; 1024];
934
935 let nwrite = write_label("fooba", &mut buf).unwrap();
936 assert_eq!(nwrite, 8);
937 assert_eq!(&buf[..nwrite], b"fooba\x00\x00\x00");
938
939 let nread = read_label(&buf, "fooba").unwrap();
940 assert_eq!(nread, 8);
941 }
942
943 #[test]
944 #[should_panic]
945 fn bad_label_interior_nul() {
946 // interior NULs are not allowed
947 write_label("foo\x00bar", &mut [0; 1024]).unwrap();
948 }
949
950 #[test]
951 fn bad_label_almost_too_long() {
952 // ok
953 write_label(&"z".repeat(255), &mut [0; 1024]).unwrap();
954 }
955
956 #[test]
957 #[should_panic]
958 fn bad_label_too_long() {
959 // labels longer than 255 bytes are banned
960 write_label(&"z".repeat(256), &mut [0; 1024]).unwrap();
961 }
962
963 #[test]
964 fn padding() {
965 assert_eq!(0, padding_len(8));
966 assert_eq!(3, padding_len(9));
967 assert_eq!(2, padding_len(10));
968 assert_eq!(1, padding_len(11));
969 assert_eq!(0, padding_len(12));
970 assert_eq!(3, padding_len(13));
971 assert_eq!(2, padding_len(14));
972 assert_eq!(1, padding_len(15));
973 assert_eq!(0, padding_len(16));
974 }
975}
976