1 | use std::borrow::Borrow; |
2 | use std::collections::HashMap; |
3 | use std::error; |
4 | use std::fmt; |
5 | use std::io; |
6 | use std::result; |
7 | |
8 | use super::{TrieSetSlice, CHUNK_SIZE}; |
9 | |
10 | // This implementation was pretty much cribbed from raphlinus' contribution |
11 | // to the standard library: https://github.com/rust-lang/rust/pull/33098/files |
12 | // |
13 | // The fundamental principle guiding this implementation is to take advantage |
14 | // of the fact that similar Unicode codepoints are often grouped together, and |
15 | // that most boolean Unicode properties are quite sparse over the entire space |
16 | // of Unicode codepoints. |
17 | // |
18 | // To do this, we represent sets using something like a trie (which gives us |
19 | // prefix compression). The "final" states of the trie are embedded in leaves |
20 | // or "chunks," where each chunk is a 64 bit integer. Each bit position of the |
21 | // integer corresponds to whether a particular codepoint is in the set or not. |
22 | // These chunks are not just a compact representation of the final states of |
23 | // the trie, but are also a form of suffix compression. In particular, if |
24 | // multiple ranges of 64 contiguous codepoints map have the same set membership |
25 | // ordering, then they all map to the exact same chunk in the trie. |
26 | // |
27 | // We organize this structure by partitioning the space of Unicode codepoints |
28 | // into three disjoint sets. The first set corresponds to codepoints |
29 | // [0, 0x800), the second [0x800, 0x1000) and the third [0x10000, 0x110000). |
30 | // These partitions conveniently correspond to the space of 1 or 2 byte UTF-8 |
31 | // encoded codepoints, 3 byte UTF-8 encoded codepoints and 4 byte UTF-8 encoded |
32 | // codepoints, respectively. |
33 | // |
34 | // Each partition has its own tree with its own root. The first partition is |
35 | // the simplest, since the tree is completely flat. In particular, to determine |
36 | // the set membership of a Unicode codepoint (that is less than `0x800`), we |
37 | // do the following (where `cp` is the codepoint we're testing): |
38 | // |
39 | // let chunk_address = cp >> 6; |
40 | // let chunk_bit = cp & 0b111111; |
41 | // let chunk = tree1[cp >> 6]; |
42 | // let is_member = 1 == ((chunk >> chunk_bit) & 1); |
43 | // |
44 | // We do something similar for the second partition: |
45 | // |
46 | // // we subtract 0x20 since (0x800 >> 6) == 0x20. |
47 | // let child_address = (cp >> 6) - 0x20; |
48 | // let chunk_address = tree2_level1[child_address]; |
49 | // let chunk_bit = cp & 0b111111; |
50 | // let chunk = tree2_level2[chunk_address]; |
51 | // let is_member = 1 == ((chunk >> chunk_bit) & 1); |
52 | // |
53 | // And so on for the third partition. |
54 | // |
55 | // Note that as a special case, if the second or third partitions are empty, |
56 | // then the trie will store empty slices for those levels. The `contains` |
57 | // check knows to return `false` in those cases. |
58 | |
59 | const CHUNKS: usize = 0x110000 / CHUNK_SIZE; |
60 | |
61 | /// A type alias that maps to `std::result::Result<T, ucd_trie::Error>`. |
62 | pub type Result<T> = result::Result<T, Error>; |
63 | |
64 | /// An error that can occur during construction of a trie. |
65 | #[derive (Clone, Debug)] |
66 | pub enum Error { |
67 | /// This error is returned when an invalid codepoint is given to |
68 | /// `TrieSetOwned::from_codepoints`. An invalid codepoint is a `u32` that |
69 | /// is greater than `0x10FFFF`. |
70 | InvalidCodepoint(u32), |
71 | /// This error is returned when a set of Unicode codepoints could not be |
72 | /// sufficiently compressed into the trie provided by this crate. There is |
73 | /// no work-around for this error at this time. |
74 | GaveUp, |
75 | } |
76 | |
77 | impl error::Error for Error {} |
78 | |
79 | impl fmt::Display for Error { |
80 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
81 | match *self { |
82 | Error::InvalidCodepoint(cp: u32) => write!( |
83 | f, |
84 | "could not construct trie set containing an \ |
85 | invalid Unicode codepoint: 0x {:X}" , |
86 | cp |
87 | ), |
88 | Error::GaveUp => { |
89 | write!(f, "could not compress codepoint set into a trie" ) |
90 | } |
91 | } |
92 | } |
93 | } |
94 | |
95 | impl From<Error> for io::Error { |
96 | fn from(err: Error) -> io::Error { |
97 | io::Error::new(kind:io::ErrorKind::Other, error:err) |
98 | } |
99 | } |
100 | |
101 | /// An owned trie set. |
102 | #[derive (Clone)] |
103 | pub struct TrieSetOwned { |
104 | tree1_level1: Vec<u64>, |
105 | tree2_level1: Vec<u8>, |
106 | tree2_level2: Vec<u64>, |
107 | tree3_level1: Vec<u8>, |
108 | tree3_level2: Vec<u8>, |
109 | tree3_level3: Vec<u64>, |
110 | } |
111 | |
112 | impl fmt::Debug for TrieSetOwned { |
113 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
114 | write!(f, "TrieSetOwned(...)" ) |
115 | } |
116 | } |
117 | |
118 | impl TrieSetOwned { |
119 | fn new(all: &[bool]) -> Result<TrieSetOwned> { |
120 | let mut bitvectors = Vec::with_capacity(CHUNKS); |
121 | for i in 0..CHUNKS { |
122 | let mut bitvector = 0u64; |
123 | for j in 0..CHUNK_SIZE { |
124 | if all[i * CHUNK_SIZE + j] { |
125 | bitvector |= 1 << j; |
126 | } |
127 | } |
128 | bitvectors.push(bitvector); |
129 | } |
130 | |
131 | let tree1_level1 = |
132 | bitvectors.iter().cloned().take(0x800 / CHUNK_SIZE).collect(); |
133 | |
134 | let (mut tree2_level1, mut tree2_level2) = compress_postfix_leaves( |
135 | &bitvectors[0x800 / CHUNK_SIZE..0x10000 / CHUNK_SIZE], |
136 | )?; |
137 | if tree2_level2.len() == 1 && tree2_level2[0] == 0 { |
138 | tree2_level1.clear(); |
139 | tree2_level2.clear(); |
140 | } |
141 | |
142 | let (mid, mut tree3_level3) = compress_postfix_leaves( |
143 | &bitvectors[0x10000 / CHUNK_SIZE..0x110000 / CHUNK_SIZE], |
144 | )?; |
145 | let (mut tree3_level1, mut tree3_level2) = |
146 | compress_postfix_mid(&mid, 64)?; |
147 | if tree3_level3.len() == 1 && tree3_level3[0] == 0 { |
148 | tree3_level1.clear(); |
149 | tree3_level2.clear(); |
150 | tree3_level3.clear(); |
151 | } |
152 | |
153 | Ok(TrieSetOwned { |
154 | tree1_level1, |
155 | tree2_level1, |
156 | tree2_level2, |
157 | tree3_level1, |
158 | tree3_level2, |
159 | tree3_level3, |
160 | }) |
161 | } |
162 | |
163 | /// Create a new trie set from a set of Unicode scalar values. |
164 | /// |
165 | /// This returns an error if a set could not be sufficiently compressed to |
166 | /// fit into a trie. |
167 | pub fn from_scalars<I, C>(scalars: I) -> Result<TrieSetOwned> |
168 | where |
169 | I: IntoIterator<Item = C>, |
170 | C: Borrow<char>, |
171 | { |
172 | let mut all = vec![false; 0x110000]; |
173 | for s in scalars { |
174 | all[*s.borrow() as usize] = true; |
175 | } |
176 | TrieSetOwned::new(&all) |
177 | } |
178 | |
179 | /// Create a new trie set from a set of Unicode scalar values. |
180 | /// |
181 | /// This returns an error if a set could not be sufficiently compressed to |
182 | /// fit into a trie. This also returns an error if any of the given |
183 | /// codepoints are greater than `0x10FFFF`. |
184 | pub fn from_codepoints<I, C>(codepoints: I) -> Result<TrieSetOwned> |
185 | where |
186 | I: IntoIterator<Item = C>, |
187 | C: Borrow<u32>, |
188 | { |
189 | let mut all = vec![false; 0x110000]; |
190 | for cp in codepoints { |
191 | let cp = *cp.borrow(); |
192 | if cp > 0x10FFFF { |
193 | return Err(Error::InvalidCodepoint(cp)); |
194 | } |
195 | all[cp as usize] = true; |
196 | } |
197 | TrieSetOwned::new(&all) |
198 | } |
199 | |
200 | /// Return this set as a slice. |
201 | #[inline (always)] |
202 | pub fn as_slice(&self) -> TrieSetSlice<'_> { |
203 | TrieSetSlice { |
204 | tree1_level1: &self.tree1_level1, |
205 | tree2_level1: &self.tree2_level1, |
206 | tree2_level2: &self.tree2_level2, |
207 | tree3_level1: &self.tree3_level1, |
208 | tree3_level2: &self.tree3_level2, |
209 | tree3_level3: &self.tree3_level3, |
210 | } |
211 | } |
212 | |
213 | /// Returns true if and only if the given Unicode scalar value is in this |
214 | /// set. |
215 | pub fn contains_char(&self, c: char) -> bool { |
216 | self.as_slice().contains_char(c) |
217 | } |
218 | |
219 | /// Returns true if and only if the given codepoint is in this set. |
220 | /// |
221 | /// If the given value exceeds the codepoint range (i.e., it's greater |
222 | /// than `0x10FFFF`), then this returns false. |
223 | pub fn contains_u32(&self, cp: u32) -> bool { |
224 | self.as_slice().contains_u32(cp) |
225 | } |
226 | } |
227 | |
228 | fn compress_postfix_leaves(chunks: &[u64]) -> Result<(Vec<u8>, Vec<u64>)> { |
229 | let mut root: Vec = vec![]; |
230 | let mut children: Vec = vec![]; |
231 | let mut bychild: HashMap = HashMap::new(); |
232 | for &chunk: u64 in chunks { |
233 | if !bychild.contains_key(&chunk) { |
234 | let start: usize = bychild.len(); |
235 | if start > ::std::u8::MAX as usize { |
236 | return Err(Error::GaveUp); |
237 | } |
238 | bychild.insert(k:chunk, v:start as u8); |
239 | children.push(chunk); |
240 | } |
241 | root.push(bychild[&chunk]); |
242 | } |
243 | Ok((root, children)) |
244 | } |
245 | |
246 | fn compress_postfix_mid( |
247 | chunks: &[u8], |
248 | chunk_size: usize, |
249 | ) -> Result<(Vec<u8>, Vec<u8>)> { |
250 | let mut root: Vec = vec![]; |
251 | let mut children: Vec = vec![]; |
252 | let mut bychild: HashMap<&[u8], u8> = HashMap::new(); |
253 | for i: usize in 0..(chunks.len() / chunk_size) { |
254 | let chunk: &[u8] = &chunks[i * chunk_size..(i + 1) * chunk_size]; |
255 | if !bychild.contains_key(chunk) { |
256 | let start: usize = bychild.len(); |
257 | if start > ::std::u8::MAX as usize { |
258 | return Err(Error::GaveUp); |
259 | } |
260 | bychild.insert(k:chunk, v:start as u8); |
261 | children.extend(iter:chunk); |
262 | } |
263 | root.push(bychild[chunk]); |
264 | } |
265 | Ok((root, children)) |
266 | } |
267 | |
268 | #[cfg (test)] |
269 | mod tests { |
270 | use super::TrieSetOwned; |
271 | use crate::general_category; |
272 | use std::collections::HashSet; |
273 | |
274 | fn mk(scalars: &[char]) -> TrieSetOwned { |
275 | TrieSetOwned::from_scalars(scalars).unwrap() |
276 | } |
277 | |
278 | fn ranges_to_set(ranges: &[(u32, u32)]) -> Vec<u32> { |
279 | let mut set = vec![]; |
280 | for &(start, end) in ranges { |
281 | for cp in start..end + 1 { |
282 | set.push(cp); |
283 | } |
284 | } |
285 | set |
286 | } |
287 | |
288 | #[test ] |
289 | fn set1() { |
290 | let set = mk(&['a' ]); |
291 | assert!(set.contains_char('a' )); |
292 | assert!(!set.contains_char('b' )); |
293 | assert!(!set.contains_char('β' )); |
294 | assert!(!set.contains_char('☃' )); |
295 | assert!(!set.contains_char('😼' )); |
296 | } |
297 | |
298 | #[test ] |
299 | fn set_combined() { |
300 | let set = mk(&['a' , 'b' , 'β' , '☃' , '😼' ]); |
301 | assert!(set.contains_char('a' )); |
302 | assert!(set.contains_char('b' )); |
303 | assert!(set.contains_char('β' )); |
304 | assert!(set.contains_char('☃' )); |
305 | assert!(set.contains_char('😼' )); |
306 | |
307 | assert!(!set.contains_char('c' )); |
308 | assert!(!set.contains_char('θ' )); |
309 | assert!(!set.contains_char('⛇' )); |
310 | assert!(!set.contains_char('🐲' )); |
311 | } |
312 | |
313 | // Basic tests on all of the general category sets. We check that |
314 | // membership is correct on every Unicode codepoint... because we can. |
315 | |
316 | macro_rules! category_test { |
317 | ($name:ident, $ranges:ident) => { |
318 | #[test] |
319 | fn $name() { |
320 | let set = ranges_to_set(general_category::$ranges); |
321 | let hashset: HashSet<u32> = set.iter().cloned().collect(); |
322 | let trie = TrieSetOwned::from_codepoints(&set).unwrap(); |
323 | for cp in 0..0x110000 { |
324 | assert!(trie.contains_u32(cp) == hashset.contains(&cp)); |
325 | } |
326 | // Test that an invalid codepoint is treated correctly. |
327 | assert!(!trie.contains_u32(0x110000)); |
328 | assert!(!hashset.contains(&0x110000)); |
329 | } |
330 | }; |
331 | } |
332 | |
333 | category_test!(gencat_cased_letter, CASED_LETTER); |
334 | category_test!(gencat_close_punctuation, CLOSE_PUNCTUATION); |
335 | category_test!(gencat_connector_punctuation, CONNECTOR_PUNCTUATION); |
336 | category_test!(gencat_control, CONTROL); |
337 | category_test!(gencat_currency_symbol, CURRENCY_SYMBOL); |
338 | category_test!(gencat_dash_punctuation, DASH_PUNCTUATION); |
339 | category_test!(gencat_decimal_number, DECIMAL_NUMBER); |
340 | category_test!(gencat_enclosing_mark, ENCLOSING_MARK); |
341 | category_test!(gencat_final_punctuation, FINAL_PUNCTUATION); |
342 | category_test!(gencat_format, FORMAT); |
343 | category_test!(gencat_initial_punctuation, INITIAL_PUNCTUATION); |
344 | category_test!(gencat_letter, LETTER); |
345 | category_test!(gencat_letter_number, LETTER_NUMBER); |
346 | category_test!(gencat_line_separator, LINE_SEPARATOR); |
347 | category_test!(gencat_lowercase_letter, LOWERCASE_LETTER); |
348 | category_test!(gencat_math_symbol, MATH_SYMBOL); |
349 | category_test!(gencat_mark, MARK); |
350 | category_test!(gencat_modifier_letter, MODIFIER_LETTER); |
351 | category_test!(gencat_modifier_symbol, MODIFIER_SYMBOL); |
352 | category_test!(gencat_nonspacing_mark, NONSPACING_MARK); |
353 | category_test!(gencat_number, NUMBER); |
354 | category_test!(gencat_open_punctuation, OPEN_PUNCTUATION); |
355 | category_test!(gencat_other, OTHER); |
356 | category_test!(gencat_other_letter, OTHER_LETTER); |
357 | category_test!(gencat_other_number, OTHER_NUMBER); |
358 | category_test!(gencat_other_punctuation, OTHER_PUNCTUATION); |
359 | category_test!(gencat_other_symbol, OTHER_SYMBOL); |
360 | category_test!(gencat_paragraph_separator, PARAGRAPH_SEPARATOR); |
361 | category_test!(gencat_private_use, PRIVATE_USE); |
362 | category_test!(gencat_punctuation, PUNCTUATION); |
363 | category_test!(gencat_separator, SEPARATOR); |
364 | category_test!(gencat_space_separator, SPACE_SEPARATOR); |
365 | category_test!(gencat_spacing_mark, SPACING_MARK); |
366 | category_test!(gencat_surrogate, SURROGATE); |
367 | category_test!(gencat_symbol, SYMBOL); |
368 | category_test!(gencat_titlecase_letter, TITLECASE_LETTER); |
369 | category_test!(gencat_unassigned, UNASSIGNED); |
370 | category_test!(gencat_uppercase_letter, UPPERCASE_LETTER); |
371 | } |
372 | |