1 | use crate::util::int::Usize; |
---|---|

2 | |

3 | /// A representation of byte oriented equivalence classes. |

4 | /// |

5 | /// This is used in finite state machines to reduce the size of the transition |

6 | /// table. This can have a particularly large impact not only on the total size |

7 | /// of an FSM, but also on FSM build times because it reduces the number of |

8 | /// transitions that need to be visited/set. |

9 | #[derive(Clone, Copy)] |

10 | pub(crate) struct ByteClasses([u8; 256]); |

11 | |

12 | impl ByteClasses { |

13 | /// Creates a new set of equivalence classes where all bytes are mapped to |

14 | /// the same class. |

15 | pub(crate) fn empty() -> ByteClasses { |

16 | ByteClasses([0; 256]) |

17 | } |

18 | |

19 | /// Creates a new set of equivalence classes where each byte belongs to |

20 | /// its own equivalence class. |

21 | pub(crate) fn singletons() -> ByteClasses { |

22 | let mut classes = ByteClasses::empty(); |

23 | for b in 0..=255 { |

24 | classes.set(b, b); |

25 | } |

26 | classes |

27 | } |

28 | |

29 | /// Set the equivalence class for the given byte. |

30 | #[inline] |

31 | pub(crate) fn set(&mut self, byte: u8, class: u8) { |

32 | self.0[usize::from(byte)] = class; |

33 | } |

34 | |

35 | /// Get the equivalence class for the given byte. |

36 | #[inline] |

37 | pub(crate) fn get(&self, byte: u8) -> u8 { |

38 | self.0[usize::from(byte)] |

39 | } |

40 | |

41 | /// Return the total number of elements in the alphabet represented by |

42 | /// these equivalence classes. Equivalently, this returns the total number |

43 | /// of equivalence classes. |

44 | #[inline] |

45 | pub(crate) fn alphabet_len(&self) -> usize { |

46 | // Add one since the number of equivalence classes is one bigger than |

47 | // the last one. |

48 | usize::from(self.0[255]) + 1 |

49 | } |

50 | |

51 | /// Returns the stride, as a base-2 exponent, required for these |

52 | /// equivalence classes. |

53 | /// |

54 | /// The stride is always the smallest power of 2 that is greater than or |

55 | /// equal to the alphabet length. This is done so that converting between |

56 | /// state IDs and indices can be done with shifts alone, which is much |

57 | /// faster than integer division. The "stride2" is the exponent. i.e., |

58 | /// `2^stride2 = stride`. |

59 | pub(crate) fn stride2(&self) -> usize { |

60 | let zeros = self.alphabet_len().next_power_of_two().trailing_zeros(); |

61 | usize::try_from(zeros).unwrap() |

62 | } |

63 | |

64 | /// Returns the stride for these equivalence classes, which corresponds |

65 | /// to the smallest power of 2 greater than or equal to the number of |

66 | /// equivalence classes. |

67 | pub(crate) fn stride(&self) -> usize { |

68 | 1 << self.stride2() |

69 | } |

70 | |

71 | /// Returns true if and only if every byte in this class maps to its own |

72 | /// equivalence class. Equivalently, there are 257 equivalence classes |

73 | /// and each class contains exactly one byte (plus the special EOI class). |

74 | #[inline] |

75 | pub(crate) fn is_singleton(&self) -> bool { |

76 | self.alphabet_len() == 256 |

77 | } |

78 | |

79 | /// Returns an iterator over all equivalence classes in this set. |

80 | pub(crate) fn iter(&self) -> ByteClassIter { |

81 | ByteClassIter { it: 0..self.alphabet_len() } |

82 | } |

83 | |

84 | /// Returns an iterator of the bytes in the given equivalence class. |

85 | pub(crate) fn elements(&self, class: u8) -> ByteClassElements { |

86 | ByteClassElements { classes: self, class, bytes: 0..=255 } |

87 | } |

88 | |

89 | /// Returns an iterator of byte ranges in the given equivalence class. |

90 | /// |

91 | /// That is, a sequence of contiguous ranges are returned. Typically, every |

92 | /// class maps to a single contiguous range. |

93 | fn element_ranges(&self, class: u8) -> ByteClassElementRanges { |

94 | ByteClassElementRanges { elements: self.elements(class), range: None } |

95 | } |

96 | } |

97 | |

98 | impl core::fmt::Debug for ByteClasses { |

99 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { |

100 | if self.is_singleton() { |

101 | write!(f, "ByteClasses(<one-class-per-byte>)") |

102 | } else { |

103 | write!(f, "ByteClasses(")?; |

104 | for (i, class) in self.iter().enumerate() { |

105 | if i > 0 { |

106 | write!(f, ", ")?; |

107 | } |

108 | write!(f, "{:?} => [", class)?; |

109 | for (start, end) in self.element_ranges(class) { |

110 | if start == end { |

111 | write!(f, "{:?}", start)?; |

112 | } else { |

113 | write!(f, "{:?}-{:?}", start, end)?; |

114 | } |

115 | } |

116 | write!(f, "]")?; |

117 | } |

118 | write!(f, ")") |

119 | } |

120 | } |

121 | } |

122 | |

123 | /// An iterator over each equivalence class. |

124 | #[derive(Debug)] |

125 | pub(crate) struct ByteClassIter { |

126 | it: core::ops::Range<usize>, |

127 | } |

128 | |

129 | impl Iterator for ByteClassIter { |

130 | type Item = u8; |

131 | |

132 | fn next(&mut self) -> Option<u8> { |

133 | self.it.next().map(|class| class.as_u8()) |

134 | } |

135 | } |

136 | |

137 | /// An iterator over all elements in a specific equivalence class. |

138 | #[derive(Debug)] |

139 | pub(crate) struct ByteClassElements<'a> { |

140 | classes: &'a ByteClasses, |

141 | class: u8, |

142 | bytes: core::ops::RangeInclusive<u8>, |

143 | } |

144 | |

145 | impl<'a> Iterator for ByteClassElements<'a> { |

146 | type Item = u8; |

147 | |

148 | fn next(&mut self) -> Option<u8> { |

149 | while let Some(byte) = self.bytes.next() { |

150 | if self.class == self.classes.get(byte) { |

151 | return Some(byte); |

152 | } |

153 | } |

154 | None |

155 | } |

156 | } |

157 | |

158 | /// An iterator over all elements in an equivalence class expressed as a |

159 | /// sequence of contiguous ranges. |

160 | #[derive(Debug)] |

161 | pub(crate) struct ByteClassElementRanges<'a> { |

162 | elements: ByteClassElements<'a>, |

163 | range: Option<(u8, u8)>, |

164 | } |

165 | |

166 | impl<'a> Iterator for ByteClassElementRanges<'a> { |

167 | type Item = (u8, u8); |

168 | |

169 | fn next(&mut self) -> Option<(u8, u8)> { |

170 | loop { |

171 | let element = match self.elements.next() { |

172 | None => return self.range.take(), |

173 | Some(element) => element, |

174 | }; |

175 | match self.range.take() { |

176 | None => { |

177 | self.range = Some((element, element)); |

178 | } |

179 | Some((start, end)) => { |

180 | if usize::from(end) + 1 != usize::from(element) { |

181 | self.range = Some((element, element)); |

182 | return Some((start, end)); |

183 | } |

184 | self.range = Some((start, element)); |

185 | } |

186 | } |

187 | } |

188 | } |

189 | } |

190 | |

191 | /// A partitioning of bytes into equivalence classes. |

192 | /// |

193 | /// A byte class set keeps track of an *approximation* of equivalence classes |

194 | /// of bytes during NFA construction. That is, every byte in an equivalence |

195 | /// class cannot discriminate between a match and a non-match. |

196 | /// |

197 | /// Note that this may not compute the minimal set of equivalence classes. |

198 | /// Basically, any byte in a pattern given to the noncontiguous NFA builder |

199 | /// will automatically be treated as its own equivalence class. All other |

200 | /// bytes---any byte not in any pattern---will be treated as their own |

201 | /// equivalence classes. In theory, all bytes not in any pattern should |

202 | /// be part of a single equivalence class, but in practice, we only treat |

203 | /// contiguous ranges of bytes as an equivalence class. So the number of |

204 | /// classes computed may be bigger than necessary. This usually doesn't make |

205 | /// much of a difference, and keeps the implementation simple. |

206 | #[derive(Clone, Debug)] |

207 | pub(crate) struct ByteClassSet(ByteSet); |

208 | |

209 | impl Default for ByteClassSet { |

210 | fn default() -> ByteClassSet { |

211 | ByteClassSet::empty() |

212 | } |

213 | } |

214 | |

215 | impl ByteClassSet { |

216 | /// Create a new set of byte classes where all bytes are part of the same |

217 | /// equivalence class. |

218 | pub(crate) fn empty() -> Self { |

219 | ByteClassSet(ByteSet::empty()) |

220 | } |

221 | |

222 | /// Indicate the the range of byte given (inclusive) can discriminate a |

223 | /// match between it and all other bytes outside of the range. |

224 | pub(crate) fn set_range(&mut self, start: u8, end: u8) { |

225 | debug_assert!(start <= end); |

226 | if start > 0 { |

227 | self.0.add(start - 1); |

228 | } |

229 | self.0.add(end); |

230 | } |

231 | |

232 | /// Convert this boolean set to a map that maps all byte values to their |

233 | /// corresponding equivalence class. The last mapping indicates the largest |

234 | /// equivalence class identifier (which is never bigger than 255). |

235 | pub(crate) fn byte_classes(&self) -> ByteClasses { |

236 | let mut classes = ByteClasses::empty(); |

237 | let mut class = 0u8; |

238 | let mut b = 0u8; |

239 | loop { |

240 | classes.set(b, class); |

241 | if b == 255 { |

242 | break; |

243 | } |

244 | if self.0.contains(b) { |

245 | class = class.checked_add(1).unwrap(); |

246 | } |

247 | b = b.checked_add(1).unwrap(); |

248 | } |

249 | classes |

250 | } |

251 | } |

252 | |

253 | /// A simple set of bytes that is reasonably cheap to copy and allocation free. |

254 | #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] |

255 | pub(crate) struct ByteSet { |

256 | bits: BitSet, |

257 | } |

258 | |

259 | /// The representation of a byte set. Split out so that we can define a |

260 | /// convenient Debug impl for it while keeping "ByteSet" in the output. |

261 | #[derive(Clone, Copy, Default, Eq, PartialEq)] |

262 | struct BitSet([u128; 2]); |

263 | |

264 | impl ByteSet { |

265 | /// Create an empty set of bytes. |

266 | pub(crate) fn empty() -> ByteSet { |

267 | ByteSet { bits: BitSet([0; 2]) } |

268 | } |

269 | |

270 | /// Add a byte to this set. |

271 | /// |

272 | /// If the given byte already belongs to this set, then this is a no-op. |

273 | pub(crate) fn add(&mut self, byte: u8) { |

274 | let bucket = byte / 128; |

275 | let bit = byte % 128; |

276 | self.bits.0[usize::from(bucket)] |= 1 << bit; |

277 | } |

278 | |

279 | /// Return true if and only if the given byte is in this set. |

280 | pub(crate) fn contains(&self, byte: u8) -> bool { |

281 | let bucket = byte / 128; |

282 | let bit = byte % 128; |

283 | self.bits.0[usize::from(bucket)] & (1 << bit) > 0 |

284 | } |

285 | } |

286 | |

287 | impl core::fmt::Debug for BitSet { |

288 | fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { |

289 | let mut fmtd = f.debug_set(); |

290 | for b in 0u8..=255 { |

291 | if (ByteSet { bits: *self }).contains(b) { |

292 | fmtd.entry(&b); |

293 | } |

294 | } |

295 | fmtd.finish() |

296 | } |

297 | } |

298 | |

299 | #[cfg(test)] |

300 | mod tests { |

301 | use alloc::{vec, vec::Vec}; |

302 | |

303 | use super::*; |

304 | |

305 | #[test] |

306 | fn byte_classes() { |

307 | let mut set = ByteClassSet::empty(); |

308 | set.set_range(b'a', b'z'); |

309 | |

310 | let classes = set.byte_classes(); |

311 | assert_eq!(classes.get(0), 0); |

312 | assert_eq!(classes.get(1), 0); |

313 | assert_eq!(classes.get(2), 0); |

314 | assert_eq!(classes.get(b'a'- 1), 0); |

315 | assert_eq!(classes.get(b'a'), 1); |

316 | assert_eq!(classes.get(b'm'), 1); |

317 | assert_eq!(classes.get(b'z'), 1); |

318 | assert_eq!(classes.get(b'z'+ 1), 2); |

319 | assert_eq!(classes.get(254), 2); |

320 | assert_eq!(classes.get(255), 2); |

321 | |

322 | let mut set = ByteClassSet::empty(); |

323 | set.set_range(0, 2); |

324 | set.set_range(4, 6); |

325 | let classes = set.byte_classes(); |

326 | assert_eq!(classes.get(0), 0); |

327 | assert_eq!(classes.get(1), 0); |

328 | assert_eq!(classes.get(2), 0); |

329 | assert_eq!(classes.get(3), 1); |

330 | assert_eq!(classes.get(4), 2); |

331 | assert_eq!(classes.get(5), 2); |

332 | assert_eq!(classes.get(6), 2); |

333 | assert_eq!(classes.get(7), 3); |

334 | assert_eq!(classes.get(255), 3); |

335 | } |

336 | |

337 | #[test] |

338 | fn full_byte_classes() { |

339 | let mut set = ByteClassSet::empty(); |

340 | for b in 0u8..=255 { |

341 | set.set_range(b, b); |

342 | } |

343 | assert_eq!(set.byte_classes().alphabet_len(), 256); |

344 | } |

345 | |

346 | #[test] |

347 | fn elements_typical() { |

348 | let mut set = ByteClassSet::empty(); |

349 | set.set_range(b'b', b'd'); |

350 | set.set_range(b'g', b'm'); |

351 | set.set_range(b'z', b'z'); |

352 | let classes = set.byte_classes(); |

353 | // class 0: \x00-a |

354 | // class 1: b-d |

355 | // class 2: e-f |

356 | // class 3: g-m |

357 | // class 4: n-y |

358 | // class 5: z-z |

359 | // class 6: \x7B-\xFF |

360 | assert_eq!(classes.alphabet_len(), 7); |

361 | |

362 | let elements = classes.elements(0).collect::<Vec<_>>(); |

363 | assert_eq!(elements.len(), 98); |

364 | assert_eq!(elements[0], b' \x00'); |

365 | assert_eq!(elements[97], b'a'); |

366 | |

367 | let elements = classes.elements(1).collect::<Vec<_>>(); |

368 | assert_eq!(elements, vec![b'b', b'c', b'd'],); |

369 | |

370 | let elements = classes.elements(2).collect::<Vec<_>>(); |

371 | assert_eq!(elements, vec![b'e', b'f'],); |

372 | |

373 | let elements = classes.elements(3).collect::<Vec<_>>(); |

374 | assert_eq!(elements, vec![b'g', b'h', b'i', b'j', b'k', b'l', b'm',],); |

375 | |

376 | let elements = classes.elements(4).collect::<Vec<_>>(); |

377 | assert_eq!(elements.len(), 12); |

378 | assert_eq!(elements[0], b'n'); |

379 | assert_eq!(elements[11], b'y'); |

380 | |

381 | let elements = classes.elements(5).collect::<Vec<_>>(); |

382 | assert_eq!(elements, vec![b'z']); |

383 | |

384 | let elements = classes.elements(6).collect::<Vec<_>>(); |

385 | assert_eq!(elements.len(), 133); |

386 | assert_eq!(elements[0], b' \x7B'); |

387 | assert_eq!(elements[132], b' \xFF'); |

388 | } |

389 | |

390 | #[test] |

391 | fn elements_singletons() { |

392 | let classes = ByteClasses::singletons(); |

393 | assert_eq!(classes.alphabet_len(), 256); |

394 | |

395 | let elements = classes.elements(b'a').collect::<Vec<_>>(); |

396 | assert_eq!(elements, vec![b'a']); |

397 | } |

398 | |

399 | #[test] |

400 | fn elements_empty() { |

401 | let classes = ByteClasses::empty(); |

402 | assert_eq!(classes.alphabet_len(), 1); |

403 | |

404 | let elements = classes.elements(0).collect::<Vec<_>>(); |

405 | assert_eq!(elements.len(), 256); |

406 | assert_eq!(elements[0], b' \x00'); |

407 | assert_eq!(elements[255], b' \xFF'); |

408 | } |

409 | } |

410 |