1use std::{fmt, str};
2
3#[allow(unused)]
4pub(crate) fn write_escaped_str(mut dest: impl fmt::Write, src: &str) -> fmt::Result {
5 // This implementation reads one byte after another.
6 // It's not very fast, but should work well enough until portable SIMD gets stabilized.
7
8 let mut escaped_buf: [u8; 8] = ESCAPED_BUF_INIT;
9 let mut last: usize = 0;
10
11 for (index: usize, byte: u8) in src.bytes().enumerate() {
12 if let Some(escaped: [u8; 2]) = get_escaped(byte) {
13 [escaped_buf[2], escaped_buf[3]] = escaped;
14 write_str_if_nonempty(&mut dest, &src[last..index])?;
15 // SAFETY: the content of `escaped_buf` is pure ASCII
16 dest.write_str(unsafe { str::from_utf8_unchecked(&escaped_buf[..ESCAPED_BUF_LEN]) })?;
17 last = index + 1;
18 }
19 }
20 write_str_if_nonempty(&mut dest, &src[last..])
21}
22
23#[allow(unused)]
24pub(crate) fn write_escaped_char(mut dest: impl fmt::Write, c: char) -> fmt::Result {
25 if !c.is_ascii() {
26 dest.write_char(c)
27 } else if let Some(escaped: [u8; 2]) = get_escaped(byte:c as u8) {
28 let mut escaped_buf: [u8; 8] = ESCAPED_BUF_INIT;
29 [escaped_buf[2], escaped_buf[3]] = escaped;
30 // SAFETY: the content of `escaped_buf` is pure ASCII
31 dest.write_str(unsafe { str::from_utf8_unchecked(&escaped_buf[..ESCAPED_BUF_LEN]) })
32 } else {
33 // RATIONALE: `write_char(c)` gets optimized if it is known that `c.is_ascii()`
34 dest.write_char(c)
35 }
36}
37
38/// Returns the decimal representation of the codepoint if the character needs HTML escaping.
39#[inline(always)]
40fn get_escaped(byte: u8) -> Option<[u8; 2]> {
41 match byte {
42 MIN_CHAR..=MAX_CHAR => match TABLE.lookup[(byte - MIN_CHAR) as usize] {
43 0 => None,
44 escaped: u16 => Some(escaped.to_ne_bytes()),
45 },
46 _ => None,
47 }
48}
49
50#[inline(always)]
51fn write_str_if_nonempty(output: &mut impl fmt::Write, input: &str) -> fmt::Result {
52 if !input.is_empty() {
53 output.write_str(input)
54 } else {
55 Ok(())
56 }
57}
58
59/// List of characters that need HTML escaping, not necessarily in ordinal order.
60const CHARS: &[u8] = br#""&'<>"#;
61
62/// The character with the lowest codepoint that needs HTML escaping.
63const MIN_CHAR: u8 = {
64 let mut v: u8 = u8::MAX;
65 let mut i: usize = 0;
66 while i < CHARS.len() {
67 if v > CHARS[i] {
68 v = CHARS[i];
69 }
70 i += 1;
71 }
72 v
73};
74
75/// The character with the highest codepoint that needs HTML escaping.
76const MAX_CHAR: u8 = {
77 let mut v: u8 = u8::MIN;
78 let mut i: usize = 0;
79 while i < CHARS.len() {
80 if v < CHARS[i] {
81 v = CHARS[i];
82 }
83 i += 1;
84 }
85 v
86};
87
88/// Number of codepoints between the lowest and highest character that needs escaping, incl.
89const CHAR_RANGE: usize = (MAX_CHAR - MIN_CHAR + 1) as usize;
90
91struct Table {
92 _align: [usize; 0],
93 lookup: [u16; CHAR_RANGE],
94}
95
96/// For characters that need HTML escaping, the codepoint is formatted as decimal digits,
97/// otherwise `b"\0\0"`. Starting at [`MIN_CHAR`].
98const TABLE: Table = {
99 let mut table: Table = Table {
100 _align: [],
101 lookup: [0; CHAR_RANGE],
102 };
103 let mut i: usize = 0;
104 while i < CHARS.len() {
105 let c: u8 = CHARS[i];
106 let h: u8 = c / 10 + b'0';
107 let l: u8 = c % 10 + b'0';
108 table.lookup[(c - MIN_CHAR) as usize] = u16::from_ne_bytes([h, l]);
109 i += 1;
110 }
111 table
112};
113
114// RATIONALE: llvm generates better code if the buffer is register sized
115const ESCAPED_BUF_INIT: [u8; 8] = *b"&#__;\0\0\0";
116const ESCAPED_BUF_LEN: usize = b"&#__;".len();
117
118#[test]
119fn test_simple_html_string_escaping() {
120 let mut buf = String::new();
121 write_escaped_str(&mut buf, "<script>").unwrap();
122 assert_eq!(buf, "&#60;script&#62;");
123
124 buf.clear();
125 write_escaped_str(&mut buf, "s<crip>t").unwrap();
126 assert_eq!(buf, "s&#60;crip&#62;t");
127
128 buf.clear();
129 write_escaped_str(&mut buf, "s<cripcripcripcripcripcripcripcripcripcrip>t").unwrap();
130 assert_eq!(buf, "s&#60;cripcripcripcripcripcripcripcripcripcrip&#62;t");
131}
132