1 | // Copyright 2015 Google Inc. All rights reserved. |
2 | // |
3 | // Permission is hereby granted, free of charge, to any person obtaining a copy |
4 | // of this software and associated documentation files (the "Software"), to deal |
5 | // in the Software without restriction, including without limitation the rights |
6 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
7 | // copies of the Software, and to permit persons to whom the Software is |
8 | // furnished to do so, subject to the following conditions: |
9 | // |
10 | // The above copyright notice and this permission notice shall be included in |
11 | // all copies or substantial portions of the Software. |
12 | // |
13 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
14 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
15 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
16 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
17 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
18 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
19 | // THE SOFTWARE. |
20 | |
21 | //! Utility functions for HTML escaping. Only useful when building your own |
22 | //! HTML renderer. |
23 | |
24 | use std::fmt::{Arguments, Write as FmtWrite}; |
25 | use std::io::{self, ErrorKind, Write}; |
26 | use std::str::from_utf8; |
27 | |
28 | #[rustfmt::skip] |
29 | static HREF_SAFE: [u8; 128] = [ |
30 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
31 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
32 | 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, |
33 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, |
34 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, |
35 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, |
36 | 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, |
37 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, |
38 | ]; |
39 | |
40 | static HEX_CHARS: &[u8] = b"0123456789ABCDEF" ; |
41 | static AMP_ESCAPE: &str = "&" ; |
42 | static SINGLE_QUOTE_ESCAPE: &str = "'" ; |
43 | |
44 | /// This wrapper exists because we can't have both a blanket implementation |
45 | /// for all types implementing `Write` and types of the for `&mut W` where |
46 | /// `W: StrWrite`. Since we need the latter a lot, we choose to wrap |
47 | /// `Write` types. |
48 | #[derive (Debug)] |
49 | pub struct WriteWrapper<W>(pub W); |
50 | |
51 | /// Trait that allows writing string slices. This is basically an extension |
52 | /// of `std::io::Write` in order to include `String`. |
53 | pub trait StrWrite { |
54 | fn write_str(&mut self, s: &str) -> io::Result<()>; |
55 | |
56 | fn write_fmt(&mut self, args: Arguments) -> io::Result<()>; |
57 | } |
58 | |
59 | impl<W> StrWrite for WriteWrapper<W> |
60 | where |
61 | W: Write, |
62 | { |
63 | #[inline ] |
64 | fn write_str(&mut self, s: &str) -> io::Result<()> { |
65 | self.0.write_all(buf:s.as_bytes()) |
66 | } |
67 | |
68 | #[inline ] |
69 | fn write_fmt(&mut self, args: Arguments) -> io::Result<()> { |
70 | self.0.write_fmt(args) |
71 | } |
72 | } |
73 | |
74 | impl<'w> StrWrite for String { |
75 | #[inline ] |
76 | fn write_str(&mut self, s: &str) -> io::Result<()> { |
77 | self.push_str(string:s); |
78 | Ok(()) |
79 | } |
80 | |
81 | #[inline ] |
82 | fn write_fmt(&mut self, args: Arguments) -> io::Result<()> { |
83 | // FIXME: translate fmt error to io error? |
84 | FmtWrite::write_fmt(self, args).map_err(|_| ErrorKind::Other.into()) |
85 | } |
86 | } |
87 | |
88 | impl<W> StrWrite for &'_ mut W |
89 | where |
90 | W: StrWrite, |
91 | { |
92 | #[inline ] |
93 | fn write_str(&mut self, s: &str) -> io::Result<()> { |
94 | (**self).write_str(s) |
95 | } |
96 | |
97 | #[inline ] |
98 | fn write_fmt(&mut self, args: Arguments) -> io::Result<()> { |
99 | (**self).write_fmt(args) |
100 | } |
101 | } |
102 | |
103 | /// Writes an href to the buffer, escaping href unsafe bytes. |
104 | pub fn escape_href<W>(mut w: W, s: &str) -> io::Result<()> |
105 | where |
106 | W: StrWrite, |
107 | { |
108 | let bytes = s.as_bytes(); |
109 | let mut mark = 0; |
110 | for i in 0..bytes.len() { |
111 | let c = bytes[i]; |
112 | if c >= 0x80 || HREF_SAFE[c as usize] == 0 { |
113 | // character needing escape |
114 | |
115 | // write partial substring up to mark |
116 | if mark < i { |
117 | w.write_str(&s[mark..i])?; |
118 | } |
119 | match c { |
120 | b'&' => { |
121 | w.write_str(AMP_ESCAPE)?; |
122 | } |
123 | b' \'' => { |
124 | w.write_str(SINGLE_QUOTE_ESCAPE)?; |
125 | } |
126 | _ => { |
127 | let mut buf = [0u8; 3]; |
128 | buf[0] = b'%' ; |
129 | buf[1] = HEX_CHARS[((c as usize) >> 4) & 0xF]; |
130 | buf[2] = HEX_CHARS[(c as usize) & 0xF]; |
131 | let escaped = from_utf8(&buf).unwrap(); |
132 | w.write_str(escaped)?; |
133 | } |
134 | } |
135 | mark = i + 1; // all escaped characters are ASCII |
136 | } |
137 | } |
138 | w.write_str(&s[mark..]) |
139 | } |
140 | |
141 | const fn create_html_escape_table() -> [u8; 256] { |
142 | let mut table: [u8; 256] = [0; 256]; |
143 | table[b'"' as usize] = 1; |
144 | table[b'&' as usize] = 2; |
145 | table[b'<' as usize] = 3; |
146 | table[b'>' as usize] = 4; |
147 | table |
148 | } |
149 | |
150 | static HTML_ESCAPE_TABLE: [u8; 256] = create_html_escape_table(); |
151 | |
152 | static HTML_ESCAPES: [&str; 5] = ["" , """ , "&" , "<" , ">" ]; |
153 | |
154 | /// Writes the given string to the Write sink, replacing special HTML bytes |
155 | /// (<, >, &, ") by escape sequences. |
156 | pub fn escape_html<W: StrWrite>(w: W, s: &str) -> io::Result<()> { |
157 | #[cfg (all(target_arch = "x86_64" , feature = "simd" ))] |
158 | { |
159 | simd::escape_html(w, s) |
160 | } |
161 | #[cfg (not(all(target_arch = "x86_64" , feature = "simd" )))] |
162 | { |
163 | escape_html_scalar(w, s) |
164 | } |
165 | } |
166 | |
167 | fn escape_html_scalar<W: StrWrite>(mut w: W, s: &str) -> io::Result<()> { |
168 | let bytes: &[u8] = s.as_bytes(); |
169 | let mut mark: usize = 0; |
170 | let mut i: usize = 0; |
171 | while i < s.len() { |
172 | match bytesIter<'_, u8>[i..] |
173 | .iter() |
174 | .position(|&c: u8| HTML_ESCAPE_TABLE[c as usize] != 0) |
175 | { |
176 | Some(pos: usize) => { |
177 | i += pos; |
178 | } |
179 | None => break, |
180 | } |
181 | let c: u8 = bytes[i]; |
182 | let escape: u8 = HTML_ESCAPE_TABLE[c as usize]; |
183 | let escape_seq: &str = HTML_ESCAPES[escape as usize]; |
184 | w.write_str(&s[mark..i])?; |
185 | w.write_str(escape_seq)?; |
186 | i += 1; |
187 | mark = i; // all escaped characters are ASCII |
188 | } |
189 | w.write_str(&s[mark..]) |
190 | } |
191 | |
192 | #[cfg (all(target_arch = "x86_64" , feature = "simd" ))] |
193 | mod simd { |
194 | use super::StrWrite; |
195 | use std::arch::x86_64::*; |
196 | use std::io; |
197 | use std::mem::size_of; |
198 | |
199 | const VECTOR_SIZE: usize = size_of::<__m128i>(); |
200 | |
201 | pub(super) fn escape_html<W: StrWrite>(mut w: W, s: &str) -> io::Result<()> { |
202 | // The SIMD accelerated code uses the PSHUFB instruction, which is part |
203 | // of the SSSE3 instruction set. Further, we can only use this code if |
204 | // the buffer is at least one VECTOR_SIZE in length to prevent reading |
205 | // out of bounds. If either of these conditions is not met, we fall back |
206 | // to scalar code. |
207 | if is_x86_feature_detected!("ssse3" ) && s.len() >= VECTOR_SIZE { |
208 | let bytes = s.as_bytes(); |
209 | let mut mark = 0; |
210 | |
211 | unsafe { |
212 | foreach_special_simd(bytes, 0, |i| { |
213 | let escape_ix = *bytes.get_unchecked(i) as usize; |
214 | let replacement = |
215 | super::HTML_ESCAPES[super::HTML_ESCAPE_TABLE[escape_ix] as usize]; |
216 | w.write_str(&s.get_unchecked(mark..i))?; |
217 | mark = i + 1; // all escaped characters are ASCII |
218 | w.write_str(replacement) |
219 | })?; |
220 | w.write_str(&s.get_unchecked(mark..)) |
221 | } |
222 | } else { |
223 | super::escape_html_scalar(w, s) |
224 | } |
225 | } |
226 | |
227 | /// Creates the lookup table for use in `compute_mask`. |
228 | const fn create_lookup() -> [u8; 16] { |
229 | let mut table = [0; 16]; |
230 | table[(b'<' & 0x0f) as usize] = b'<' ; |
231 | table[(b'>' & 0x0f) as usize] = b'>' ; |
232 | table[(b'&' & 0x0f) as usize] = b'&' ; |
233 | table[(b'"' & 0x0f) as usize] = b'"' ; |
234 | table[0] = 0b0111_1111; |
235 | table |
236 | } |
237 | |
238 | #[target_feature (enable = "ssse3" )] |
239 | /// Computes a byte mask at given offset in the byte buffer. Its first 16 (least significant) |
240 | /// bits correspond to whether there is an HTML special byte (&, <, ", >) at the 16 bytes |
241 | /// `bytes[offset..]`. For example, the mask `(1 << 3)` states that there is an HTML byte |
242 | /// at `offset + 3`. It is only safe to call this function when |
243 | /// `bytes.len() >= offset + VECTOR_SIZE`. |
244 | unsafe fn compute_mask(bytes: &[u8], offset: usize) -> i32 { |
245 | debug_assert!(bytes.len() >= offset + VECTOR_SIZE); |
246 | |
247 | let table = create_lookup(); |
248 | let lookup = _mm_loadu_si128(table.as_ptr() as *const __m128i); |
249 | let raw_ptr = bytes.as_ptr().offset(offset as isize) as *const __m128i; |
250 | |
251 | // Load the vector from memory. |
252 | let vector = _mm_loadu_si128(raw_ptr); |
253 | // We take the least significant 4 bits of every byte and use them as indices |
254 | // to map into the lookup vector. |
255 | // Note that shuffle maps bytes with their most significant bit set to lookup[0]. |
256 | // Bytes that share their lower nibble with an HTML special byte get mapped to that |
257 | // corresponding special byte. Note that all HTML special bytes have distinct lower |
258 | // nibbles. Other bytes either get mapped to 0 or 127. |
259 | let expected = _mm_shuffle_epi8(lookup, vector); |
260 | // We compare the original vector to the mapped output. Bytes that shared a lower |
261 | // nibble with an HTML special byte match *only* if they are that special byte. Bytes |
262 | // that have either a 0 lower nibble or their most significant bit set were mapped to |
263 | // 127 and will hence never match. All other bytes have non-zero lower nibbles but |
264 | // were mapped to 0 and will therefore also not match. |
265 | let matches = _mm_cmpeq_epi8(expected, vector); |
266 | |
267 | // Translate matches to a bitmask, where every 1 corresponds to a HTML special character |
268 | // and a 0 is a non-HTML byte. |
269 | _mm_movemask_epi8(matches) |
270 | } |
271 | |
272 | /// Calls the given function with the index of every byte in the given byteslice |
273 | /// that is either ", &, <, or > and for no other byte. |
274 | /// Make sure to only call this when `bytes.len() >= 16`, undefined behaviour may |
275 | /// occur otherwise. |
276 | #[target_feature (enable = "ssse3" )] |
277 | unsafe fn foreach_special_simd<F>( |
278 | bytes: &[u8], |
279 | mut offset: usize, |
280 | mut callback: F, |
281 | ) -> io::Result<()> |
282 | where |
283 | F: FnMut(usize) -> io::Result<()>, |
284 | { |
285 | // The strategy here is to walk the byte buffer in chunks of VECTOR_SIZE (16) |
286 | // bytes at a time starting at the given offset. For each chunk, we compute a |
287 | // a bitmask indicating whether the corresponding byte is a HTML special byte. |
288 | // We then iterate over all the 1 bits in this mask and call the callback function |
289 | // with the corresponding index in the buffer. |
290 | // When the number of HTML special bytes in the buffer is relatively low, this |
291 | // allows us to quickly go through the buffer without a lookup and for every |
292 | // single byte. |
293 | |
294 | debug_assert!(bytes.len() >= VECTOR_SIZE); |
295 | let upperbound = bytes.len() - VECTOR_SIZE; |
296 | while offset < upperbound { |
297 | let mut mask = compute_mask(bytes, offset); |
298 | while mask != 0 { |
299 | let ix = mask.trailing_zeros(); |
300 | callback(offset + ix as usize)?; |
301 | mask ^= mask & -mask; |
302 | } |
303 | offset += VECTOR_SIZE; |
304 | } |
305 | |
306 | // Final iteration. We align the read with the end of the slice and |
307 | // shift off the bytes at start we have already scanned. |
308 | let mut mask = compute_mask(bytes, upperbound); |
309 | mask >>= offset - upperbound; |
310 | while mask != 0 { |
311 | let ix = mask.trailing_zeros(); |
312 | callback(offset + ix as usize)?; |
313 | mask ^= mask & -mask; |
314 | } |
315 | Ok(()) |
316 | } |
317 | |
318 | #[cfg (test)] |
319 | mod html_scan_tests { |
320 | #[test ] |
321 | fn multichunk() { |
322 | let mut vec = Vec::new(); |
323 | unsafe { |
324 | super::foreach_special_simd("&aXaaaa.a'aa9a<>aab&" .as_bytes(), 0, |ix| { |
325 | Ok(vec.push(ix)) |
326 | }) |
327 | .unwrap(); |
328 | } |
329 | assert_eq!(vec, vec![0, 14, 15, 19]); |
330 | } |
331 | |
332 | // only match these bytes, and when we match them, match them VECTOR_SIZE times |
333 | #[test ] |
334 | fn only_right_bytes_matched() { |
335 | for b in 0..255u8 { |
336 | let right_byte = b == b'&' || b == b'<' || b == b'>' || b == b'"' ; |
337 | let vek = vec![b; super::VECTOR_SIZE]; |
338 | let mut match_count = 0; |
339 | unsafe { |
340 | super::foreach_special_simd(&vek, 0, |_| { |
341 | match_count += 1; |
342 | Ok(()) |
343 | }) |
344 | .unwrap(); |
345 | } |
346 | assert!((match_count > 0) == (match_count == super::VECTOR_SIZE)); |
347 | assert_eq!( |
348 | (match_count == super::VECTOR_SIZE), |
349 | right_byte, |
350 | "match_count: {}, byte: {:?}" , |
351 | match_count, |
352 | b as char |
353 | ); |
354 | } |
355 | } |
356 | } |
357 | } |
358 | |
359 | #[cfg (test)] |
360 | mod test { |
361 | pub use super::escape_href; |
362 | |
363 | #[test ] |
364 | fn check_href_escape() { |
365 | let mut s = String::new(); |
366 | escape_href(&mut s, "&^_" ).unwrap(); |
367 | assert_eq!(s.as_str(), "&^_" ); |
368 | } |
369 | } |
370 | |