1//===- llvm/ADT/CoalescingBitVector.h - A coalescing bitvector --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// A bitvector that uses an IntervalMap to coalesce adjacent elements
11/// into intervals.
12///
13//===----------------------------------------------------------------------===//
14
15#ifndef LLVM_ADT_COALESCINGBITVECTOR_H
16#define LLVM_ADT_COALESCINGBITVECTOR_H
17
18#include "llvm/ADT/IntervalMap.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/SmallVector.h"
21#include "llvm/ADT/iterator_range.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/raw_ostream.h"
24
25#include <initializer_list>
26
27namespace llvm {
28
29/// A bitvector that, under the hood, relies on an IntervalMap to coalesce
30/// elements into intervals. Good for representing sets which predominantly
31/// contain contiguous ranges. Bad for representing sets with lots of gaps
32/// between elements.
33///
34/// Compared to SparseBitVector, CoalescingBitVector offers more predictable
35/// performance for non-sequential find() operations.
36///
37/// \tparam IndexT - The type of the index into the bitvector.
38template <typename IndexT> class CoalescingBitVector {
39 static_assert(std::is_unsigned<IndexT>::value,
40 "Index must be an unsigned integer.");
41
42 using ThisT = CoalescingBitVector<IndexT>;
43
44 /// An interval map for closed integer ranges. The mapped values are unused.
45 using MapT = IntervalMap<IndexT, char>;
46
47 using UnderlyingIterator = typename MapT::const_iterator;
48
49 using IntervalT = std::pair<IndexT, IndexT>;
50
51public:
52 using Allocator = typename MapT::Allocator;
53
54 /// Construct by passing in a CoalescingBitVector<IndexT>::Allocator
55 /// reference.
56 CoalescingBitVector(Allocator &Alloc)
57 : Alloc(&Alloc), Intervals(Alloc) {}
58
59 /// \name Copy/move constructors and assignment operators.
60 /// @{
61
62 CoalescingBitVector(const ThisT &Other)
63 : Alloc(Other.Alloc), Intervals(*Other.Alloc) {
64 set(Other);
65 }
66
67 ThisT &operator=(const ThisT &Other) {
68 clear();
69 set(Other);
70 return *this;
71 }
72
73 CoalescingBitVector(ThisT &&Other) = delete;
74 ThisT &operator=(ThisT &&Other) = delete;
75
76 /// @}
77
78 /// Clear all the bits.
79 void clear() { Intervals.clear(); }
80
81 /// Check whether no bits are set.
82 bool empty() const { return Intervals.empty(); }
83
84 /// Count the number of set bits.
85 unsigned count() const {
86 unsigned Bits = 0;
87 for (auto It = Intervals.begin(), End = Intervals.end(); It != End; ++It)
88 Bits += 1 + It.stop() - It.start();
89 return Bits;
90 }
91
92 /// Set the bit at \p Index.
93 ///
94 /// This method does /not/ support setting a bit that has already been set,
95 /// for efficiency reasons. If possible, restructure your code to not set the
96 /// same bit multiple times, or use \ref test_and_set.
97 void set(IndexT Index) {
98 assert(!test(Index) && "Setting already-set bits not supported/efficient, "
99 "IntervalMap will assert");
100 insert(Start: Index, End: Index);
101 }
102
103 /// Set the bits set in \p Other.
104 ///
105 /// This method does /not/ support setting already-set bits, see \ref set
106 /// for the rationale. For a safe set union operation, use \ref operator|=.
107 void set(const ThisT &Other) {
108 for (auto It = Other.Intervals.begin(), End = Other.Intervals.end();
109 It != End; ++It)
110 insert(Start: It.start(), End: It.stop());
111 }
112
113 /// Set the bits at \p Indices. Used for testing, primarily.
114 void set(std::initializer_list<IndexT> Indices) {
115 for (IndexT Index : Indices)
116 set(Index);
117 }
118
119 /// Check whether the bit at \p Index is set.
120 bool test(IndexT Index) const {
121 const auto It = Intervals.find(Index);
122 if (It == Intervals.end())
123 return false;
124 assert(It.stop() >= Index && "Interval must end after Index");
125 return It.start() <= Index;
126 }
127
128 /// Set the bit at \p Index. Supports setting an already-set bit.
129 void test_and_set(IndexT Index) {
130 if (!test(Index))
131 set(Index);
132 }
133
134 /// Reset the bit at \p Index. Supports resetting an already-unset bit.
135 void reset(IndexT Index) {
136 auto It = Intervals.find(Index);
137 if (It == Intervals.end())
138 return;
139
140 // Split the interval containing Index into up to two parts: one from
141 // [Start, Index-1] and another from [Index+1, Stop]. If Index is equal to
142 // either Start or Stop, we create one new interval. If Index is equal to
143 // both Start and Stop, we simply erase the existing interval.
144 IndexT Start = It.start();
145 if (Index < Start)
146 // The index was not set.
147 return;
148 IndexT Stop = It.stop();
149 assert(Index <= Stop && "Wrong interval for index");
150 It.erase();
151 if (Start < Index)
152 insert(Start, End: Index - 1);
153 if (Index < Stop)
154 insert(Start: Index + 1, End: Stop);
155 }
156
157 /// Set union. If \p RHS is guaranteed to not overlap with this, \ref set may
158 /// be a faster alternative.
159 void operator|=(const ThisT &RHS) {
160 // Get the overlaps between the two interval maps.
161 SmallVector<IntervalT, 8> Overlaps;
162 getOverlaps(Other: RHS, Overlaps);
163
164 // Insert the non-overlapping parts of all the intervals from RHS.
165 for (auto It = RHS.Intervals.begin(), End = RHS.Intervals.end();
166 It != End; ++It) {
167 IndexT Start = It.start();
168 IndexT Stop = It.stop();
169 SmallVector<IntervalT, 8> NonOverlappingParts;
170 getNonOverlappingParts(Start, Stop, Overlaps, NonOverlappingParts);
171 for (IntervalT AdditivePortion : NonOverlappingParts)
172 insert(Start: AdditivePortion.first, End: AdditivePortion.second);
173 }
174 }
175
176 /// Set intersection.
177 void operator&=(const ThisT &RHS) {
178 // Get the overlaps between the two interval maps (i.e. the intersection).
179 SmallVector<IntervalT, 8> Overlaps;
180 getOverlaps(Other: RHS, Overlaps);
181 // Rebuild the interval map, including only the overlaps.
182 clear();
183 for (IntervalT Overlap : Overlaps)
184 insert(Start: Overlap.first, End: Overlap.second);
185 }
186
187 /// Reset all bits present in \p Other.
188 void intersectWithComplement(const ThisT &Other) {
189 SmallVector<IntervalT, 8> Overlaps;
190 if (!getOverlaps(Other, Overlaps)) {
191 // If there is no overlap with Other, the intersection is empty.
192 return;
193 }
194
195 // Delete the overlapping intervals. Split up intervals that only partially
196 // intersect an overlap.
197 for (IntervalT Overlap : Overlaps) {
198 IndexT OlapStart, OlapStop;
199 std::tie(OlapStart, OlapStop) = Overlap;
200
201 auto It = Intervals.find(OlapStart);
202 IndexT CurrStart = It.start();
203 IndexT CurrStop = It.stop();
204 assert(CurrStart <= OlapStart && OlapStop <= CurrStop &&
205 "Expected some intersection!");
206
207 // Split the overlap interval into up to two parts: one from [CurrStart,
208 // OlapStart-1] and another from [OlapStop+1, CurrStop]. If OlapStart is
209 // equal to CurrStart, the first split interval is unnecessary. Ditto for
210 // when OlapStop is equal to CurrStop, we omit the second split interval.
211 It.erase();
212 if (CurrStart < OlapStart)
213 insert(Start: CurrStart, End: OlapStart - 1);
214 if (OlapStop < CurrStop)
215 insert(Start: OlapStop + 1, End: CurrStop);
216 }
217 }
218
219 bool operator==(const ThisT &RHS) const {
220 // We cannot just use std::equal because it checks the dereferenced values
221 // of an iterator pair for equality, not the iterators themselves. In our
222 // case that results in comparison of the (unused) IntervalMap values.
223 auto ItL = Intervals.begin();
224 auto ItR = RHS.Intervals.begin();
225 while (ItL != Intervals.end() && ItR != RHS.Intervals.end() &&
226 ItL.start() == ItR.start() && ItL.stop() == ItR.stop()) {
227 ++ItL;
228 ++ItR;
229 }
230 return ItL == Intervals.end() && ItR == RHS.Intervals.end();
231 }
232
233 bool operator!=(const ThisT &RHS) const { return !operator==(RHS); }
234
235 class const_iterator {
236 friend class CoalescingBitVector;
237
238 public:
239 using iterator_category = std::forward_iterator_tag;
240 using value_type = IndexT;
241 using difference_type = std::ptrdiff_t;
242 using pointer = value_type *;
243 using reference = value_type &;
244
245 private:
246 // For performance reasons, make the offset at the end different than the
247 // one used in \ref begin, to optimize the common `It == end()` pattern.
248 static constexpr unsigned kIteratorAtTheEndOffset = ~0u;
249
250 UnderlyingIterator MapIterator;
251 unsigned OffsetIntoMapIterator = 0;
252
253 // Querying the start/stop of an IntervalMap iterator can be very expensive.
254 // Cache these values for performance reasons.
255 IndexT CachedStart = IndexT();
256 IndexT CachedStop = IndexT();
257
258 void setToEnd() {
259 OffsetIntoMapIterator = kIteratorAtTheEndOffset;
260 CachedStart = IndexT();
261 CachedStop = IndexT();
262 }
263
264 /// MapIterator has just changed, reset the cached state to point to the
265 /// start of the new underlying iterator.
266 void resetCache() {
267 if (MapIterator.valid()) {
268 OffsetIntoMapIterator = 0;
269 CachedStart = MapIterator.start();
270 CachedStop = MapIterator.stop();
271 } else {
272 setToEnd();
273 }
274 }
275
276 /// Advance the iterator to \p Index, if it is contained within the current
277 /// interval. The public-facing method which supports advancing past the
278 /// current interval is \ref advanceToLowerBound.
279 void advanceTo(IndexT Index) {
280 assert(Index <= CachedStop && "Cannot advance to OOB index");
281 if (Index < CachedStart)
282 // We're already past this index.
283 return;
284 OffsetIntoMapIterator = Index - CachedStart;
285 }
286
287 const_iterator(UnderlyingIterator MapIt) : MapIterator(MapIt) {
288 resetCache();
289 }
290
291 public:
292 const_iterator() { setToEnd(); }
293
294 bool operator==(const const_iterator &RHS) const {
295 // Do /not/ compare MapIterator for equality, as this is very expensive.
296 // The cached start/stop values make that check unnecessary.
297 return std::tie(OffsetIntoMapIterator, CachedStart, CachedStop) ==
298 std::tie(RHS.OffsetIntoMapIterator, RHS.CachedStart,
299 RHS.CachedStop);
300 }
301
302 bool operator!=(const const_iterator &RHS) const {
303 return !operator==(RHS);
304 }
305
306 IndexT operator*() const { return CachedStart + OffsetIntoMapIterator; }
307
308 const_iterator &operator++() { // Pre-increment (++It).
309 if (CachedStart + OffsetIntoMapIterator < CachedStop) {
310 // Keep going within the current interval.
311 ++OffsetIntoMapIterator;
312 } else {
313 // We reached the end of the current interval: advance.
314 ++MapIterator;
315 resetCache();
316 }
317 return *this;
318 }
319
320 const_iterator operator++(int) { // Post-increment (It++).
321 const_iterator tmp = *this;
322 operator++();
323 return tmp;
324 }
325
326 /// Advance the iterator to the first set bit AT, OR AFTER, \p Index. If
327 /// no such set bit exists, advance to end(). This is like std::lower_bound.
328 /// This is useful if \p Index is close to the current iterator position.
329 /// However, unlike \ref find(), this has worst-case O(n) performance.
330 void advanceToLowerBound(IndexT Index) {
331 if (OffsetIntoMapIterator == kIteratorAtTheEndOffset)
332 return;
333
334 // Advance to the first interval containing (or past) Index, or to end().
335 while (Index > CachedStop) {
336 ++MapIterator;
337 resetCache();
338 if (OffsetIntoMapIterator == kIteratorAtTheEndOffset)
339 return;
340 }
341
342 advanceTo(Index);
343 }
344 };
345
346 const_iterator begin() const { return const_iterator(Intervals.begin()); }
347
348 const_iterator end() const { return const_iterator(); }
349
350 /// Return an iterator pointing to the first set bit AT, OR AFTER, \p Index.
351 /// If no such set bit exists, return end(). This is like std::lower_bound.
352 /// This has worst-case logarithmic performance (roughly O(log(gaps between
353 /// contiguous ranges))).
354 const_iterator find(IndexT Index) const {
355 auto UnderlyingIt = Intervals.find(Index);
356 if (UnderlyingIt == Intervals.end())
357 return end();
358 auto It = const_iterator(UnderlyingIt);
359 It.advanceTo(Index);
360 return It;
361 }
362
363 /// Return a range iterator which iterates over all of the set bits in the
364 /// half-open range [Start, End).
365 iterator_range<const_iterator> half_open_range(IndexT Start,
366 IndexT End) const {
367 assert(Start < End && "Not a valid range");
368 auto StartIt = find(Index: Start);
369 if (StartIt == end() || *StartIt >= End)
370 return {end(), end()};
371 auto EndIt = StartIt;
372 EndIt.advanceToLowerBound(End);
373 return {StartIt, EndIt};
374 }
375
376 void print(raw_ostream &OS) const {
377 OS << "{";
378 for (auto It = Intervals.begin(), End = Intervals.end(); It != End;
379 ++It) {
380 OS << "[" << It.start();
381 if (It.start() != It.stop())
382 OS << ", " << It.stop();
383 OS << "]";
384 }
385 OS << "}";
386 }
387
388#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
389 LLVM_DUMP_METHOD void dump() const {
390 // LLDB swallows the first line of output after callling dump(). Add
391 // newlines before/after the braces to work around this.
392 dbgs() << "\n";
393 print(OS&: dbgs());
394 dbgs() << "\n";
395 }
396#endif
397
398private:
399 void insert(IndexT Start, IndexT End) { Intervals.insert(Start, End, 0); }
400
401 /// Record the overlaps between \p this and \p Other in \p Overlaps. Return
402 /// true if there is any overlap.
403 bool getOverlaps(const ThisT &Other,
404 SmallVectorImpl<IntervalT> &Overlaps) const {
405 for (IntervalMapOverlaps<MapT, MapT> I(Intervals, Other.Intervals);
406 I.valid(); ++I)
407 Overlaps.emplace_back(I.start(), I.stop());
408 assert(llvm::is_sorted(Overlaps,
409 [](IntervalT LHS, IntervalT RHS) {
410 return LHS.second < RHS.first;
411 }) &&
412 "Overlaps must be sorted");
413 return !Overlaps.empty();
414 }
415
416 /// Given the set of overlaps between this and some other bitvector, and an
417 /// interval [Start, Stop] from that bitvector, determine the portions of the
418 /// interval which do not overlap with this.
419 void getNonOverlappingParts(IndexT Start, IndexT Stop,
420 const SmallVectorImpl<IntervalT> &Overlaps,
421 SmallVectorImpl<IntervalT> &NonOverlappingParts) {
422 IndexT NextUncoveredBit = Start;
423 for (IntervalT Overlap : Overlaps) {
424 IndexT OlapStart, OlapStop;
425 std::tie(OlapStart, OlapStop) = Overlap;
426
427 // [Start;Stop] and [OlapStart;OlapStop] overlap iff OlapStart <= Stop
428 // and Start <= OlapStop.
429 bool DoesOverlap = OlapStart <= Stop && Start <= OlapStop;
430 if (!DoesOverlap)
431 continue;
432
433 // Cover the range [NextUncoveredBit, OlapStart). This puts the start of
434 // the next uncovered range at OlapStop+1.
435 if (NextUncoveredBit < OlapStart)
436 NonOverlappingParts.emplace_back(NextUncoveredBit, OlapStart - 1);
437 NextUncoveredBit = OlapStop + 1;
438 if (NextUncoveredBit > Stop)
439 break;
440 }
441 if (NextUncoveredBit <= Stop)
442 NonOverlappingParts.emplace_back(NextUncoveredBit, Stop);
443 }
444
445 Allocator *Alloc;
446 MapT Intervals;
447};
448
449} // namespace llvm
450
451#endif // LLVM_ADT_COALESCINGBITVECTOR_H
452

source code of llvm/include/llvm/ADT/CoalescingBitVector.h