1// Copyright 2015 Google Inc. All rights reserved.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a copy
4// of this software and associated documentation files (the "Software"), to deal
5// in the Software without restriction, including without limitation the rights
6// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7// copies of the Software, and to permit persons to whom the Software is
8// furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19// THE SOFTWARE.
20
21//! Utility functions for HTML escaping. Only useful when building your own
22//! HTML renderer.
23
24use std::fmt::{Arguments, Write as FmtWrite};
25use std::io::{self, ErrorKind, Write};
26use std::str::from_utf8;
27
28#[rustfmt::skip]
29static HREF_SAFE: [u8; 128] = [
30 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
31 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
32 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
33 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1,
34 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
35 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1,
36 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
37 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0,
38];
39
40static HEX_CHARS: &[u8] = b"0123456789ABCDEF";
41static AMP_ESCAPE: &str = "&";
42static SINGLE_QUOTE_ESCAPE: &str = "'";
43
44/// This wrapper exists because we can't have both a blanket implementation
45/// for all types implementing `Write` and types of the for `&mut W` where
46/// `W: StrWrite`. Since we need the latter a lot, we choose to wrap
47/// `Write` types.
48#[derive(Debug)]
49pub struct WriteWrapper<W>(pub W);
50
51/// Trait that allows writing string slices. This is basically an extension
52/// of `std::io::Write` in order to include `String`.
53pub trait StrWrite {
54 fn write_str(&mut self, s: &str) -> io::Result<()>;
55
56 fn write_fmt(&mut self, args: Arguments) -> io::Result<()>;
57}
58
59impl<W> StrWrite for WriteWrapper<W>
60where
61 W: Write,
62{
63 #[inline]
64 fn write_str(&mut self, s: &str) -> io::Result<()> {
65 self.0.write_all(buf:s.as_bytes())
66 }
67
68 #[inline]
69 fn write_fmt(&mut self, args: Arguments) -> io::Result<()> {
70 self.0.write_fmt(args)
71 }
72}
73
74impl<'w> StrWrite for String {
75 #[inline]
76 fn write_str(&mut self, s: &str) -> io::Result<()> {
77 self.push_str(string:s);
78 Ok(())
79 }
80
81 #[inline]
82 fn write_fmt(&mut self, args: Arguments) -> io::Result<()> {
83 // FIXME: translate fmt error to io error?
84 FmtWrite::write_fmt(self, args).map_err(|_| ErrorKind::Other.into())
85 }
86}
87
88impl<W> StrWrite for &'_ mut W
89where
90 W: StrWrite,
91{
92 #[inline]
93 fn write_str(&mut self, s: &str) -> io::Result<()> {
94 (**self).write_str(s)
95 }
96
97 #[inline]
98 fn write_fmt(&mut self, args: Arguments) -> io::Result<()> {
99 (**self).write_fmt(args)
100 }
101}
102
103/// Writes an href to the buffer, escaping href unsafe bytes.
104pub fn escape_href<W>(mut w: W, s: &str) -> io::Result<()>
105where
106 W: StrWrite,
107{
108 let bytes = s.as_bytes();
109 let mut mark = 0;
110 for i in 0..bytes.len() {
111 let c = bytes[i];
112 if c >= 0x80 || HREF_SAFE[c as usize] == 0 {
113 // character needing escape
114
115 // write partial substring up to mark
116 if mark < i {
117 w.write_str(&s[mark..i])?;
118 }
119 match c {
120 b'&' => {
121 w.write_str(AMP_ESCAPE)?;
122 }
123 b'\'' => {
124 w.write_str(SINGLE_QUOTE_ESCAPE)?;
125 }
126 _ => {
127 let mut buf = [0u8; 3];
128 buf[0] = b'%';
129 buf[1] = HEX_CHARS[((c as usize) >> 4) & 0xF];
130 buf[2] = HEX_CHARS[(c as usize) & 0xF];
131 let escaped = from_utf8(&buf).unwrap();
132 w.write_str(escaped)?;
133 }
134 }
135 mark = i + 1; // all escaped characters are ASCII
136 }
137 }
138 w.write_str(&s[mark..])
139}
140
141const fn create_html_escape_table() -> [u8; 256] {
142 let mut table: [u8; 256] = [0; 256];
143 table[b'"' as usize] = 1;
144 table[b'&' as usize] = 2;
145 table[b'<' as usize] = 3;
146 table[b'>' as usize] = 4;
147 table
148}
149
150static HTML_ESCAPE_TABLE: [u8; 256] = create_html_escape_table();
151
152static HTML_ESCAPES: [&str; 5] = ["", "&quot;", "&amp;", "&lt;", "&gt;"];
153
154/// Writes the given string to the Write sink, replacing special HTML bytes
155/// (<, >, &, ") by escape sequences.
156pub fn escape_html<W: StrWrite>(w: W, s: &str) -> io::Result<()> {
157 #[cfg(all(target_arch = "x86_64", feature = "simd"))]
158 {
159 simd::escape_html(w, s)
160 }
161 #[cfg(not(all(target_arch = "x86_64", feature = "simd")))]
162 {
163 escape_html_scalar(w, s)
164 }
165}
166
167fn escape_html_scalar<W: StrWrite>(mut w: W, s: &str) -> io::Result<()> {
168 let bytes: &[u8] = s.as_bytes();
169 let mut mark: usize = 0;
170 let mut i: usize = 0;
171 while i < s.len() {
172 match bytesIter<'_, u8>[i..]
173 .iter()
174 .position(|&c: u8| HTML_ESCAPE_TABLE[c as usize] != 0)
175 {
176 Some(pos: usize) => {
177 i += pos;
178 }
179 None => break,
180 }
181 let c: u8 = bytes[i];
182 let escape: u8 = HTML_ESCAPE_TABLE[c as usize];
183 let escape_seq: &str = HTML_ESCAPES[escape as usize];
184 w.write_str(&s[mark..i])?;
185 w.write_str(escape_seq)?;
186 i += 1;
187 mark = i; // all escaped characters are ASCII
188 }
189 w.write_str(&s[mark..])
190}
191
192#[cfg(all(target_arch = "x86_64", feature = "simd"))]
193mod simd {
194 use super::StrWrite;
195 use std::arch::x86_64::*;
196 use std::io;
197 use std::mem::size_of;
198
199 const VECTOR_SIZE: usize = size_of::<__m128i>();
200
201 pub(super) fn escape_html<W: StrWrite>(mut w: W, s: &str) -> io::Result<()> {
202 // The SIMD accelerated code uses the PSHUFB instruction, which is part
203 // of the SSSE3 instruction set. Further, we can only use this code if
204 // the buffer is at least one VECTOR_SIZE in length to prevent reading
205 // out of bounds. If either of these conditions is not met, we fall back
206 // to scalar code.
207 if is_x86_feature_detected!("ssse3") && s.len() >= VECTOR_SIZE {
208 let bytes = s.as_bytes();
209 let mut mark = 0;
210
211 unsafe {
212 foreach_special_simd(bytes, 0, |i| {
213 let escape_ix = *bytes.get_unchecked(i) as usize;
214 let replacement =
215 super::HTML_ESCAPES[super::HTML_ESCAPE_TABLE[escape_ix] as usize];
216 w.write_str(&s.get_unchecked(mark..i))?;
217 mark = i + 1; // all escaped characters are ASCII
218 w.write_str(replacement)
219 })?;
220 w.write_str(&s.get_unchecked(mark..))
221 }
222 } else {
223 super::escape_html_scalar(w, s)
224 }
225 }
226
227 /// Creates the lookup table for use in `compute_mask`.
228 const fn create_lookup() -> [u8; 16] {
229 let mut table = [0; 16];
230 table[(b'<' & 0x0f) as usize] = b'<';
231 table[(b'>' & 0x0f) as usize] = b'>';
232 table[(b'&' & 0x0f) as usize] = b'&';
233 table[(b'"' & 0x0f) as usize] = b'"';
234 table[0] = 0b0111_1111;
235 table
236 }
237
238 #[target_feature(enable = "ssse3")]
239 /// Computes a byte mask at given offset in the byte buffer. Its first 16 (least significant)
240 /// bits correspond to whether there is an HTML special byte (&, <, ", >) at the 16 bytes
241 /// `bytes[offset..]`. For example, the mask `(1 << 3)` states that there is an HTML byte
242 /// at `offset + 3`. It is only safe to call this function when
243 /// `bytes.len() >= offset + VECTOR_SIZE`.
244 unsafe fn compute_mask(bytes: &[u8], offset: usize) -> i32 {
245 debug_assert!(bytes.len() >= offset + VECTOR_SIZE);
246
247 let table = create_lookup();
248 let lookup = _mm_loadu_si128(table.as_ptr() as *const __m128i);
249 let raw_ptr = bytes.as_ptr().offset(offset as isize) as *const __m128i;
250
251 // Load the vector from memory.
252 let vector = _mm_loadu_si128(raw_ptr);
253 // We take the least significant 4 bits of every byte and use them as indices
254 // to map into the lookup vector.
255 // Note that shuffle maps bytes with their most significant bit set to lookup[0].
256 // Bytes that share their lower nibble with an HTML special byte get mapped to that
257 // corresponding special byte. Note that all HTML special bytes have distinct lower
258 // nibbles. Other bytes either get mapped to 0 or 127.
259 let expected = _mm_shuffle_epi8(lookup, vector);
260 // We compare the original vector to the mapped output. Bytes that shared a lower
261 // nibble with an HTML special byte match *only* if they are that special byte. Bytes
262 // that have either a 0 lower nibble or their most significant bit set were mapped to
263 // 127 and will hence never match. All other bytes have non-zero lower nibbles but
264 // were mapped to 0 and will therefore also not match.
265 let matches = _mm_cmpeq_epi8(expected, vector);
266
267 // Translate matches to a bitmask, where every 1 corresponds to a HTML special character
268 // and a 0 is a non-HTML byte.
269 _mm_movemask_epi8(matches)
270 }
271
272 /// Calls the given function with the index of every byte in the given byteslice
273 /// that is either ", &, <, or > and for no other byte.
274 /// Make sure to only call this when `bytes.len() >= 16`, undefined behaviour may
275 /// occur otherwise.
276 #[target_feature(enable = "ssse3")]
277 unsafe fn foreach_special_simd<F>(
278 bytes: &[u8],
279 mut offset: usize,
280 mut callback: F,
281 ) -> io::Result<()>
282 where
283 F: FnMut(usize) -> io::Result<()>,
284 {
285 // The strategy here is to walk the byte buffer in chunks of VECTOR_SIZE (16)
286 // bytes at a time starting at the given offset. For each chunk, we compute a
287 // a bitmask indicating whether the corresponding byte is a HTML special byte.
288 // We then iterate over all the 1 bits in this mask and call the callback function
289 // with the corresponding index in the buffer.
290 // When the number of HTML special bytes in the buffer is relatively low, this
291 // allows us to quickly go through the buffer without a lookup and for every
292 // single byte.
293
294 debug_assert!(bytes.len() >= VECTOR_SIZE);
295 let upperbound = bytes.len() - VECTOR_SIZE;
296 while offset < upperbound {
297 let mut mask = compute_mask(bytes, offset);
298 while mask != 0 {
299 let ix = mask.trailing_zeros();
300 callback(offset + ix as usize)?;
301 mask ^= mask & -mask;
302 }
303 offset += VECTOR_SIZE;
304 }
305
306 // Final iteration. We align the read with the end of the slice and
307 // shift off the bytes at start we have already scanned.
308 let mut mask = compute_mask(bytes, upperbound);
309 mask >>= offset - upperbound;
310 while mask != 0 {
311 let ix = mask.trailing_zeros();
312 callback(offset + ix as usize)?;
313 mask ^= mask & -mask;
314 }
315 Ok(())
316 }
317
318 #[cfg(test)]
319 mod html_scan_tests {
320 #[test]
321 fn multichunk() {
322 let mut vec = Vec::new();
323 unsafe {
324 super::foreach_special_simd("&aXaaaa.a'aa9a<>aab&".as_bytes(), 0, |ix| {
325 Ok(vec.push(ix))
326 })
327 .unwrap();
328 }
329 assert_eq!(vec, vec![0, 14, 15, 19]);
330 }
331
332 // only match these bytes, and when we match them, match them VECTOR_SIZE times
333 #[test]
334 fn only_right_bytes_matched() {
335 for b in 0..255u8 {
336 let right_byte = b == b'&' || b == b'<' || b == b'>' || b == b'"';
337 let vek = vec![b; super::VECTOR_SIZE];
338 let mut match_count = 0;
339 unsafe {
340 super::foreach_special_simd(&vek, 0, |_| {
341 match_count += 1;
342 Ok(())
343 })
344 .unwrap();
345 }
346 assert!((match_count > 0) == (match_count == super::VECTOR_SIZE));
347 assert_eq!(
348 (match_count == super::VECTOR_SIZE),
349 right_byte,
350 "match_count: {}, byte: {:?}",
351 match_count,
352 b as char
353 );
354 }
355 }
356 }
357}
358
359#[cfg(test)]
360mod test {
361 pub use super::escape_href;
362
363 #[test]
364 fn check_href_escape() {
365 let mut s = String::new();
366 escape_href(&mut s, "&^_").unwrap();
367 assert_eq!(s.as_str(), "&amp;^_");
368 }
369}
370