1 | // Copyright 2015 Google Inc. All rights reserved. |
2 | // |
3 | // Permission is hereby granted, free of charge, to any person obtaining a copy |
4 | // of this software and associated documentation files (the "Software"), to deal |
5 | // in the Software without restriction, including without limitation the rights |
6 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
7 | // copies of the Software, and to permit persons to whom the Software is |
8 | // furnished to do so, subject to the following conditions: |
9 | // |
10 | // The above copyright notice and this permission notice shall be included in |
11 | // all copies or substantial portions of the Software. |
12 | // |
13 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
14 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
15 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
16 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
17 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
18 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
19 | // THE SOFTWARE. |
20 | |
21 | //! Utility functions for HTML escaping. Only useful when building your own |
22 | //! HTML renderer. |
23 | |
24 | use std::fmt::{self, Arguments}; |
25 | use std::io::{self, Write}; |
26 | use std::str::from_utf8; |
27 | |
28 | #[rustfmt::skip] |
29 | static HREF_SAFE: [u8; 128] = [ |
30 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
31 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
32 | 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, |
33 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, |
34 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, |
35 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, |
36 | 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, |
37 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, |
38 | ]; |
39 | |
40 | static HEX_CHARS: &[u8] = b"0123456789ABCDEF" ; |
41 | static AMP_ESCAPE: &str = "&" ; |
42 | static SINGLE_QUOTE_ESCAPE: &str = "'" ; |
43 | |
44 | /// This wrapper exists because we can't have both a blanket implementation |
45 | /// for all types implementing `Write` and types of the for `&mut W` where |
46 | /// `W: StrWrite`. Since we need the latter a lot, we choose to wrap |
47 | /// `Write` types. |
48 | #[derive (Debug)] |
49 | pub struct IoWriter<W>(pub W); |
50 | |
51 | /// Trait that allows writing string slices. This is basically an extension |
52 | /// of `std::io::Write` in order to include `String`. |
53 | pub trait StrWrite { |
54 | type Error; |
55 | |
56 | fn write_str(&mut self, s: &str) -> Result<(), Self::Error>; |
57 | fn write_fmt(&mut self, args: Arguments) -> Result<(), Self::Error>; |
58 | } |
59 | |
60 | impl<W> StrWrite for IoWriter<W> |
61 | where |
62 | W: Write, |
63 | { |
64 | type Error = io::Error; |
65 | |
66 | #[inline ] |
67 | fn write_str(&mut self, s: &str) -> io::Result<()> { |
68 | self.0.write_all(buf:s.as_bytes()) |
69 | } |
70 | |
71 | #[inline ] |
72 | fn write_fmt(&mut self, args: Arguments) -> io::Result<()> { |
73 | self.0.write_fmt(args) |
74 | } |
75 | } |
76 | |
77 | /// This wrapper exists because we can't have both a blanket implementation |
78 | /// for all types implementing `io::Write` and types of the form `&mut W` where |
79 | /// `W: StrWrite`. Since we need the latter a lot, we choose to wrap |
80 | /// `Write` types. |
81 | #[derive (Debug)] |
82 | pub struct FmtWriter<W>(pub W); |
83 | |
84 | impl<W> StrWrite for FmtWriter<W> |
85 | where |
86 | W: fmt::Write, |
87 | { |
88 | type Error = fmt::Error; |
89 | |
90 | #[inline ] |
91 | fn write_str(&mut self, s: &str) -> fmt::Result { |
92 | self.0.write_str(s) |
93 | } |
94 | |
95 | #[inline ] |
96 | fn write_fmt(&mut self, args: Arguments) -> fmt::Result { |
97 | self.0.write_fmt(args) |
98 | } |
99 | } |
100 | |
101 | impl StrWrite for String { |
102 | type Error = fmt::Error; |
103 | |
104 | #[inline ] |
105 | fn write_str(&mut self, s: &str) -> fmt::Result { |
106 | self.push_str(string:s); |
107 | Ok(()) |
108 | } |
109 | |
110 | #[inline ] |
111 | fn write_fmt(&mut self, args: Arguments) -> fmt::Result { |
112 | fmt::Write::write_fmt(self, args) |
113 | } |
114 | } |
115 | |
116 | impl<W> StrWrite for &'_ mut W |
117 | where |
118 | W: StrWrite, |
119 | { |
120 | type Error = W::Error; |
121 | |
122 | #[inline ] |
123 | fn write_str(&mut self, s: &str) -> Result<(), Self::Error> { |
124 | (**self).write_str(s) |
125 | } |
126 | |
127 | #[inline ] |
128 | fn write_fmt(&mut self, args: Arguments) -> Result<(), Self::Error> { |
129 | (**self).write_fmt(args) |
130 | } |
131 | } |
132 | |
133 | /// Writes an href to the buffer, escaping href unsafe bytes. |
134 | pub fn escape_href<W>(mut w: W, s: &str) -> Result<(), W::Error> |
135 | where |
136 | W: StrWrite, |
137 | { |
138 | let bytes = s.as_bytes(); |
139 | let mut mark = 0; |
140 | for i in 0..bytes.len() { |
141 | let c = bytes[i]; |
142 | if c >= 0x80 || HREF_SAFE[c as usize] == 0 { |
143 | // character needing escape |
144 | |
145 | // write partial substring up to mark |
146 | if mark < i { |
147 | w.write_str(&s[mark..i])?; |
148 | } |
149 | match c { |
150 | b'&' => { |
151 | w.write_str(AMP_ESCAPE)?; |
152 | } |
153 | b' \'' => { |
154 | w.write_str(SINGLE_QUOTE_ESCAPE)?; |
155 | } |
156 | _ => { |
157 | let mut buf = [0u8; 3]; |
158 | buf[0] = b'%' ; |
159 | buf[1] = HEX_CHARS[((c as usize) >> 4) & 0xF]; |
160 | buf[2] = HEX_CHARS[(c as usize) & 0xF]; |
161 | let escaped = from_utf8(&buf).unwrap(); |
162 | w.write_str(escaped)?; |
163 | } |
164 | } |
165 | mark = i + 1; // all escaped characters are ASCII |
166 | } |
167 | } |
168 | w.write_str(&s[mark..]) |
169 | } |
170 | |
171 | const fn create_html_escape_table(body: bool) -> [u8; 256] { |
172 | let mut table: [u8; 256] = [0; 256]; |
173 | table[b'&' as usize] = 1; |
174 | table[b'<' as usize] = 2; |
175 | table[b'>' as usize] = 3; |
176 | if !body { |
177 | table[b'"' as usize] = 4; |
178 | table[b' \'' as usize] = 5; |
179 | } |
180 | table |
181 | } |
182 | |
183 | static HTML_ESCAPE_TABLE: [u8; 256] = create_html_escape_table(body:false); |
184 | static HTML_BODY_TEXT_ESCAPE_TABLE: [u8; 256] = create_html_escape_table(body:true); |
185 | |
186 | static HTML_ESCAPES: [&str; 6] = ["" , "&" , "<" , ">" , """ , "'" ]; |
187 | |
188 | /// Writes the given string to the Write sink, replacing special HTML bytes |
189 | /// (<, >, &, ", ') by escape sequences. |
190 | /// |
191 | /// Use this function to write output to quoted HTML attributes. |
192 | /// Since this function doesn't escape spaces, unquoted attributes |
193 | /// cannot be used. For example: |
194 | /// |
195 | /// ```rust |
196 | /// let mut value = String::new(); |
197 | /// pulldown_cmark_escape::escape_html(&mut value, "two words" ) |
198 | /// .expect("writing to a string is infallible" ); |
199 | /// // This is okay. |
200 | /// let ok = format!("<a title='{value}'>test</a>" ); |
201 | /// // This is not okay. |
202 | /// //let not_ok = format!("<a title={value}>test</a>"); |
203 | /// ```` |
204 | pub fn escape_html<W: StrWrite>(w: W, s: &str) -> Result<(), W::Error> { |
205 | #[cfg (all(target_arch = "x86_64" , feature = "simd" ))] |
206 | { |
207 | simd::escape_html(w, s, &HTML_ESCAPE_TABLE) |
208 | } |
209 | #[cfg (not(all(target_arch = "x86_64" , feature = "simd" )))] |
210 | { |
211 | escape_html_scalar(w, s, &HTML_ESCAPE_TABLE) |
212 | } |
213 | } |
214 | |
215 | /// For use in HTML body text, writes the given string to the Write sink, |
216 | /// replacing special HTML bytes (<, >, &) by escape sequences. |
217 | /// |
218 | /// <div class="warning"> |
219 | /// |
220 | /// This function should be used for escaping text nodes, not attributes. |
221 | /// In the below example, the word "foo" is an attribute, and the word |
222 | /// "bar" is an text node. The word "bar" could be escaped by this function, |
223 | /// but the word "foo" must be escaped using [`escape_html`]. |
224 | /// |
225 | /// ```html |
226 | /// <span class="foo">bar</span> |
227 | /// ``` |
228 | /// |
229 | /// If you aren't sure what the difference is, use [`escape_html`]. |
230 | /// It should always be correct, but will produce larger output. |
231 | /// |
232 | /// </div> |
233 | pub fn escape_html_body_text<W: StrWrite>(w: W, s: &str) -> Result<(), W::Error> { |
234 | #[cfg (all(target_arch = "x86_64" , feature = "simd" ))] |
235 | { |
236 | simd::escape_html(w, s, &HTML_BODY_TEXT_ESCAPE_TABLE) |
237 | } |
238 | #[cfg (not(all(target_arch = "x86_64" , feature = "simd" )))] |
239 | { |
240 | escape_html_scalar(w, s, &HTML_BODY_TEXT_ESCAPE_TABLE) |
241 | } |
242 | } |
243 | |
244 | fn escape_html_scalar<W: StrWrite>( |
245 | mut w: W, |
246 | s: &str, |
247 | table: &'static [u8; 256], |
248 | ) -> Result<(), W::Error> { |
249 | let bytes: &[u8] = s.as_bytes(); |
250 | let mut mark: usize = 0; |
251 | let mut i: usize = 0; |
252 | while i < s.len() { |
253 | match bytes[i..].iter().position(|&c: u8| table[c as usize] != 0) { |
254 | Some(pos: usize) => { |
255 | i += pos; |
256 | } |
257 | None => break, |
258 | } |
259 | let c: u8 = bytes[i]; |
260 | let escape: u8 = table[c as usize]; |
261 | let escape_seq: &str = HTML_ESCAPES[escape as usize]; |
262 | w.write_str(&s[mark..i])?; |
263 | w.write_str(escape_seq)?; |
264 | i += 1; |
265 | mark = i; // all escaped characters are ASCII |
266 | } |
267 | w.write_str(&s[mark..]) |
268 | } |
269 | |
270 | #[cfg (all(target_arch = "x86_64" , feature = "simd" ))] |
271 | mod simd { |
272 | use super::StrWrite; |
273 | use std::arch::x86_64::*; |
274 | use std::mem::size_of; |
275 | |
276 | const VECTOR_SIZE: usize = size_of::<__m128i>(); |
277 | |
278 | pub(super) fn escape_html<W: StrWrite>( |
279 | mut w: W, |
280 | s: &str, |
281 | table: &'static [u8; 256], |
282 | ) -> Result<(), W::Error> { |
283 | // The SIMD accelerated code uses the PSHUFB instruction, which is part |
284 | // of the SSSE3 instruction set. Further, we can only use this code if |
285 | // the buffer is at least one VECTOR_SIZE in length to prevent reading |
286 | // out of bounds. If either of these conditions is not met, we fall back |
287 | // to scalar code. |
288 | if is_x86_feature_detected!("ssse3" ) && s.len() >= VECTOR_SIZE { |
289 | let bytes = s.as_bytes(); |
290 | let mut mark = 0; |
291 | |
292 | unsafe { |
293 | foreach_special_simd(bytes, 0, |i| { |
294 | let escape_ix = *bytes.get_unchecked(i) as usize; |
295 | let entry = table[escape_ix] as usize; |
296 | w.write_str(s.get_unchecked(mark..i))?; |
297 | mark = i + 1; // all escaped characters are ASCII |
298 | if entry == 0 { |
299 | w.write_str(s.get_unchecked(i..mark)) |
300 | } else { |
301 | let replacement = super::HTML_ESCAPES[entry]; |
302 | w.write_str(replacement) |
303 | } |
304 | })?; |
305 | w.write_str(s.get_unchecked(mark..)) |
306 | } |
307 | } else { |
308 | super::escape_html_scalar(w, s, table) |
309 | } |
310 | } |
311 | |
312 | /// Creates the lookup table for use in `compute_mask`. |
313 | const fn create_lookup() -> [u8; 16] { |
314 | let mut table = [0; 16]; |
315 | table[(b'<' & 0x0f) as usize] = b'<' ; |
316 | table[(b'>' & 0x0f) as usize] = b'>' ; |
317 | table[(b'&' & 0x0f) as usize] = b'&' ; |
318 | table[(b'"' & 0x0f) as usize] = b'"' ; |
319 | table[(b' \'' & 0x0f) as usize] = b' \'' ; |
320 | table[0] = 0b0111_1111; |
321 | table |
322 | } |
323 | |
324 | #[target_feature (enable = "ssse3" )] |
325 | /// Computes a byte mask at given offset in the byte buffer. Its first 16 (least significant) |
326 | /// bits correspond to whether there is an HTML special byte (&, <, ", >) at the 16 bytes |
327 | /// `bytes[offset..]`. For example, the mask `(1 << 3)` states that there is an HTML byte |
328 | /// at `offset + 3`. It is only safe to call this function when |
329 | /// `bytes.len() >= offset + VECTOR_SIZE`. |
330 | unsafe fn compute_mask(bytes: &[u8], offset: usize) -> i32 { |
331 | debug_assert!(bytes.len() >= offset + VECTOR_SIZE); |
332 | |
333 | let table = create_lookup(); |
334 | let lookup = _mm_loadu_si128(table.as_ptr() as *const __m128i); |
335 | let raw_ptr = bytes.as_ptr().add(offset) as *const __m128i; |
336 | |
337 | // Load the vector from memory. |
338 | let vector = _mm_loadu_si128(raw_ptr); |
339 | // We take the least significant 4 bits of every byte and use them as indices |
340 | // to map into the lookup vector. |
341 | // Note that shuffle maps bytes with their most significant bit set to lookup[0]. |
342 | // Bytes that share their lower nibble with an HTML special byte get mapped to that |
343 | // corresponding special byte. Note that all HTML special bytes have distinct lower |
344 | // nibbles. Other bytes either get mapped to 0 or 127. |
345 | let expected = _mm_shuffle_epi8(lookup, vector); |
346 | // We compare the original vector to the mapped output. Bytes that shared a lower |
347 | // nibble with an HTML special byte match *only* if they are that special byte. Bytes |
348 | // that have either a 0 lower nibble or their most significant bit set were mapped to |
349 | // 127 and will hence never match. All other bytes have non-zero lower nibbles but |
350 | // were mapped to 0 and will therefore also not match. |
351 | let matches = _mm_cmpeq_epi8(expected, vector); |
352 | |
353 | // Translate matches to a bitmask, where every 1 corresponds to a HTML special character |
354 | // and a 0 is a non-HTML byte. |
355 | _mm_movemask_epi8(matches) |
356 | } |
357 | |
358 | /// Calls the given function with the index of every byte in the given byteslice |
359 | /// that is either ", &, <, or > and for no other byte. |
360 | /// Make sure to only call this when `bytes.len() >= 16`, undefined behaviour may |
361 | /// occur otherwise. |
362 | #[target_feature (enable = "ssse3" )] |
363 | unsafe fn foreach_special_simd<E, F>( |
364 | bytes: &[u8], |
365 | mut offset: usize, |
366 | mut callback: F, |
367 | ) -> Result<(), E> |
368 | where |
369 | F: FnMut(usize) -> Result<(), E>, |
370 | { |
371 | // The strategy here is to walk the byte buffer in chunks of VECTOR_SIZE (16) |
372 | // bytes at a time starting at the given offset. For each chunk, we compute a |
373 | // a bitmask indicating whether the corresponding byte is a HTML special byte. |
374 | // We then iterate over all the 1 bits in this mask and call the callback function |
375 | // with the corresponding index in the buffer. |
376 | // When the number of HTML special bytes in the buffer is relatively low, this |
377 | // allows us to quickly go through the buffer without a lookup and for every |
378 | // single byte. |
379 | |
380 | debug_assert!(bytes.len() >= VECTOR_SIZE); |
381 | let upperbound = bytes.len() - VECTOR_SIZE; |
382 | while offset < upperbound { |
383 | let mut mask = compute_mask(bytes, offset); |
384 | while mask != 0 { |
385 | let ix = mask.trailing_zeros(); |
386 | callback(offset + ix as usize)?; |
387 | mask ^= mask & -mask; |
388 | } |
389 | offset += VECTOR_SIZE; |
390 | } |
391 | |
392 | // Final iteration. We align the read with the end of the slice and |
393 | // shift off the bytes at start we have already scanned. |
394 | let mut mask = compute_mask(bytes, upperbound); |
395 | mask >>= offset - upperbound; |
396 | while mask != 0 { |
397 | let ix = mask.trailing_zeros(); |
398 | callback(offset + ix as usize)?; |
399 | mask ^= mask & -mask; |
400 | } |
401 | Ok(()) |
402 | } |
403 | |
404 | #[cfg (test)] |
405 | mod html_scan_tests { |
406 | #[test ] |
407 | fn multichunk() { |
408 | let mut vec = Vec::new(); |
409 | unsafe { |
410 | super::foreach_special_simd("&aXaaaa.a'aa9a<>aab&" .as_bytes(), 0, |ix| { |
411 | #[allow (clippy::unit_arg)] |
412 | Ok::<_, std::fmt::Error>(vec.push(ix)) |
413 | }) |
414 | .unwrap(); |
415 | } |
416 | assert_eq!(vec, vec![0, 9, 14, 15, 19]); |
417 | } |
418 | |
419 | // only match these bytes, and when we match them, match them VECTOR_SIZE times |
420 | #[test ] |
421 | fn only_right_bytes_matched() { |
422 | for b in 0..255u8 { |
423 | let right_byte = b == b'&' || b == b'<' || b == b'>' || b == b'"' || b == b' \'' ; |
424 | let vek = vec![b; super::VECTOR_SIZE]; |
425 | let mut match_count = 0; |
426 | unsafe { |
427 | super::foreach_special_simd(&vek, 0, |_| { |
428 | match_count += 1; |
429 | Ok::<_, std::fmt::Error>(()) |
430 | }) |
431 | .unwrap(); |
432 | } |
433 | assert!((match_count > 0) == (match_count == super::VECTOR_SIZE)); |
434 | assert_eq!( |
435 | (match_count == super::VECTOR_SIZE), |
436 | right_byte, |
437 | "match_count: {}, byte: {:?}" , |
438 | match_count, |
439 | b as char |
440 | ); |
441 | } |
442 | } |
443 | } |
444 | } |
445 | |
446 | #[cfg (test)] |
447 | mod test { |
448 | pub use super::{escape_href, escape_html, escape_html_body_text}; |
449 | |
450 | #[test ] |
451 | fn check_href_escape() { |
452 | let mut s = String::new(); |
453 | escape_href(&mut s, "&^_" ).unwrap(); |
454 | assert_eq!(s.as_str(), "&^_" ); |
455 | } |
456 | |
457 | #[test ] |
458 | fn check_attr_escape() { |
459 | let mut s = String::new(); |
460 | escape_html(&mut s, r##"&^"'_"## ).unwrap(); |
461 | assert_eq!(s.as_str(), "&^"'_" ); |
462 | } |
463 | |
464 | #[test ] |
465 | fn check_body_escape() { |
466 | let mut s = String::new(); |
467 | escape_html_body_text(&mut s, r##"&^"'_"## ).unwrap(); |
468 | assert_eq!(s.as_str(), r##"&^"'_"## ); |
469 | } |
470 | } |
471 | |