1// Copyright 2015 Google Inc. All rights reserved.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a copy
4// of this software and associated documentation files (the "Software"), to deal
5// in the Software without restriction, including without limitation the rights
6// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7// copies of the Software, and to permit persons to whom the Software is
8// furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19// THE SOFTWARE.
20
21//! Utility functions for HTML escaping. Only useful when building your own
22//! HTML renderer.
23
24use std::fmt::{self, Arguments};
25use std::io::{self, Write};
26use std::str::from_utf8;
27
28#[rustfmt::skip]
29static HREF_SAFE: [u8; 128] = [
30 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
31 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
32 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
33 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1,
34 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
35 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1,
36 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
37 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0,
38];
39
40static HEX_CHARS: &[u8] = b"0123456789ABCDEF";
41static AMP_ESCAPE: &str = "&";
42static SINGLE_QUOTE_ESCAPE: &str = "'";
43
44/// This wrapper exists because we can't have both a blanket implementation
45/// for all types implementing `Write` and types of the for `&mut W` where
46/// `W: StrWrite`. Since we need the latter a lot, we choose to wrap
47/// `Write` types.
48#[derive(Debug)]
49pub struct IoWriter<W>(pub W);
50
51/// Trait that allows writing string slices. This is basically an extension
52/// of `std::io::Write` in order to include `String`.
53pub trait StrWrite {
54 type Error;
55
56 fn write_str(&mut self, s: &str) -> Result<(), Self::Error>;
57 fn write_fmt(&mut self, args: Arguments) -> Result<(), Self::Error>;
58}
59
60impl<W> StrWrite for IoWriter<W>
61where
62 W: Write,
63{
64 type Error = io::Error;
65
66 #[inline]
67 fn write_str(&mut self, s: &str) -> io::Result<()> {
68 self.0.write_all(buf:s.as_bytes())
69 }
70
71 #[inline]
72 fn write_fmt(&mut self, args: Arguments) -> io::Result<()> {
73 self.0.write_fmt(args)
74 }
75}
76
77/// This wrapper exists because we can't have both a blanket implementation
78/// for all types implementing `io::Write` and types of the form `&mut W` where
79/// `W: StrWrite`. Since we need the latter a lot, we choose to wrap
80/// `Write` types.
81#[derive(Debug)]
82pub struct FmtWriter<W>(pub W);
83
84impl<W> StrWrite for FmtWriter<W>
85where
86 W: fmt::Write,
87{
88 type Error = fmt::Error;
89
90 #[inline]
91 fn write_str(&mut self, s: &str) -> fmt::Result {
92 self.0.write_str(s)
93 }
94
95 #[inline]
96 fn write_fmt(&mut self, args: Arguments) -> fmt::Result {
97 self.0.write_fmt(args)
98 }
99}
100
101impl StrWrite for String {
102 type Error = fmt::Error;
103
104 #[inline]
105 fn write_str(&mut self, s: &str) -> fmt::Result {
106 self.push_str(string:s);
107 Ok(())
108 }
109
110 #[inline]
111 fn write_fmt(&mut self, args: Arguments) -> fmt::Result {
112 fmt::Write::write_fmt(self, args)
113 }
114}
115
116impl<W> StrWrite for &'_ mut W
117where
118 W: StrWrite,
119{
120 type Error = W::Error;
121
122 #[inline]
123 fn write_str(&mut self, s: &str) -> Result<(), Self::Error> {
124 (**self).write_str(s)
125 }
126
127 #[inline]
128 fn write_fmt(&mut self, args: Arguments) -> Result<(), Self::Error> {
129 (**self).write_fmt(args)
130 }
131}
132
133/// Writes an href to the buffer, escaping href unsafe bytes.
134pub fn escape_href<W>(mut w: W, s: &str) -> Result<(), W::Error>
135where
136 W: StrWrite,
137{
138 let bytes = s.as_bytes();
139 let mut mark = 0;
140 for i in 0..bytes.len() {
141 let c = bytes[i];
142 if c >= 0x80 || HREF_SAFE[c as usize] == 0 {
143 // character needing escape
144
145 // write partial substring up to mark
146 if mark < i {
147 w.write_str(&s[mark..i])?;
148 }
149 match c {
150 b'&' => {
151 w.write_str(AMP_ESCAPE)?;
152 }
153 b'\'' => {
154 w.write_str(SINGLE_QUOTE_ESCAPE)?;
155 }
156 _ => {
157 let mut buf = [0u8; 3];
158 buf[0] = b'%';
159 buf[1] = HEX_CHARS[((c as usize) >> 4) & 0xF];
160 buf[2] = HEX_CHARS[(c as usize) & 0xF];
161 let escaped = from_utf8(&buf).unwrap();
162 w.write_str(escaped)?;
163 }
164 }
165 mark = i + 1; // all escaped characters are ASCII
166 }
167 }
168 w.write_str(&s[mark..])
169}
170
171const fn create_html_escape_table(body: bool) -> [u8; 256] {
172 let mut table: [u8; 256] = [0; 256];
173 table[b'&' as usize] = 1;
174 table[b'<' as usize] = 2;
175 table[b'>' as usize] = 3;
176 if !body {
177 table[b'"' as usize] = 4;
178 table[b'\'' as usize] = 5;
179 }
180 table
181}
182
183static HTML_ESCAPE_TABLE: [u8; 256] = create_html_escape_table(body:false);
184static HTML_BODY_TEXT_ESCAPE_TABLE: [u8; 256] = create_html_escape_table(body:true);
185
186static HTML_ESCAPES: [&str; 6] = ["", "&amp;", "&lt;", "&gt;", "&quot;", "&#39;"];
187
188/// Writes the given string to the Write sink, replacing special HTML bytes
189/// (<, >, &, ", ') by escape sequences.
190///
191/// Use this function to write output to quoted HTML attributes.
192/// Since this function doesn't escape spaces, unquoted attributes
193/// cannot be used. For example:
194///
195/// ```rust
196/// let mut value = String::new();
197/// pulldown_cmark_escape::escape_html(&mut value, "two words")
198/// .expect("writing to a string is infallible");
199/// // This is okay.
200/// let ok = format!("<a title='{value}'>test</a>");
201/// // This is not okay.
202/// //let not_ok = format!("<a title={value}>test</a>");
203/// ````
204pub fn escape_html<W: StrWrite>(w: W, s: &str) -> Result<(), W::Error> {
205 #[cfg(all(target_arch = "x86_64", feature = "simd"))]
206 {
207 simd::escape_html(w, s, &HTML_ESCAPE_TABLE)
208 }
209 #[cfg(not(all(target_arch = "x86_64", feature = "simd")))]
210 {
211 escape_html_scalar(w, s, &HTML_ESCAPE_TABLE)
212 }
213}
214
215/// For use in HTML body text, writes the given string to the Write sink,
216/// replacing special HTML bytes (<, >, &) by escape sequences.
217///
218/// <div class="warning">
219///
220/// This function should be used for escaping text nodes, not attributes.
221/// In the below example, the word "foo" is an attribute, and the word
222/// "bar" is an text node. The word "bar" could be escaped by this function,
223/// but the word "foo" must be escaped using [`escape_html`].
224///
225/// ```html
226/// <span class="foo">bar</span>
227/// ```
228///
229/// If you aren't sure what the difference is, use [`escape_html`].
230/// It should always be correct, but will produce larger output.
231///
232/// </div>
233pub fn escape_html_body_text<W: StrWrite>(w: W, s: &str) -> Result<(), W::Error> {
234 #[cfg(all(target_arch = "x86_64", feature = "simd"))]
235 {
236 simd::escape_html(w, s, &HTML_BODY_TEXT_ESCAPE_TABLE)
237 }
238 #[cfg(not(all(target_arch = "x86_64", feature = "simd")))]
239 {
240 escape_html_scalar(w, s, &HTML_BODY_TEXT_ESCAPE_TABLE)
241 }
242}
243
244fn escape_html_scalar<W: StrWrite>(
245 mut w: W,
246 s: &str,
247 table: &'static [u8; 256],
248) -> Result<(), W::Error> {
249 let bytes: &[u8] = s.as_bytes();
250 let mut mark: usize = 0;
251 let mut i: usize = 0;
252 while i < s.len() {
253 match bytes[i..].iter().position(|&c: u8| table[c as usize] != 0) {
254 Some(pos: usize) => {
255 i += pos;
256 }
257 None => break,
258 }
259 let c: u8 = bytes[i];
260 let escape: u8 = table[c as usize];
261 let escape_seq: &str = HTML_ESCAPES[escape as usize];
262 w.write_str(&s[mark..i])?;
263 w.write_str(escape_seq)?;
264 i += 1;
265 mark = i; // all escaped characters are ASCII
266 }
267 w.write_str(&s[mark..])
268}
269
270#[cfg(all(target_arch = "x86_64", feature = "simd"))]
271mod simd {
272 use super::StrWrite;
273 use std::arch::x86_64::*;
274 use std::mem::size_of;
275
276 const VECTOR_SIZE: usize = size_of::<__m128i>();
277
278 pub(super) fn escape_html<W: StrWrite>(
279 mut w: W,
280 s: &str,
281 table: &'static [u8; 256],
282 ) -> Result<(), W::Error> {
283 // The SIMD accelerated code uses the PSHUFB instruction, which is part
284 // of the SSSE3 instruction set. Further, we can only use this code if
285 // the buffer is at least one VECTOR_SIZE in length to prevent reading
286 // out of bounds. If either of these conditions is not met, we fall back
287 // to scalar code.
288 if is_x86_feature_detected!("ssse3") && s.len() >= VECTOR_SIZE {
289 let bytes = s.as_bytes();
290 let mut mark = 0;
291
292 unsafe {
293 foreach_special_simd(bytes, 0, |i| {
294 let escape_ix = *bytes.get_unchecked(i) as usize;
295 let entry = table[escape_ix] as usize;
296 w.write_str(s.get_unchecked(mark..i))?;
297 mark = i + 1; // all escaped characters are ASCII
298 if entry == 0 {
299 w.write_str(s.get_unchecked(i..mark))
300 } else {
301 let replacement = super::HTML_ESCAPES[entry];
302 w.write_str(replacement)
303 }
304 })?;
305 w.write_str(s.get_unchecked(mark..))
306 }
307 } else {
308 super::escape_html_scalar(w, s, table)
309 }
310 }
311
312 /// Creates the lookup table for use in `compute_mask`.
313 const fn create_lookup() -> [u8; 16] {
314 let mut table = [0; 16];
315 table[(b'<' & 0x0f) as usize] = b'<';
316 table[(b'>' & 0x0f) as usize] = b'>';
317 table[(b'&' & 0x0f) as usize] = b'&';
318 table[(b'"' & 0x0f) as usize] = b'"';
319 table[(b'\'' & 0x0f) as usize] = b'\'';
320 table[0] = 0b0111_1111;
321 table
322 }
323
324 #[target_feature(enable = "ssse3")]
325 /// Computes a byte mask at given offset in the byte buffer. Its first 16 (least significant)
326 /// bits correspond to whether there is an HTML special byte (&, <, ", >) at the 16 bytes
327 /// `bytes[offset..]`. For example, the mask `(1 << 3)` states that there is an HTML byte
328 /// at `offset + 3`. It is only safe to call this function when
329 /// `bytes.len() >= offset + VECTOR_SIZE`.
330 unsafe fn compute_mask(bytes: &[u8], offset: usize) -> i32 {
331 debug_assert!(bytes.len() >= offset + VECTOR_SIZE);
332
333 let table = create_lookup();
334 let lookup = _mm_loadu_si128(table.as_ptr() as *const __m128i);
335 let raw_ptr = bytes.as_ptr().add(offset) as *const __m128i;
336
337 // Load the vector from memory.
338 let vector = _mm_loadu_si128(raw_ptr);
339 // We take the least significant 4 bits of every byte and use them as indices
340 // to map into the lookup vector.
341 // Note that shuffle maps bytes with their most significant bit set to lookup[0].
342 // Bytes that share their lower nibble with an HTML special byte get mapped to that
343 // corresponding special byte. Note that all HTML special bytes have distinct lower
344 // nibbles. Other bytes either get mapped to 0 or 127.
345 let expected = _mm_shuffle_epi8(lookup, vector);
346 // We compare the original vector to the mapped output. Bytes that shared a lower
347 // nibble with an HTML special byte match *only* if they are that special byte. Bytes
348 // that have either a 0 lower nibble or their most significant bit set were mapped to
349 // 127 and will hence never match. All other bytes have non-zero lower nibbles but
350 // were mapped to 0 and will therefore also not match.
351 let matches = _mm_cmpeq_epi8(expected, vector);
352
353 // Translate matches to a bitmask, where every 1 corresponds to a HTML special character
354 // and a 0 is a non-HTML byte.
355 _mm_movemask_epi8(matches)
356 }
357
358 /// Calls the given function with the index of every byte in the given byteslice
359 /// that is either ", &, <, or > and for no other byte.
360 /// Make sure to only call this when `bytes.len() >= 16`, undefined behaviour may
361 /// occur otherwise.
362 #[target_feature(enable = "ssse3")]
363 unsafe fn foreach_special_simd<E, F>(
364 bytes: &[u8],
365 mut offset: usize,
366 mut callback: F,
367 ) -> Result<(), E>
368 where
369 F: FnMut(usize) -> Result<(), E>,
370 {
371 // The strategy here is to walk the byte buffer in chunks of VECTOR_SIZE (16)
372 // bytes at a time starting at the given offset. For each chunk, we compute a
373 // a bitmask indicating whether the corresponding byte is a HTML special byte.
374 // We then iterate over all the 1 bits in this mask and call the callback function
375 // with the corresponding index in the buffer.
376 // When the number of HTML special bytes in the buffer is relatively low, this
377 // allows us to quickly go through the buffer without a lookup and for every
378 // single byte.
379
380 debug_assert!(bytes.len() >= VECTOR_SIZE);
381 let upperbound = bytes.len() - VECTOR_SIZE;
382 while offset < upperbound {
383 let mut mask = compute_mask(bytes, offset);
384 while mask != 0 {
385 let ix = mask.trailing_zeros();
386 callback(offset + ix as usize)?;
387 mask ^= mask & -mask;
388 }
389 offset += VECTOR_SIZE;
390 }
391
392 // Final iteration. We align the read with the end of the slice and
393 // shift off the bytes at start we have already scanned.
394 let mut mask = compute_mask(bytes, upperbound);
395 mask >>= offset - upperbound;
396 while mask != 0 {
397 let ix = mask.trailing_zeros();
398 callback(offset + ix as usize)?;
399 mask ^= mask & -mask;
400 }
401 Ok(())
402 }
403
404 #[cfg(test)]
405 mod html_scan_tests {
406 #[test]
407 fn multichunk() {
408 let mut vec = Vec::new();
409 unsafe {
410 super::foreach_special_simd("&aXaaaa.a'aa9a<>aab&".as_bytes(), 0, |ix| {
411 #[allow(clippy::unit_arg)]
412 Ok::<_, std::fmt::Error>(vec.push(ix))
413 })
414 .unwrap();
415 }
416 assert_eq!(vec, vec![0, 9, 14, 15, 19]);
417 }
418
419 // only match these bytes, and when we match them, match them VECTOR_SIZE times
420 #[test]
421 fn only_right_bytes_matched() {
422 for b in 0..255u8 {
423 let right_byte = b == b'&' || b == b'<' || b == b'>' || b == b'"' || b == b'\'';
424 let vek = vec![b; super::VECTOR_SIZE];
425 let mut match_count = 0;
426 unsafe {
427 super::foreach_special_simd(&vek, 0, |_| {
428 match_count += 1;
429 Ok::<_, std::fmt::Error>(())
430 })
431 .unwrap();
432 }
433 assert!((match_count > 0) == (match_count == super::VECTOR_SIZE));
434 assert_eq!(
435 (match_count == super::VECTOR_SIZE),
436 right_byte,
437 "match_count: {}, byte: {:?}",
438 match_count,
439 b as char
440 );
441 }
442 }
443 }
444}
445
446#[cfg(test)]
447mod test {
448 pub use super::{escape_href, escape_html, escape_html_body_text};
449
450 #[test]
451 fn check_href_escape() {
452 let mut s = String::new();
453 escape_href(&mut s, "&^_").unwrap();
454 assert_eq!(s.as_str(), "&amp;^_");
455 }
456
457 #[test]
458 fn check_attr_escape() {
459 let mut s = String::new();
460 escape_html(&mut s, r##"&^"'_"##).unwrap();
461 assert_eq!(s.as_str(), "&amp;^&quot;&#39;_");
462 }
463
464 #[test]
465 fn check_body_escape() {
466 let mut s = String::new();
467 escape_html_body_text(&mut s, r##"&^"'_"##).unwrap();
468 assert_eq!(s.as_str(), r##"&amp;^"'_"##);
469 }
470}
471