1use std::borrow::Borrow;
2use std::collections::HashMap;
3use std::error;
4use std::fmt;
5use std::io;
6use std::result;
7
8use super::{TrieSetSlice, CHUNK_SIZE};
9
10// This implementation was pretty much cribbed from raphlinus' contribution
11// to the standard library: https://github.com/rust-lang/rust/pull/33098/files
12//
13// The fundamental principle guiding this implementation is to take advantage
14// of the fact that similar Unicode codepoints are often grouped together, and
15// that most boolean Unicode properties are quite sparse over the entire space
16// of Unicode codepoints.
17//
18// To do this, we represent sets using something like a trie (which gives us
19// prefix compression). The "final" states of the trie are embedded in leaves
20// or "chunks," where each chunk is a 64 bit integer. Each bit position of the
21// integer corresponds to whether a particular codepoint is in the set or not.
22// These chunks are not just a compact representation of the final states of
23// the trie, but are also a form of suffix compression. In particular, if
24// multiple ranges of 64 contiguous codepoints map have the same set membership
25// ordering, then they all map to the exact same chunk in the trie.
26//
27// We organize this structure by partitioning the space of Unicode codepoints
28// into three disjoint sets. The first set corresponds to codepoints
29// [0, 0x800), the second [0x800, 0x1000) and the third [0x10000, 0x110000).
30// These partitions conveniently correspond to the space of 1 or 2 byte UTF-8
31// encoded codepoints, 3 byte UTF-8 encoded codepoints and 4 byte UTF-8 encoded
32// codepoints, respectively.
33//
34// Each partition has its own tree with its own root. The first partition is
35// the simplest, since the tree is completely flat. In particular, to determine
36// the set membership of a Unicode codepoint (that is less than `0x800`), we
37// do the following (where `cp` is the codepoint we're testing):
38//
39// let chunk_address = cp >> 6;
40// let chunk_bit = cp & 0b111111;
41// let chunk = tree1[cp >> 6];
42// let is_member = 1 == ((chunk >> chunk_bit) & 1);
43//
44// We do something similar for the second partition:
45//
46// // we subtract 0x20 since (0x800 >> 6) == 0x20.
47// let child_address = (cp >> 6) - 0x20;
48// let chunk_address = tree2_level1[child_address];
49// let chunk_bit = cp & 0b111111;
50// let chunk = tree2_level2[chunk_address];
51// let is_member = 1 == ((chunk >> chunk_bit) & 1);
52//
53// And so on for the third partition.
54//
55// Note that as a special case, if the second or third partitions are empty,
56// then the trie will store empty slices for those levels. The `contains`
57// check knows to return `false` in those cases.
58
59const CHUNKS: usize = 0x110000 / CHUNK_SIZE;
60
61/// A type alias that maps to `std::result::Result<T, ucd_trie::Error>`.
62pub type Result<T> = result::Result<T, Error>;
63
64/// An error that can occur during construction of a trie.
65#[derive(Clone, Debug)]
66pub enum Error {
67 /// This error is returned when an invalid codepoint is given to
68 /// `TrieSetOwned::from_codepoints`. An invalid codepoint is a `u32` that
69 /// is greater than `0x10FFFF`.
70 InvalidCodepoint(u32),
71 /// This error is returned when a set of Unicode codepoints could not be
72 /// sufficiently compressed into the trie provided by this crate. There is
73 /// no work-around for this error at this time.
74 GaveUp,
75}
76
77impl error::Error for Error {}
78
79impl fmt::Display for Error {
80 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81 match *self {
82 Error::InvalidCodepoint(cp: u32) => write!(
83 f,
84 "could not construct trie set containing an \
85 invalid Unicode codepoint: 0x{:X}",
86 cp
87 ),
88 Error::GaveUp => {
89 write!(f, "could not compress codepoint set into a trie")
90 }
91 }
92 }
93}
94
95impl From<Error> for io::Error {
96 fn from(err: Error) -> io::Error {
97 io::Error::new(kind:io::ErrorKind::Other, error:err)
98 }
99}
100
101/// An owned trie set.
102#[derive(Clone)]
103pub struct TrieSetOwned {
104 tree1_level1: Vec<u64>,
105 tree2_level1: Vec<u8>,
106 tree2_level2: Vec<u64>,
107 tree3_level1: Vec<u8>,
108 tree3_level2: Vec<u8>,
109 tree3_level3: Vec<u64>,
110}
111
112impl fmt::Debug for TrieSetOwned {
113 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114 write!(f, "TrieSetOwned(...)")
115 }
116}
117
118impl TrieSetOwned {
119 fn new(all: &[bool]) -> Result<TrieSetOwned> {
120 let mut bitvectors = Vec::with_capacity(CHUNKS);
121 for i in 0..CHUNKS {
122 let mut bitvector = 0u64;
123 for j in 0..CHUNK_SIZE {
124 if all[i * CHUNK_SIZE + j] {
125 bitvector |= 1 << j;
126 }
127 }
128 bitvectors.push(bitvector);
129 }
130
131 let tree1_level1 =
132 bitvectors.iter().cloned().take(0x800 / CHUNK_SIZE).collect();
133
134 let (mut tree2_level1, mut tree2_level2) = compress_postfix_leaves(
135 &bitvectors[0x800 / CHUNK_SIZE..0x10000 / CHUNK_SIZE],
136 )?;
137 if tree2_level2.len() == 1 && tree2_level2[0] == 0 {
138 tree2_level1.clear();
139 tree2_level2.clear();
140 }
141
142 let (mid, mut tree3_level3) = compress_postfix_leaves(
143 &bitvectors[0x10000 / CHUNK_SIZE..0x110000 / CHUNK_SIZE],
144 )?;
145 let (mut tree3_level1, mut tree3_level2) =
146 compress_postfix_mid(&mid, 64)?;
147 if tree3_level3.len() == 1 && tree3_level3[0] == 0 {
148 tree3_level1.clear();
149 tree3_level2.clear();
150 tree3_level3.clear();
151 }
152
153 Ok(TrieSetOwned {
154 tree1_level1,
155 tree2_level1,
156 tree2_level2,
157 tree3_level1,
158 tree3_level2,
159 tree3_level3,
160 })
161 }
162
163 /// Create a new trie set from a set of Unicode scalar values.
164 ///
165 /// This returns an error if a set could not be sufficiently compressed to
166 /// fit into a trie.
167 pub fn from_scalars<I, C>(scalars: I) -> Result<TrieSetOwned>
168 where
169 I: IntoIterator<Item = C>,
170 C: Borrow<char>,
171 {
172 let mut all = vec![false; 0x110000];
173 for s in scalars {
174 all[*s.borrow() as usize] = true;
175 }
176 TrieSetOwned::new(&all)
177 }
178
179 /// Create a new trie set from a set of Unicode scalar values.
180 ///
181 /// This returns an error if a set could not be sufficiently compressed to
182 /// fit into a trie. This also returns an error if any of the given
183 /// codepoints are greater than `0x10FFFF`.
184 pub fn from_codepoints<I, C>(codepoints: I) -> Result<TrieSetOwned>
185 where
186 I: IntoIterator<Item = C>,
187 C: Borrow<u32>,
188 {
189 let mut all = vec![false; 0x110000];
190 for cp in codepoints {
191 let cp = *cp.borrow();
192 if cp > 0x10FFFF {
193 return Err(Error::InvalidCodepoint(cp));
194 }
195 all[cp as usize] = true;
196 }
197 TrieSetOwned::new(&all)
198 }
199
200 /// Return this set as a slice.
201 #[inline(always)]
202 pub fn as_slice(&self) -> TrieSetSlice<'_> {
203 TrieSetSlice {
204 tree1_level1: &self.tree1_level1,
205 tree2_level1: &self.tree2_level1,
206 tree2_level2: &self.tree2_level2,
207 tree3_level1: &self.tree3_level1,
208 tree3_level2: &self.tree3_level2,
209 tree3_level3: &self.tree3_level3,
210 }
211 }
212
213 /// Returns true if and only if the given Unicode scalar value is in this
214 /// set.
215 pub fn contains_char(&self, c: char) -> bool {
216 self.as_slice().contains_char(c)
217 }
218
219 /// Returns true if and only if the given codepoint is in this set.
220 ///
221 /// If the given value exceeds the codepoint range (i.e., it's greater
222 /// than `0x10FFFF`), then this returns false.
223 pub fn contains_u32(&self, cp: u32) -> bool {
224 self.as_slice().contains_u32(cp)
225 }
226}
227
228fn compress_postfix_leaves(chunks: &[u64]) -> Result<(Vec<u8>, Vec<u64>)> {
229 let mut root: Vec = vec![];
230 let mut children: Vec = vec![];
231 let mut bychild: HashMap = HashMap::new();
232 for &chunk: u64 in chunks {
233 if !bychild.contains_key(&chunk) {
234 let start: usize = bychild.len();
235 if start > ::std::u8::MAX as usize {
236 return Err(Error::GaveUp);
237 }
238 bychild.insert(k:chunk, v:start as u8);
239 children.push(chunk);
240 }
241 root.push(bychild[&chunk]);
242 }
243 Ok((root, children))
244}
245
246fn compress_postfix_mid(
247 chunks: &[u8],
248 chunk_size: usize,
249) -> Result<(Vec<u8>, Vec<u8>)> {
250 let mut root: Vec = vec![];
251 let mut children: Vec = vec![];
252 let mut bychild: HashMap<&[u8], u8> = HashMap::new();
253 for i: usize in 0..(chunks.len() / chunk_size) {
254 let chunk: &[u8] = &chunks[i * chunk_size..(i + 1) * chunk_size];
255 if !bychild.contains_key(chunk) {
256 let start: usize = bychild.len();
257 if start > ::std::u8::MAX as usize {
258 return Err(Error::GaveUp);
259 }
260 bychild.insert(k:chunk, v:start as u8);
261 children.extend(iter:chunk);
262 }
263 root.push(bychild[chunk]);
264 }
265 Ok((root, children))
266}
267
268#[cfg(test)]
269mod tests {
270 use super::TrieSetOwned;
271 use crate::general_category;
272 use std::collections::HashSet;
273
274 fn mk(scalars: &[char]) -> TrieSetOwned {
275 TrieSetOwned::from_scalars(scalars).unwrap()
276 }
277
278 fn ranges_to_set(ranges: &[(u32, u32)]) -> Vec<u32> {
279 let mut set = vec![];
280 for &(start, end) in ranges {
281 for cp in start..end + 1 {
282 set.push(cp);
283 }
284 }
285 set
286 }
287
288 #[test]
289 fn set1() {
290 let set = mk(&['a']);
291 assert!(set.contains_char('a'));
292 assert!(!set.contains_char('b'));
293 assert!(!set.contains_char('β'));
294 assert!(!set.contains_char('☃'));
295 assert!(!set.contains_char('😼'));
296 }
297
298 #[test]
299 fn set_combined() {
300 let set = mk(&['a', 'b', 'β', '☃', '😼']);
301 assert!(set.contains_char('a'));
302 assert!(set.contains_char('b'));
303 assert!(set.contains_char('β'));
304 assert!(set.contains_char('☃'));
305 assert!(set.contains_char('😼'));
306
307 assert!(!set.contains_char('c'));
308 assert!(!set.contains_char('θ'));
309 assert!(!set.contains_char('⛇'));
310 assert!(!set.contains_char('🐲'));
311 }
312
313 // Basic tests on all of the general category sets. We check that
314 // membership is correct on every Unicode codepoint... because we can.
315
316 macro_rules! category_test {
317 ($name:ident, $ranges:ident) => {
318 #[test]
319 fn $name() {
320 let set = ranges_to_set(general_category::$ranges);
321 let hashset: HashSet<u32> = set.iter().cloned().collect();
322 let trie = TrieSetOwned::from_codepoints(&set).unwrap();
323 for cp in 0..0x110000 {
324 assert!(trie.contains_u32(cp) == hashset.contains(&cp));
325 }
326 // Test that an invalid codepoint is treated correctly.
327 assert!(!trie.contains_u32(0x110000));
328 assert!(!hashset.contains(&0x110000));
329 }
330 };
331 }
332
333 category_test!(gencat_cased_letter, CASED_LETTER);
334 category_test!(gencat_close_punctuation, CLOSE_PUNCTUATION);
335 category_test!(gencat_connector_punctuation, CONNECTOR_PUNCTUATION);
336 category_test!(gencat_control, CONTROL);
337 category_test!(gencat_currency_symbol, CURRENCY_SYMBOL);
338 category_test!(gencat_dash_punctuation, DASH_PUNCTUATION);
339 category_test!(gencat_decimal_number, DECIMAL_NUMBER);
340 category_test!(gencat_enclosing_mark, ENCLOSING_MARK);
341 category_test!(gencat_final_punctuation, FINAL_PUNCTUATION);
342 category_test!(gencat_format, FORMAT);
343 category_test!(gencat_initial_punctuation, INITIAL_PUNCTUATION);
344 category_test!(gencat_letter, LETTER);
345 category_test!(gencat_letter_number, LETTER_NUMBER);
346 category_test!(gencat_line_separator, LINE_SEPARATOR);
347 category_test!(gencat_lowercase_letter, LOWERCASE_LETTER);
348 category_test!(gencat_math_symbol, MATH_SYMBOL);
349 category_test!(gencat_mark, MARK);
350 category_test!(gencat_modifier_letter, MODIFIER_LETTER);
351 category_test!(gencat_modifier_symbol, MODIFIER_SYMBOL);
352 category_test!(gencat_nonspacing_mark, NONSPACING_MARK);
353 category_test!(gencat_number, NUMBER);
354 category_test!(gencat_open_punctuation, OPEN_PUNCTUATION);
355 category_test!(gencat_other, OTHER);
356 category_test!(gencat_other_letter, OTHER_LETTER);
357 category_test!(gencat_other_number, OTHER_NUMBER);
358 category_test!(gencat_other_punctuation, OTHER_PUNCTUATION);
359 category_test!(gencat_other_symbol, OTHER_SYMBOL);
360 category_test!(gencat_paragraph_separator, PARAGRAPH_SEPARATOR);
361 category_test!(gencat_private_use, PRIVATE_USE);
362 category_test!(gencat_punctuation, PUNCTUATION);
363 category_test!(gencat_separator, SEPARATOR);
364 category_test!(gencat_space_separator, SPACE_SEPARATOR);
365 category_test!(gencat_spacing_mark, SPACING_MARK);
366 category_test!(gencat_surrogate, SURROGATE);
367 category_test!(gencat_symbol, SYMBOL);
368 category_test!(gencat_titlecase_letter, TITLECASE_LETTER);
369 category_test!(gencat_unassigned, UNASSIGNED);
370 category_test!(gencat_uppercase_letter, UPPERCASE_LETTER);
371}
372