1#![cfg_attr(not(any(feature = "json", test)), no_std)]
2#![deny(elided_lifetimes_in_paths)]
3#![deny(unreachable_pub)]
4
5use core::fmt::{self, Display, Formatter, Write};
6use core::str;
7
8#[derive(Debug)]
9pub struct MarkupDisplay<E, T>
10where
11 E: Escaper,
12 T: Display,
13{
14 value: DisplayValue<T>,
15 escaper: E,
16}
17
18impl<E, T> MarkupDisplay<E, T>
19where
20 E: Escaper,
21 T: Display,
22{
23 pub fn new_unsafe(value: T, escaper: E) -> Self {
24 Self {
25 value: DisplayValue::Unsafe(value),
26 escaper,
27 }
28 }
29
30 pub fn new_safe(value: T, escaper: E) -> Self {
31 Self {
32 value: DisplayValue::Safe(value),
33 escaper,
34 }
35 }
36
37 #[must_use]
38 pub fn mark_safe(mut self) -> MarkupDisplay<E, T> {
39 self.value = match self.value {
40 DisplayValue::Unsafe(t: T) => DisplayValue::Safe(t),
41 _ => self.value,
42 };
43 self
44 }
45}
46
47impl<E, T> Display for MarkupDisplay<E, T>
48where
49 E: Escaper,
50 T: Display,
51{
52 fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result {
53 match self.value {
54 DisplayValue::Unsafe(ref t: &T) => write!(
55 EscapeWriter {
56 fmt,
57 escaper: &self.escaper
58 },
59 "{}",
60 t
61 ),
62 DisplayValue::Safe(ref t: &T) => t.fmt(fmt),
63 }
64 }
65}
66
67#[derive(Debug)]
68pub struct EscapeWriter<'a, E, W> {
69 fmt: W,
70 escaper: &'a E,
71}
72
73impl<E, W> Write for EscapeWriter<'_, E, W>
74where
75 W: Write,
76 E: Escaper,
77{
78 fn write_str(&mut self, s: &str) -> fmt::Result {
79 self.escaper.write_escaped(&mut self.fmt, string:s)
80 }
81}
82
83pub fn escape<E>(string: &str, escaper: E) -> Escaped<'_, E>
84where
85 E: Escaper,
86{
87 Escaped { string, escaper }
88}
89
90#[derive(Debug)]
91pub struct Escaped<'a, E>
92where
93 E: Escaper,
94{
95 string: &'a str,
96 escaper: E,
97}
98
99impl<E> Display for Escaped<'_, E>
100where
101 E: Escaper,
102{
103 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
104 self.escaper.write_escaped(fmt, self.string)
105 }
106}
107
108pub struct Html;
109
110macro_rules! escaping_body {
111 ($start:ident, $i:ident, $fmt:ident, $bytes:ident, $quote:expr) => {{
112 if $start < $i {
113 $fmt.write_str(unsafe { str::from_utf8_unchecked(&$bytes[$start..$i]) })?;
114 }
115 $fmt.write_str($quote)?;
116 $start = $i + 1;
117 }};
118}
119
120impl Escaper for Html {
121 fn write_escaped<W>(&self, mut fmt: W, string: &str) -> fmt::Result
122 where
123 W: Write,
124 {
125 let bytes = string.as_bytes();
126 let mut start = 0;
127 for (i, b) in bytes.iter().enumerate() {
128 if b.wrapping_sub(b'"') <= FLAG {
129 match *b {
130 b'<' => escaping_body!(start, i, fmt, bytes, "&lt;"),
131 b'>' => escaping_body!(start, i, fmt, bytes, "&gt;"),
132 b'&' => escaping_body!(start, i, fmt, bytes, "&amp;"),
133 b'"' => escaping_body!(start, i, fmt, bytes, "&quot;"),
134 b'\'' => escaping_body!(start, i, fmt, bytes, "&#x27;"),
135 _ => (),
136 }
137 }
138 }
139 if start < bytes.len() {
140 fmt.write_str(unsafe { str::from_utf8_unchecked(&bytes[start..]) })
141 } else {
142 Ok(())
143 }
144 }
145}
146
147pub struct Text;
148
149impl Escaper for Text {
150 fn write_escaped<W>(&self, mut fmt: W, string: &str) -> fmt::Result
151 where
152 W: Write,
153 {
154 fmt.write_str(string)
155 }
156}
157
158#[derive(Debug, PartialEq)]
159enum DisplayValue<T>
160where
161 T: Display,
162{
163 Safe(T),
164 Unsafe(T),
165}
166
167pub trait Escaper {
168 fn write_escaped<W>(&self, fmt: W, string: &str) -> fmt::Result
169 where
170 W: Write;
171}
172
173const FLAG: u8 = b'>' - b'"';
174
175/// Escape chevrons, ampersand and apostrophes for use in JSON
176#[cfg(feature = "json")]
177#[derive(Debug, Clone, Default)]
178pub struct JsonEscapeBuffer(Vec<u8>);
179
180#[cfg(feature = "json")]
181impl JsonEscapeBuffer {
182 pub fn new() -> Self {
183 Self(Vec::new())
184 }
185
186 pub fn finish(self) -> String {
187 unsafe { String::from_utf8_unchecked(self.0) }
188 }
189}
190
191#[cfg(feature = "json")]
192impl std::io::Write for JsonEscapeBuffer {
193 fn write(&mut self, bytes: &[u8]) -> std::io::Result<usize> {
194 macro_rules! push_esc_sequence {
195 ($start:ident, $i:ident, $self:ident, $bytes:ident, $quote:expr) => {{
196 if $start < $i {
197 $self.0.extend_from_slice(&$bytes[$start..$i]);
198 }
199 $self.0.extend_from_slice($quote);
200 $start = $i + 1;
201 }};
202 }
203
204 self.0.reserve(bytes.len());
205 let mut start = 0;
206 for (i, b) in bytes.iter().enumerate() {
207 match *b {
208 b'&' => push_esc_sequence!(start, i, self, bytes, br#"\u0026"#),
209 b'\'' => push_esc_sequence!(start, i, self, bytes, br#"\u0027"#),
210 b'<' => push_esc_sequence!(start, i, self, bytes, br#"\u003c"#),
211 b'>' => push_esc_sequence!(start, i, self, bytes, br#"\u003e"#),
212 _ => (),
213 }
214 }
215 if start < bytes.len() {
216 self.0.extend_from_slice(&bytes[start..]);
217 }
218 Ok(bytes.len())
219 }
220
221 fn flush(&mut self) -> std::io::Result<()> {
222 Ok(())
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use std::string::ToString;
230
231 #[test]
232 fn test_escape() {
233 assert_eq!(escape("", Html).to_string(), "");
234 assert_eq!(escape("<&>", Html).to_string(), "&lt;&amp;&gt;");
235 assert_eq!(escape("bla&", Html).to_string(), "bla&amp;");
236 assert_eq!(escape("<foo", Html).to_string(), "&lt;foo");
237 assert_eq!(escape("bla&h", Html).to_string(), "bla&amp;h");
238 }
239}
240