| 1 | use std::borrow::Borrow; |
| 2 | use std::collections::HashMap; |
| 3 | use std::error; |
| 4 | use std::fmt; |
| 5 | use std::io; |
| 6 | use std::result; |
| 7 | |
| 8 | use super::{TrieSetSlice, CHUNK_SIZE}; |
| 9 | |
| 10 | // This implementation was pretty much cribbed from raphlinus' contribution |
| 11 | // to the standard library: https://github.com/rust-lang/rust/pull/33098/files |
| 12 | // |
| 13 | // The fundamental principle guiding this implementation is to take advantage |
| 14 | // of the fact that similar Unicode codepoints are often grouped together, and |
| 15 | // that most boolean Unicode properties are quite sparse over the entire space |
| 16 | // of Unicode codepoints. |
| 17 | // |
| 18 | // To do this, we represent sets using something like a trie (which gives us |
| 19 | // prefix compression). The "final" states of the trie are embedded in leaves |
| 20 | // or "chunks," where each chunk is a 64 bit integer. Each bit position of the |
| 21 | // integer corresponds to whether a particular codepoint is in the set or not. |
| 22 | // These chunks are not just a compact representation of the final states of |
| 23 | // the trie, but are also a form of suffix compression. In particular, if |
| 24 | // multiple ranges of 64 contiguous codepoints map have the same set membership |
| 25 | // ordering, then they all map to the exact same chunk in the trie. |
| 26 | // |
| 27 | // We organize this structure by partitioning the space of Unicode codepoints |
| 28 | // into three disjoint sets. The first set corresponds to codepoints |
| 29 | // [0, 0x800), the second [0x800, 0x1000) and the third [0x10000, 0x110000). |
| 30 | // These partitions conveniently correspond to the space of 1 or 2 byte UTF-8 |
| 31 | // encoded codepoints, 3 byte UTF-8 encoded codepoints and 4 byte UTF-8 encoded |
| 32 | // codepoints, respectively. |
| 33 | // |
| 34 | // Each partition has its own tree with its own root. The first partition is |
| 35 | // the simplest, since the tree is completely flat. In particular, to determine |
| 36 | // the set membership of a Unicode codepoint (that is less than `0x800`), we |
| 37 | // do the following (where `cp` is the codepoint we're testing): |
| 38 | // |
| 39 | // let chunk_address = cp >> 6; |
| 40 | // let chunk_bit = cp & 0b111111; |
| 41 | // let chunk = tree1[cp >> 6]; |
| 42 | // let is_member = 1 == ((chunk >> chunk_bit) & 1); |
| 43 | // |
| 44 | // We do something similar for the second partition: |
| 45 | // |
| 46 | // // we subtract 0x20 since (0x800 >> 6) == 0x20. |
| 47 | // let child_address = (cp >> 6) - 0x20; |
| 48 | // let chunk_address = tree2_level1[child_address]; |
| 49 | // let chunk_bit = cp & 0b111111; |
| 50 | // let chunk = tree2_level2[chunk_address]; |
| 51 | // let is_member = 1 == ((chunk >> chunk_bit) & 1); |
| 52 | // |
| 53 | // And so on for the third partition. |
| 54 | // |
| 55 | // Note that as a special case, if the second or third partitions are empty, |
| 56 | // then the trie will store empty slices for those levels. The `contains` |
| 57 | // check knows to return `false` in those cases. |
| 58 | |
| 59 | const CHUNKS: usize = 0x110000 / CHUNK_SIZE; |
| 60 | |
| 61 | /// A type alias that maps to `std::result::Result<T, ucd_trie::Error>`. |
| 62 | pub type Result<T> = result::Result<T, Error>; |
| 63 | |
| 64 | /// An error that can occur during construction of a trie. |
| 65 | #[derive (Clone, Debug)] |
| 66 | pub enum Error { |
| 67 | /// This error is returned when an invalid codepoint is given to |
| 68 | /// `TrieSetOwned::from_codepoints`. An invalid codepoint is a `u32` that |
| 69 | /// is greater than `0x10FFFF`. |
| 70 | InvalidCodepoint(u32), |
| 71 | /// This error is returned when a set of Unicode codepoints could not be |
| 72 | /// sufficiently compressed into the trie provided by this crate. There is |
| 73 | /// no work-around for this error at this time. |
| 74 | GaveUp, |
| 75 | } |
| 76 | |
| 77 | impl error::Error for Error {} |
| 78 | |
| 79 | impl fmt::Display for Error { |
| 80 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 81 | match *self { |
| 82 | Error::InvalidCodepoint(cp: u32) => write!( |
| 83 | f, |
| 84 | "could not construct trie set containing an \ |
| 85 | invalid Unicode codepoint: 0x {:X}" , |
| 86 | cp |
| 87 | ), |
| 88 | Error::GaveUp => { |
| 89 | write!(f, "could not compress codepoint set into a trie" ) |
| 90 | } |
| 91 | } |
| 92 | } |
| 93 | } |
| 94 | |
| 95 | impl From<Error> for io::Error { |
| 96 | fn from(err: Error) -> io::Error { |
| 97 | io::Error::new(kind:io::ErrorKind::Other, error:err) |
| 98 | } |
| 99 | } |
| 100 | |
| 101 | /// An owned trie set. |
| 102 | #[derive (Clone)] |
| 103 | pub struct TrieSetOwned { |
| 104 | tree1_level1: Vec<u64>, |
| 105 | tree2_level1: Vec<u8>, |
| 106 | tree2_level2: Vec<u64>, |
| 107 | tree3_level1: Vec<u8>, |
| 108 | tree3_level2: Vec<u8>, |
| 109 | tree3_level3: Vec<u64>, |
| 110 | } |
| 111 | |
| 112 | impl fmt::Debug for TrieSetOwned { |
| 113 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 114 | write!(f, "TrieSetOwned(...)" ) |
| 115 | } |
| 116 | } |
| 117 | |
| 118 | impl TrieSetOwned { |
| 119 | fn new(all: &[bool]) -> Result<TrieSetOwned> { |
| 120 | let mut bitvectors = Vec::with_capacity(CHUNKS); |
| 121 | for i in 0..CHUNKS { |
| 122 | let mut bitvector = 0u64; |
| 123 | for j in 0..CHUNK_SIZE { |
| 124 | if all[i * CHUNK_SIZE + j] { |
| 125 | bitvector |= 1 << j; |
| 126 | } |
| 127 | } |
| 128 | bitvectors.push(bitvector); |
| 129 | } |
| 130 | |
| 131 | let tree1_level1 = |
| 132 | bitvectors.iter().cloned().take(0x800 / CHUNK_SIZE).collect(); |
| 133 | |
| 134 | let (mut tree2_level1, mut tree2_level2) = compress_postfix_leaves( |
| 135 | &bitvectors[0x800 / CHUNK_SIZE..0x10000 / CHUNK_SIZE], |
| 136 | )?; |
| 137 | if tree2_level2.len() == 1 && tree2_level2[0] == 0 { |
| 138 | tree2_level1.clear(); |
| 139 | tree2_level2.clear(); |
| 140 | } |
| 141 | |
| 142 | let (mid, mut tree3_level3) = compress_postfix_leaves( |
| 143 | &bitvectors[0x10000 / CHUNK_SIZE..0x110000 / CHUNK_SIZE], |
| 144 | )?; |
| 145 | let (mut tree3_level1, mut tree3_level2) = |
| 146 | compress_postfix_mid(&mid, 64)?; |
| 147 | if tree3_level3.len() == 1 && tree3_level3[0] == 0 { |
| 148 | tree3_level1.clear(); |
| 149 | tree3_level2.clear(); |
| 150 | tree3_level3.clear(); |
| 151 | } |
| 152 | |
| 153 | Ok(TrieSetOwned { |
| 154 | tree1_level1, |
| 155 | tree2_level1, |
| 156 | tree2_level2, |
| 157 | tree3_level1, |
| 158 | tree3_level2, |
| 159 | tree3_level3, |
| 160 | }) |
| 161 | } |
| 162 | |
| 163 | /// Create a new trie set from a set of Unicode scalar values. |
| 164 | /// |
| 165 | /// This returns an error if a set could not be sufficiently compressed to |
| 166 | /// fit into a trie. |
| 167 | pub fn from_scalars<I, C>(scalars: I) -> Result<TrieSetOwned> |
| 168 | where |
| 169 | I: IntoIterator<Item = C>, |
| 170 | C: Borrow<char>, |
| 171 | { |
| 172 | let mut all = vec![false; 0x110000]; |
| 173 | for s in scalars { |
| 174 | all[*s.borrow() as usize] = true; |
| 175 | } |
| 176 | TrieSetOwned::new(&all) |
| 177 | } |
| 178 | |
| 179 | /// Create a new trie set from a set of Unicode scalar values. |
| 180 | /// |
| 181 | /// This returns an error if a set could not be sufficiently compressed to |
| 182 | /// fit into a trie. This also returns an error if any of the given |
| 183 | /// codepoints are greater than `0x10FFFF`. |
| 184 | pub fn from_codepoints<I, C>(codepoints: I) -> Result<TrieSetOwned> |
| 185 | where |
| 186 | I: IntoIterator<Item = C>, |
| 187 | C: Borrow<u32>, |
| 188 | { |
| 189 | let mut all = vec![false; 0x110000]; |
| 190 | for cp in codepoints { |
| 191 | let cp = *cp.borrow(); |
| 192 | if cp > 0x10FFFF { |
| 193 | return Err(Error::InvalidCodepoint(cp)); |
| 194 | } |
| 195 | all[cp as usize] = true; |
| 196 | } |
| 197 | TrieSetOwned::new(&all) |
| 198 | } |
| 199 | |
| 200 | /// Return this set as a slice. |
| 201 | #[inline (always)] |
| 202 | pub fn as_slice(&self) -> TrieSetSlice<'_> { |
| 203 | TrieSetSlice { |
| 204 | tree1_level1: &self.tree1_level1, |
| 205 | tree2_level1: &self.tree2_level1, |
| 206 | tree2_level2: &self.tree2_level2, |
| 207 | tree3_level1: &self.tree3_level1, |
| 208 | tree3_level2: &self.tree3_level2, |
| 209 | tree3_level3: &self.tree3_level3, |
| 210 | } |
| 211 | } |
| 212 | |
| 213 | /// Returns true if and only if the given Unicode scalar value is in this |
| 214 | /// set. |
| 215 | pub fn contains_char(&self, c: char) -> bool { |
| 216 | self.as_slice().contains_char(c) |
| 217 | } |
| 218 | |
| 219 | /// Returns true if and only if the given codepoint is in this set. |
| 220 | /// |
| 221 | /// If the given value exceeds the codepoint range (i.e., it's greater |
| 222 | /// than `0x10FFFF`), then this returns false. |
| 223 | pub fn contains_u32(&self, cp: u32) -> bool { |
| 224 | self.as_slice().contains_u32(cp) |
| 225 | } |
| 226 | } |
| 227 | |
| 228 | fn compress_postfix_leaves(chunks: &[u64]) -> Result<(Vec<u8>, Vec<u64>)> { |
| 229 | let mut root: Vec = vec![]; |
| 230 | let mut children: Vec = vec![]; |
| 231 | let mut bychild: HashMap = HashMap::new(); |
| 232 | for &chunk: u64 in chunks { |
| 233 | if !bychild.contains_key(&chunk) { |
| 234 | let start: usize = bychild.len(); |
| 235 | if start > ::std::u8::MAX as usize { |
| 236 | return Err(Error::GaveUp); |
| 237 | } |
| 238 | bychild.insert(k:chunk, v:start as u8); |
| 239 | children.push(chunk); |
| 240 | } |
| 241 | root.push(bychild[&chunk]); |
| 242 | } |
| 243 | Ok((root, children)) |
| 244 | } |
| 245 | |
| 246 | fn compress_postfix_mid( |
| 247 | chunks: &[u8], |
| 248 | chunk_size: usize, |
| 249 | ) -> Result<(Vec<u8>, Vec<u8>)> { |
| 250 | let mut root: Vec = vec![]; |
| 251 | let mut children: Vec = vec![]; |
| 252 | let mut bychild: HashMap<&[u8], u8> = HashMap::new(); |
| 253 | for i: usize in 0..(chunks.len() / chunk_size) { |
| 254 | let chunk: &[u8] = &chunks[i * chunk_size..(i + 1) * chunk_size]; |
| 255 | if !bychild.contains_key(chunk) { |
| 256 | let start: usize = bychild.len(); |
| 257 | if start > ::std::u8::MAX as usize { |
| 258 | return Err(Error::GaveUp); |
| 259 | } |
| 260 | bychild.insert(k:chunk, v:start as u8); |
| 261 | children.extend(iter:chunk); |
| 262 | } |
| 263 | root.push(bychild[chunk]); |
| 264 | } |
| 265 | Ok((root, children)) |
| 266 | } |
| 267 | |
| 268 | #[cfg (test)] |
| 269 | mod tests { |
| 270 | use super::TrieSetOwned; |
| 271 | use crate::general_category; |
| 272 | use std::collections::HashSet; |
| 273 | |
| 274 | fn mk(scalars: &[char]) -> TrieSetOwned { |
| 275 | TrieSetOwned::from_scalars(scalars).unwrap() |
| 276 | } |
| 277 | |
| 278 | fn ranges_to_set(ranges: &[(u32, u32)]) -> Vec<u32> { |
| 279 | let mut set = vec![]; |
| 280 | for &(start, end) in ranges { |
| 281 | for cp in start..end + 1 { |
| 282 | set.push(cp); |
| 283 | } |
| 284 | } |
| 285 | set |
| 286 | } |
| 287 | |
| 288 | #[test ] |
| 289 | fn set1() { |
| 290 | let set = mk(&['a' ]); |
| 291 | assert!(set.contains_char('a' )); |
| 292 | assert!(!set.contains_char('b' )); |
| 293 | assert!(!set.contains_char('β' )); |
| 294 | assert!(!set.contains_char('☃' )); |
| 295 | assert!(!set.contains_char('😼' )); |
| 296 | } |
| 297 | |
| 298 | #[test ] |
| 299 | fn set_combined() { |
| 300 | let set = mk(&['a' , 'b' , 'β' , '☃' , '😼' ]); |
| 301 | assert!(set.contains_char('a' )); |
| 302 | assert!(set.contains_char('b' )); |
| 303 | assert!(set.contains_char('β' )); |
| 304 | assert!(set.contains_char('☃' )); |
| 305 | assert!(set.contains_char('😼' )); |
| 306 | |
| 307 | assert!(!set.contains_char('c' )); |
| 308 | assert!(!set.contains_char('θ' )); |
| 309 | assert!(!set.contains_char('⛇' )); |
| 310 | assert!(!set.contains_char('🐲' )); |
| 311 | } |
| 312 | |
| 313 | // Basic tests on all of the general category sets. We check that |
| 314 | // membership is correct on every Unicode codepoint... because we can. |
| 315 | |
| 316 | macro_rules! category_test { |
| 317 | ($name:ident, $ranges:ident) => { |
| 318 | #[test] |
| 319 | fn $name() { |
| 320 | let set = ranges_to_set(general_category::$ranges); |
| 321 | let hashset: HashSet<u32> = set.iter().cloned().collect(); |
| 322 | let trie = TrieSetOwned::from_codepoints(&set).unwrap(); |
| 323 | for cp in 0..0x110000 { |
| 324 | assert!(trie.contains_u32(cp) == hashset.contains(&cp)); |
| 325 | } |
| 326 | // Test that an invalid codepoint is treated correctly. |
| 327 | assert!(!trie.contains_u32(0x110000)); |
| 328 | assert!(!hashset.contains(&0x110000)); |
| 329 | } |
| 330 | }; |
| 331 | } |
| 332 | |
| 333 | category_test!(gencat_cased_letter, CASED_LETTER); |
| 334 | category_test!(gencat_close_punctuation, CLOSE_PUNCTUATION); |
| 335 | category_test!(gencat_connector_punctuation, CONNECTOR_PUNCTUATION); |
| 336 | category_test!(gencat_control, CONTROL); |
| 337 | category_test!(gencat_currency_symbol, CURRENCY_SYMBOL); |
| 338 | category_test!(gencat_dash_punctuation, DASH_PUNCTUATION); |
| 339 | category_test!(gencat_decimal_number, DECIMAL_NUMBER); |
| 340 | category_test!(gencat_enclosing_mark, ENCLOSING_MARK); |
| 341 | category_test!(gencat_final_punctuation, FINAL_PUNCTUATION); |
| 342 | category_test!(gencat_format, FORMAT); |
| 343 | category_test!(gencat_initial_punctuation, INITIAL_PUNCTUATION); |
| 344 | category_test!(gencat_letter, LETTER); |
| 345 | category_test!(gencat_letter_number, LETTER_NUMBER); |
| 346 | category_test!(gencat_line_separator, LINE_SEPARATOR); |
| 347 | category_test!(gencat_lowercase_letter, LOWERCASE_LETTER); |
| 348 | category_test!(gencat_math_symbol, MATH_SYMBOL); |
| 349 | category_test!(gencat_mark, MARK); |
| 350 | category_test!(gencat_modifier_letter, MODIFIER_LETTER); |
| 351 | category_test!(gencat_modifier_symbol, MODIFIER_SYMBOL); |
| 352 | category_test!(gencat_nonspacing_mark, NONSPACING_MARK); |
| 353 | category_test!(gencat_number, NUMBER); |
| 354 | category_test!(gencat_open_punctuation, OPEN_PUNCTUATION); |
| 355 | category_test!(gencat_other, OTHER); |
| 356 | category_test!(gencat_other_letter, OTHER_LETTER); |
| 357 | category_test!(gencat_other_number, OTHER_NUMBER); |
| 358 | category_test!(gencat_other_punctuation, OTHER_PUNCTUATION); |
| 359 | category_test!(gencat_other_symbol, OTHER_SYMBOL); |
| 360 | category_test!(gencat_paragraph_separator, PARAGRAPH_SEPARATOR); |
| 361 | category_test!(gencat_private_use, PRIVATE_USE); |
| 362 | category_test!(gencat_punctuation, PUNCTUATION); |
| 363 | category_test!(gencat_separator, SEPARATOR); |
| 364 | category_test!(gencat_space_separator, SPACE_SEPARATOR); |
| 365 | category_test!(gencat_spacing_mark, SPACING_MARK); |
| 366 | category_test!(gencat_surrogate, SURROGATE); |
| 367 | category_test!(gencat_symbol, SYMBOL); |
| 368 | category_test!(gencat_titlecase_letter, TITLECASE_LETTER); |
| 369 | category_test!(gencat_unassigned, UNASSIGNED); |
| 370 | category_test!(gencat_uppercase_letter, UPPERCASE_LETTER); |
| 371 | } |
| 372 | |