Skip to main content

splinter_rs/
splinter.rs

1use std::{fmt::Debug, ops::RangeBounds};
2
3use bytes::Bytes;
4
5use crate::{
6    Encodable, Optimizable, SplinterRef,
7    codec::{encoder::Encoder, footer::Footer},
8    level::High,
9    partition::Partition,
10    traits::{PartitionRead, PartitionWrite},
11    util::RangeExt,
12};
13
14/// A compressed bitmap optimized for small, sparse sets of 32-bit unsigned integers.
15///
16/// `Splinter` is the main owned data structure that can be built incrementally by inserting
17/// values and then optimized for size and query performance. It uses a 256-way tree structure
18/// by decomposing integers into big-endian component bytes, with nodes optimized into four
19/// different storage classes: tree, vec, bitmap, and run.
20///
21/// For zero-copy querying of serialized data, see [`SplinterRef`].
22/// For a clone-on-write wrapper, see [`crate::CowSplinter`].
23///
24/// # Examples
25///
26/// Basic usage:
27///
28/// ```
29/// use splinter_rs::{Splinter, PartitionWrite, PartitionRead, Optimizable};
30///
31/// let mut splinter = Splinter::from_iter([1024, 2048, 123]);
32///
33/// // Check membership
34/// assert!(splinter.contains(1024));
35/// assert!(!splinter.contains(999));
36///
37/// // Get cardinality
38/// assert_eq!(splinter.cardinality(), 3);
39///
40/// // Optimize for better compression, recommended before encoding to bytes.
41/// splinter.optimize();
42/// ```
43///
44/// Building from iterator:
45///
46/// ```
47/// use splinter_rs::{Splinter, PartitionRead};
48///
49/// let values = vec![100, 200, 300, 400];
50/// let splinter: Splinter = values.into_iter().collect();
51///
52/// assert_eq!(splinter.cardinality(), 4);
53/// assert!(splinter.contains(200));
54/// ```
55#[derive(Clone, PartialEq, Eq, Default, Debug)]
56pub struct Splinter(Partition<High>);
57
58static_assertions::const_assert_eq!(std::mem::size_of::<Splinter>(), 40);
59
60impl Splinter {
61    /// An empty Splinter, suitable for usage in a const context.
62    pub const EMPTY: Self = Splinter(Partition::EMPTY);
63
64    /// A full Splinter, suitable for usage in a const context.
65    pub const FULL: Self = Splinter(Partition::Full);
66
67    /// Encodes this splinter into a [`SplinterRef`] for zero-copy querying.
68    ///
69    /// This method serializes the splinter data and returns a [`SplinterRef<Bytes>`]
70    /// that can be used for efficient read-only operations without deserializing
71    /// the underlying data structure.
72    ///
73    /// # Examples
74    ///
75    /// ```
76    /// use splinter_rs::{Splinter, PartitionWrite, PartitionRead};
77    ///
78    /// let mut splinter = Splinter::from_iter([42, 1337]);
79    ///
80    /// let splinter_ref = splinter.encode_to_splinter_ref();
81    /// assert_eq!(splinter_ref.cardinality(), 2);
82    /// assert!(splinter_ref.contains(42));
83    /// ```
84    pub fn encode_to_splinter_ref(&self) -> SplinterRef<Bytes> {
85        SplinterRef { data: self.encode_to_bytes() }
86    }
87
88    #[inline(always)]
89    pub(crate) fn new(inner: Partition<High>) -> Self {
90        Self(inner)
91    }
92
93    #[inline(always)]
94    pub(crate) fn inner(&self) -> &Partition<High> {
95        &self.0
96    }
97
98    #[inline(always)]
99    pub(crate) fn inner_mut(&mut self) -> &mut Partition<High> {
100        &mut self.0
101    }
102}
103
104impl FromIterator<u32> for Splinter {
105    fn from_iter<I: IntoIterator<Item = u32>>(iter: I) -> Self {
106        Self(Partition::<High>::from_iter(iter))
107    }
108}
109
110impl<R: RangeBounds<u32>> From<R> for Splinter {
111    fn from(range: R) -> Self {
112        if let Some(range) = range.try_into_inclusive() {
113            if range.start() == &u32::MIN && range.end() == &u32::MAX {
114                Self::FULL
115            } else {
116                Self(Partition::<High>::from(range))
117            }
118        } else {
119            // range is empty
120            Self::EMPTY
121        }
122    }
123}
124
125impl PartitionRead<High> for Splinter {
126    /// Returns the total number of elements in this splinter.
127    ///
128    /// # Examples
129    ///
130    /// ```
131    /// use splinter_rs::{Splinter, PartitionRead, PartitionWrite};
132    ///
133    /// let mut splinter = Splinter::EMPTY;
134    /// assert_eq!(splinter.cardinality(), 0);
135    ///
136    /// let splinter = Splinter::from_iter([100, 200, 300]);
137    /// assert_eq!(splinter.cardinality(), 3);
138    /// ```
139    #[inline]
140    fn cardinality(&self) -> usize {
141        self.0.cardinality()
142    }
143
144    /// Returns `true` if this splinter contains no elements.
145    ///
146    /// # Examples
147    ///
148    /// ```
149    /// use splinter_rs::{Splinter, PartitionRead, PartitionWrite};
150    ///
151    /// let mut splinter = Splinter::EMPTY;
152    /// assert!(splinter.is_empty());
153    ///
154    /// let splinter = Splinter::from_iter([42]);
155    /// assert!(!splinter.is_empty());
156    /// ```
157    #[inline]
158    fn is_empty(&self) -> bool {
159        self.0.is_empty()
160    }
161
162    /// Returns `true` if this splinter contains the specified value.
163    ///
164    /// # Examples
165    ///
166    /// ```
167    /// use splinter_rs::{Splinter, PartitionRead, PartitionWrite};
168    ///
169    /// let splinter = Splinter::from_iter([42, 1337]);
170    ///
171    /// assert!(splinter.contains(42));
172    /// assert!(splinter.contains(1337));
173    /// assert!(!splinter.contains(999));
174    /// ```
175    #[inline]
176    fn contains(&self, value: u32) -> bool {
177        self.0.contains(value)
178    }
179
180    /// Returns the 0-based position of the value in this splinter if it exists.
181    ///
182    /// This method searches for the given value in the splinter and returns its position
183    /// in the sorted sequence of all elements. If the value doesn't exist, returns `None`.
184    ///
185    /// # Examples
186    ///
187    /// ```
188    /// use splinter_rs::{Splinter, PartitionRead, PartitionWrite};
189    ///
190    /// let splinter = Splinter::from_iter([10, 20, 30]);
191    ///
192    /// assert_eq!(splinter.position(10), Some(0));
193    /// assert_eq!(splinter.position(20), Some(1));
194    /// assert_eq!(splinter.position(30), Some(2));
195    /// assert_eq!(splinter.position(25), None); // doesn't exist
196    /// ```
197    #[inline]
198    fn position(&self, value: u32) -> Option<usize> {
199        self.0.position(value)
200    }
201
202    /// Returns the number of elements in this splinter that are less than or equal to the given value.
203    ///
204    /// This is also known as the "rank" of the value in the sorted sequence of all elements.
205    ///
206    /// # Examples
207    ///
208    /// ```
209    /// use splinter_rs::{Splinter, PartitionRead, PartitionWrite};
210    ///
211    /// let splinter = Splinter::from_iter([10, 20, 30]);
212    ///
213    /// assert_eq!(splinter.rank(5), 0);   // No elements <= 5
214    /// assert_eq!(splinter.rank(10), 1);  // One element <= 10
215    /// assert_eq!(splinter.rank(25), 2);  // Two elements <= 25
216    /// assert_eq!(splinter.rank(30), 3);  // Three elements <= 30
217    /// assert_eq!(splinter.rank(50), 3);  // Three elements <= 50
218    /// ```
219    #[inline]
220    fn rank(&self, value: u32) -> usize {
221        self.0.rank(value)
222    }
223
224    /// Returns the element at the given index in the sorted sequence, or `None` if the index is out of bounds.
225    ///
226    /// The index is 0-based, so `select(0)` returns the smallest element.
227    ///
228    /// # Examples
229    ///
230    /// ```
231    /// use splinter_rs::{Splinter, PartitionRead, PartitionWrite};
232    ///
233    /// let splinter = Splinter::from_iter([100, 50, 200]);
234    ///
235    /// assert_eq!(splinter.select(0), Some(50));   // Smallest element
236    /// assert_eq!(splinter.select(1), Some(100));  // Second smallest
237    /// assert_eq!(splinter.select(2), Some(200));  // Largest element
238    /// assert_eq!(splinter.select(3), None);       // Out of bounds
239    /// ```
240    #[inline]
241    fn select(&self, idx: usize) -> Option<u32> {
242        self.0.select(idx)
243    }
244
245    /// Returns the largest element in this splinter, or `None` if it's empty.
246    ///
247    /// # Examples
248    ///
249    /// ```
250    /// use splinter_rs::{Splinter, PartitionRead, PartitionWrite};
251    ///
252    /// let mut splinter = Splinter::EMPTY;
253    /// assert_eq!(splinter.last(), None);
254    ///
255    /// let splinter = Splinter::from_iter([100, 50, 200]);
256    ///
257    /// assert_eq!(splinter.last(), Some(200));
258    /// ```
259    #[inline]
260    fn last(&self) -> Option<u32> {
261        self.0.last()
262    }
263
264    /// Returns an iterator over all elements in ascending order.
265    ///
266    /// # Examples
267    ///
268    /// ```
269    /// use splinter_rs::{Splinter, PartitionRead, PartitionWrite};
270    ///
271    /// let splinter = Splinter::from_iter([300, 100, 200]);
272    ///
273    /// let values: Vec<u32> = splinter.iter().collect();
274    /// assert_eq!(values, vec![100, 200, 300]);
275    /// ```
276    #[inline]
277    fn iter(&self) -> impl Iterator<Item = u32> {
278        self.0.iter()
279    }
280
281    /// Returns `true` if this splinter contains all values in the specified range.
282    ///
283    /// This method checks whether every value within the given range bounds is present
284    /// in the splinter. An empty range is trivially contained and returns `true`.
285    ///
286    /// # Examples
287    ///
288    /// ```
289    /// use splinter_rs::{Splinter, PartitionRead};
290    ///
291    /// let splinter = Splinter::from_iter([10, 11, 12, 13, 14, 15, 100]);
292    ///
293    /// // Check if range is fully contained
294    /// assert!(splinter.contains_all(10..=15));
295    /// assert!(splinter.contains_all(11..=14));
296    ///
297    /// // Missing values mean the range is not fully contained
298    /// assert!(!splinter.contains_all(10..=16));  // 16 is missing
299    /// assert!(!splinter.contains_all(9..=15));   // 9 is missing
300    ///
301    /// // Empty ranges are trivially contained
302    /// assert!(splinter.contains_all(50..50));
303    /// ```
304    #[inline]
305    fn contains_all<R: RangeBounds<u32>>(&self, values: R) -> bool {
306        self.0.contains_all(values)
307    }
308
309    /// Returns `true` if this splinter has a non-empty intersection with the specified range.
310    ///
311    /// This method checks whether any value within the given range is present
312    /// in the splinter. Returns `false` for empty ranges.
313    ///
314    /// # Examples
315    ///
316    /// ```
317    /// use splinter_rs::{Splinter, PartitionRead};
318    ///
319    /// let splinter = Splinter::from_iter([10, 20, 30]);
320    ///
321    /// // Check for any overlap
322    /// assert!(splinter.contains_any(10..=15));   // Contains 10
323    /// assert!(splinter.contains_any(5..=10));    // Contains 10
324    /// assert!(splinter.contains_any(25..=35));   // Contains 30
325    ///
326    /// // No overlap
327    /// assert!(!splinter.contains_any(0..=9));    // No values in range
328    /// assert!(!splinter.contains_any(40..=50));  // No values in range
329    ///
330    /// // Empty ranges have no intersection
331    /// assert!(!splinter.contains_any(50..50));
332    /// ```
333    #[inline]
334    fn contains_any<R: RangeBounds<u32>>(&self, values: R) -> bool {
335        self.0.contains_any(values)
336    }
337}
338
339impl PartitionWrite<High> for Splinter {
340    /// Inserts a value into this splinter.
341    ///
342    /// Returns `true` if the value was newly inserted, or `false` if it was already present.
343    ///
344    /// # Examples
345    ///
346    /// ```
347    /// use splinter_rs::{Splinter, PartitionWrite, PartitionRead};
348    ///
349    /// let mut splinter = Splinter::EMPTY;
350    ///
351    /// // First insertion returns true
352    /// assert!(splinter.insert(42));
353    /// assert_eq!(splinter.cardinality(), 1);
354    ///
355    /// // Second insertion of same value returns false
356    /// assert!(!splinter.insert(42));
357    /// assert_eq!(splinter.cardinality(), 1);
358    ///
359    /// // Different value returns true
360    /// assert!(splinter.insert(100));
361    /// assert_eq!(splinter.cardinality(), 2);
362    /// ```
363    #[inline]
364    fn insert(&mut self, value: u32) -> bool {
365        self.0.insert(value)
366    }
367
368    /// Removes a value from this splinter.
369    ///
370    /// Returns `true` if the value was present and removed, or `false` if it was not present.
371    ///
372    /// # Examples
373    ///
374    /// ```
375    /// use splinter_rs::{Splinter, PartitionWrite, PartitionRead};
376    ///
377    /// let mut splinter = Splinter::from_iter([42, 100]);
378    /// assert_eq!(splinter.cardinality(), 2);
379    ///
380    /// // Remove existing value
381    /// assert!(splinter.remove(42));
382    /// assert_eq!(splinter.cardinality(), 1);
383    /// assert!(!splinter.contains(42));
384    /// assert!(splinter.contains(100));
385    ///
386    /// // Remove non-existent value
387    /// assert!(!splinter.remove(999));
388    /// assert_eq!(splinter.cardinality(), 1);
389    /// ```
390    #[inline]
391    fn remove(&mut self, value: u32) -> bool {
392        self.0.remove(value)
393    }
394
395    /// Removes a range of values from this splinter.
396    ///
397    /// This method removes all values that fall within the specified range bounds.
398    /// The range can be inclusive, exclusive, or half-bounded using standard Rust range syntax.
399    ///
400    /// # Examples
401    ///
402    /// ```
403    /// use splinter_rs::{Splinter, PartitionRead, PartitionWrite};
404    ///
405    /// let mut splinter = Splinter::from_iter(1..=10);
406    ///
407    /// // Remove values 3 through 7 (inclusive)
408    /// splinter.remove_range(3..=7);
409    /// assert!(!splinter.contains(5));
410    /// assert!(splinter.contains(2));
411    /// assert!(splinter.contains(8));
412    ///
413    /// // Remove from 9 onwards
414    /// splinter.remove_range(9..);
415    /// assert!(!splinter.contains(9));
416    /// assert!(!splinter.contains(10));
417    /// assert!(splinter.contains(8));
418    /// ```
419    #[inline]
420    fn remove_range<R: RangeBounds<u32>>(&mut self, values: R) {
421        self.0.remove_range(values);
422    }
423}
424
425impl Encodable for Splinter {
426    fn encoded_size(&self) -> usize {
427        self.0.encoded_size() + std::mem::size_of::<Footer>()
428    }
429
430    fn encode<B: bytes::BufMut>(&self, encoder: &mut Encoder<B>) {
431        self.0.encode(encoder);
432        encoder.write_footer();
433    }
434}
435
436impl Optimizable for Splinter {
437    #[inline]
438    fn optimize(&mut self) {
439        self.0.optimize();
440    }
441}
442
443impl Extend<u32> for Splinter {
444    #[inline]
445    fn extend<T: IntoIterator<Item = u32>>(&mut self, iter: T) {
446        self.0.extend(iter);
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use std::ops::Bound;
453
454    use super::*;
455    use crate::{
456        codec::Encodable,
457        level::{Level, Low},
458        testutil::{SetGen, mksplinter, ratio_to_marks, test_partition_read, test_partition_write},
459        traits::Optimizable,
460    };
461    use itertools::{Itertools, assert_equal};
462    use proptest::{
463        collection::{hash_set, vec},
464        proptest,
465    };
466    use rand::{SeedableRng, seq::index};
467    use roaring::RoaringBitmap;
468
469    #[test]
470    fn test_sanity() {
471        let mut splinter = Splinter::EMPTY;
472
473        assert!(splinter.insert(1));
474        assert!(!splinter.insert(1));
475        assert!(splinter.contains(1));
476
477        let values = [1024, 123, 16384];
478        for v in values {
479            assert!(splinter.insert(v));
480            assert!(splinter.contains(v));
481            assert!(!splinter.contains(v + 1));
482        }
483
484        for i in 0..8192 + 10 {
485            splinter.insert(i);
486        }
487
488        splinter.optimize();
489
490        dbg!(&splinter);
491
492        let expected = splinter.iter().collect_vec();
493        test_partition_read(&splinter, &expected);
494        test_partition_write(&mut splinter);
495    }
496
497    #[test]
498    fn test_wat() {
499        let mut set_gen = SetGen::new(0xDEAD_BEEF);
500        let set = set_gen.random_max(64, 4096);
501        let baseline_size = set.len() * 4;
502
503        let mut splinter = Splinter::from_iter(set.iter().copied());
504        splinter.optimize();
505
506        dbg!(&splinter, splinter.encoded_size(), baseline_size, set.len());
507        itertools::assert_equal(splinter.iter(), set);
508    }
509
510    #[test]
511    fn test_splinter_write() {
512        let mut splinter = Splinter::from_iter(0u32..16384);
513        test_partition_write(&mut splinter);
514    }
515
516    #[test]
517    fn test_splinter_optimize_growth() {
518        let mut splinter = Splinter::EMPTY;
519        let mut rng = rand::rngs::StdRng::seed_from_u64(0xdeadbeef);
520        let set = index::sample(&mut rng, Low::MAX_LEN, 8);
521        dbg!(&splinter);
522        for i in set {
523            splinter.insert(i as u32);
524            dbg!(&splinter);
525        }
526    }
527
528    #[test]
529    fn test_splinter_from_range() {
530        let splinter = Splinter::from(..);
531        assert_eq!(splinter.cardinality(), (u32::MAX as usize) + 1);
532
533        let mut splinter = Splinter::from(1..);
534        assert_eq!(splinter.cardinality(), u32::MAX as usize);
535
536        splinter.remove(1024);
537        assert_eq!(splinter.cardinality(), (u32::MAX as usize) - 1);
538
539        let mut count = 1;
540        for i in (2048..=256000).step_by(1024) {
541            splinter.remove(i);
542            count += 1
543        }
544        assert_eq!(splinter.cardinality(), (u32::MAX as usize) - count);
545    }
546
547    proptest! {
548        #[test]
549        fn test_splinter_read_proptest(set in hash_set(0u32..16384, 0..1024)) {
550            let expected = set.iter().copied().sorted().collect_vec();
551            test_partition_read(&Splinter::from_iter(set), &expected);
552        }
553
554
555        #[test]
556        fn test_splinter_proptest(set in vec(0u32..16384, 0..1024)) {
557            let splinter = mksplinter(&set);
558            if set.is_empty() {
559                assert!(!splinter.contains(123));
560            } else {
561                let lookup = set[set.len() / 3];
562                assert!(splinter.contains(lookup));
563            }
564        }
565
566        #[test]
567        fn test_splinter_opt_proptest(set in vec(0u32..16384, 0..1024))  {
568            let mut splinter = mksplinter(&set);
569            splinter.optimize();
570            if set.is_empty() {
571                assert!(!splinter.contains(123));
572            } else {
573                let lookup = set[set.len() / 3];
574                assert!(splinter.contains(lookup));
575            }
576        }
577
578        #[test]
579        fn test_splinter_eq_proptest(set in vec(0u32..16384, 0..1024)) {
580            let a = mksplinter(&set);
581            assert_eq!(a, a.clone());
582        }
583
584        #[test]
585        fn test_splinter_opt_eq_proptest(set in vec(0u32..16384, 0..1024)) {
586            let mut a = mksplinter(&set);
587            let b = mksplinter(&set);
588            a.optimize();
589            assert_eq!(a, b);
590        }
591
592        #[test]
593        fn test_splinter_remove_range_proptest(set in hash_set(0u32..16384, 0..1024)) {
594            let expected = set.iter().copied().sorted().collect_vec();
595            let mut splinter = mksplinter(&expected);
596            if let Some(last) = expected.last() {
597                splinter.remove_range((Bound::Excluded(last), Bound::Unbounded));
598                assert_equal(splinter.iter(), expected);
599            }
600        }
601    }
602
603    // -- Hegel property-based tests --
604
605    use hegel::generators;
606
607    /// Iter always produces sorted, deduplicated output.
608    #[hegel::test]
609    fn test_iter_sorted_and_deduped(tc: hegel::TestCase) {
610        let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
611        let splinter = Splinter::from_iter(values);
612        let items: Vec<u32> = splinter.iter().collect();
613        for window in items.windows(2) {
614            assert!(
615                window[0] < window[1],
616                "iter not strictly sorted: {window:?}"
617            );
618        }
619    }
620
621    /// Cardinality equals the number of items yielded by iter.
622    #[hegel::test]
623    fn test_cardinality_equals_iter_count(tc: hegel::TestCase) {
624        let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
625        let splinter = Splinter::from_iter(values);
626        assert_eq!(splinter.cardinality(), splinter.iter().count());
627    }
628
629    /// Every inserted value is contained; every iterated value is contained.
630    #[hegel::test]
631    fn test_contains_all_inserted_values(tc: hegel::TestCase) {
632        let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
633        let splinter = Splinter::from_iter(values.iter().copied());
634        for &v in &values {
635            assert!(splinter.contains(v), "missing value {v}");
636        }
637    }
638
639    /// Insert returns true for new values, false for duplicates.
640    #[hegel::test]
641    fn test_insert_returns_correct_bool(tc: hegel::TestCase) {
642        let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
643        let mut splinter = Splinter::EMPTY;
644        let mut seen = std::collections::HashSet::new();
645        for v in values {
646            let was_new = seen.insert(v);
647            assert_eq!(splinter.insert(v), was_new);
648        }
649    }
650
651    /// Remove returns true when value was present, false otherwise.
652    #[hegel::test]
653    fn test_remove_returns_correct_bool(tc: hegel::TestCase) {
654        let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
655        let mut splinter = Splinter::from_iter(values.iter().copied());
656        let to_remove: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
657        let mut present: std::collections::HashSet<u32> = values.into_iter().collect();
658        for v in to_remove {
659            let was_present = present.remove(&v);
660            assert_eq!(splinter.remove(v), was_present);
661        }
662    }
663
664    /// Optimize preserves the set of elements.
665    #[hegel::test]
666    fn test_optimize_preserves_elements(tc: hegel::TestCase) {
667        let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
668        let mut splinter = Splinter::from_iter(values.iter().copied());
669        let before: Vec<u32> = splinter.iter().collect();
670        splinter.optimize();
671        let after: Vec<u32> = splinter.iter().collect();
672        assert_eq!(before, after);
673    }
674
675    /// Optimize is idempotent: optimizing twice gives the same result.
676    #[hegel::test]
677    fn test_optimize_idempotent(tc: hegel::TestCase) {
678        let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
679        let mut splinter = Splinter::from_iter(values);
680        splinter.optimize();
681        let after_first = splinter.clone();
682        splinter.optimize();
683        assert_eq!(splinter, after_first);
684    }
685
686    /// Select and position are inverses: select(position(v)) == v and position(select(i)) == i.
687    #[hegel::test]
688    fn test_select_position_inverse(tc: hegel::TestCase) {
689        let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()).min_size(1));
690        let splinter = Splinter::from_iter(values);
691        let cardinality = splinter.cardinality();
692        let idx = tc.draw(generators::integers::<usize>().max_value(cardinality - 1));
693        let value = splinter.select(idx).unwrap();
694        assert_eq!(splinter.position(value), Some(idx));
695    }
696
697    /// Rank is consistent: rank(v) == number of elements <= v.
698    #[hegel::test]
699    fn test_rank_consistency(tc: hegel::TestCase) {
700        let values: Vec<u32> =
701            tc.draw(generators::vecs(generators::integers::<u32>().max_value(65535)).min_size(1));
702        let splinter = Splinter::from_iter(values);
703        let query = tc.draw(generators::integers::<u32>().max_value(65535));
704        let rank = splinter.rank(query);
705        let count_leq = splinter.iter().filter(|&v| v <= query).count();
706        assert_eq!(rank, count_leq);
707    }
708
709    /// Encode/decode roundtrip: encode → `SplinterRef` → decode recovers the same set.
710    #[hegel::test]
711    fn test_encode_decode_roundtrip(tc: hegel::TestCase) {
712        let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
713        let mut splinter = Splinter::from_iter(values);
714        splinter.optimize();
715        let encoded = splinter.encode_to_bytes();
716        let splinter_ref = SplinterRef::from_bytes(encoded).unwrap();
717        let decoded = splinter_ref.decode_to_splinter();
718        assert_eq!(splinter, decoded);
719    }
720
721    /// `encoded_size` matches actual encoded byte length.
722    #[hegel::test]
723    fn test_encoded_size_matches(tc: hegel::TestCase) {
724        let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
725        let mut splinter = Splinter::from_iter(values);
726        splinter.optimize();
727        let declared_size = splinter.encoded_size();
728        let actual_bytes = splinter.encode_to_bytes();
729        assert_eq!(declared_size, actual_bytes.len());
730    }
731
732    /// Splinter from a range contains exactly the values in that range.
733    #[hegel::test]
734    fn test_from_range_contains_all(tc: hegel::TestCase) {
735        let mut a = tc.draw(generators::integers::<u16>());
736        let mut b = tc.draw(generators::integers::<u16>());
737        if a > b {
738            (a, b) = (b, a);
739        }
740        let start = a as u32;
741        let end = b as u32;
742        let splinter = Splinter::from(start..=end);
743        assert_eq!(splinter.cardinality(), (end - start + 1) as usize);
744        assert!(splinter.contains(start));
745        assert!(splinter.contains(end));
746        if start > 0 {
747            assert!(!splinter.contains(start - 1));
748        }
749        if end < u32::MAX {
750            assert!(!splinter.contains(end + 1));
751        }
752    }
753
754    #[test]
755    fn test_expected_compression() {
756        fn to_roaring(set: impl Iterator<Item = u32>) -> Vec<u8> {
757            let mut buf = std::io::Cursor::new(Vec::new());
758            let mut bmp = RoaringBitmap::from_sorted_iter(set).unwrap();
759            bmp.optimize();
760            bmp.serialize_into(&mut buf).unwrap();
761            buf.into_inner()
762        }
763
764        struct Report {
765            name: String,
766            baseline: usize,
767            //        (actual, expected)
768            splinter: (usize, usize),
769            roaring: (usize, usize),
770
771            splinter_lz4: usize,
772            roaring_lz4: usize,
773        }
774
775        let mut reports = vec![];
776
777        let mut run_test = |name: &str,
778                            set: Vec<u32>,
779                            expected_set_size: usize,
780                            expected_splinter: usize,
781                            expected_roaring: usize| {
782            assert_eq!(set.len(), expected_set_size, "Set size mismatch");
783
784            let mut splinter = Splinter::from_iter(set.clone());
785            splinter.optimize();
786            itertools::assert_equal(splinter.iter(), set.iter().copied());
787
788            test_partition_read(&splinter, &set);
789
790            let expected_size = splinter.encoded_size();
791            let splinter = splinter.encode_to_bytes();
792
793            assert_eq!(
794                splinter.len(),
795                expected_size,
796                "actual encoded size does not match declared encoded size"
797            );
798
799            let roaring = to_roaring(set.iter().copied());
800
801            let splinter_lz4 = lz4::block::compress(&splinter, None, false).unwrap();
802            let roaring_lz4 = lz4::block::compress(&roaring, None, false).unwrap();
803
804            // verify round trip
805            assert_eq!(
806                splinter,
807                lz4::block::decompress(&splinter_lz4, Some(splinter.len() as i32)).unwrap()
808            );
809            assert_eq!(
810                roaring,
811                lz4::block::decompress(&roaring_lz4, Some(roaring.len() as i32)).unwrap()
812            );
813
814            reports.push(Report {
815                name: name.to_owned(),
816                baseline: set.len() * std::mem::size_of::<u32>(),
817                splinter: (splinter.len(), expected_splinter),
818                roaring: (roaring.len(), expected_roaring),
819
820                splinter_lz4: splinter_lz4.len(),
821                roaring_lz4: roaring_lz4.len(),
822            });
823        };
824
825        let mut set_gen = SetGen::new(0xDEAD_BEEF);
826
827        // empty splinter
828        run_test("empty", vec![], 0, 13, 8);
829
830        // 1 element in set
831        let set = set_gen.distributed(1, 1, 1, 1);
832        run_test("1 element", set, 1, 21, 18);
833
834        // 1 fully dense block
835        let set = set_gen.distributed(1, 1, 1, 256);
836        run_test("1 dense block", set, 256, 25, 15);
837
838        // 1 half full block
839        let set = set_gen.distributed(1, 1, 1, 128);
840        run_test("1 half full block", set, 128, 72, 255);
841
842        // 1 sparse block
843        let set = set_gen.distributed(1, 1, 1, 16);
844        run_test("1 sparse block", set, 16, 57, 48);
845
846        // 8 half full blocks
847        let set = set_gen.distributed(1, 1, 8, 128);
848        run_test("8 half full blocks", set, 1024, 338, 2003);
849
850        // 8 sparse blocks
851        let set = set_gen.distributed(1, 1, 8, 2);
852        run_test("8 sparse blocks", set, 16, 67, 48);
853
854        // 64 half full blocks
855        let set = set_gen.distributed(4, 4, 4, 128);
856        run_test("64 half full blocks", set, 8192, 2634, 16452);
857
858        // 64 sparse blocks
859        let set = set_gen.distributed(4, 4, 4, 2);
860        run_test("64 sparse blocks", set, 128, 450, 392);
861
862        // 256 half full blocks
863        let set = set_gen.distributed(4, 8, 8, 128);
864        run_test("256 half full blocks", set, 32768, 10074, 65580);
865
866        // 256 sparse blocks
867        let set = set_gen.distributed(4, 8, 8, 2);
868        run_test("256 sparse blocks", set, 512, 1402, 1288);
869
870        // 512 half full blocks
871        let set = set_gen.distributed(8, 8, 8, 128);
872        run_test("512 half full blocks", set, 65536, 20134, 130810);
873
874        // 512 sparse blocks
875        let set = set_gen.distributed(8, 8, 8, 2);
876        run_test("512 sparse blocks", set, 1024, 2790, 2568);
877
878        // the rest of the compression tests use 4k elements
879        let elements = 4096;
880
881        // fully dense splinter
882        let set = set_gen.distributed(1, 1, 16, 256);
883        run_test("fully dense", set, elements, 87, 63);
884
885        // 128 elements per block; dense partitions
886        let set = set_gen.distributed(1, 1, 32, 128);
887        run_test("128/block; dense", set, elements, 1250, 8208);
888
889        // 32 elements per block; dense partitions
890        let set = set_gen.distributed(1, 1, 128, 32);
891        run_test("32/block; dense", set, elements, 4802, 8208);
892
893        // 16 element per block; dense low partitions
894        let set = set_gen.distributed(1, 1, 256, 16);
895        run_test("16/block; dense", set, elements, 5666, 8208);
896
897        // 128 elements per block; sparse mid partitions
898        let set = set_gen.distributed(1, 32, 1, 128);
899        run_test("128/block; sparse mid", set, elements, 1529, 8282);
900
901        // 128 elements per block; sparse high partitions
902        let set = set_gen.distributed(32, 1, 1, 128);
903        run_test("128/block; sparse high", set, elements, 1870, 8224);
904
905        // 1 element per block; sparse mid partitions
906        let set = set_gen.distributed(1, 256, 16, 1);
907        run_test("1/block; sparse mid", set, elements, 10521, 10248);
908
909        // 1 element per block; sparse high partitions
910        let set = set_gen.distributed(256, 16, 1, 1);
911        run_test("1/block; sparse high", set, elements, 15374, 40968);
912
913        // 1/block; spread low
914        let set = set_gen.dense(1, 16, 256, 1);
915        run_test("1/block; spread low", set, elements, 8377, 8328);
916
917        // each partition is dense
918        let set = set_gen.dense(8, 8, 8, 8);
919        run_test("dense throughout", set, elements, 2790, 2700);
920
921        // the lowest partitions are dense
922        let set = set_gen.dense(1, 1, 64, 64);
923        run_test("dense low", set, elements, 291, 267);
924
925        // the mid and low partitions are dense
926        let set = set_gen.dense(1, 32, 16, 8);
927        run_test("dense mid/low", set, elements, 2393, 2376);
928
929        let random_cases = [
930            // random sets drawing from the enire u32 range
931            (32, High::MAX_LEN, 145, 328),
932            (256, High::MAX_LEN, 1041, 2544),
933            (1024, High::MAX_LEN, 4113, 10168),
934            (4096, High::MAX_LEN, 15374, 40056),
935            (16384, High::MAX_LEN, 52238, 148656),
936            (65536, High::MAX_LEN, 199694, 461288),
937            // random sets with values < 65536
938            (32, 65536, 99, 80),
939            (256, 65536, 547, 528),
940            (1024, 65536, 2083, 2064),
941            (4096, 65536, 5666, 8208),
942            (65536, 65536, 25, 15),
943            // small sets with values < 1024
944            (8, 1024, 49, 32),
945            (16, 1024, 67, 48),
946            (32, 1024, 94, 80),
947            (64, 1024, 126, 144),
948            (128, 1024, 183, 272),
949        ];
950
951        for (count, max, expected_splinter, expected_roaring) in random_cases {
952            let name = if max == High::MAX_LEN {
953                format!("random/{count}")
954            } else {
955                format!("random/{count}/{max}")
956            };
957            run_test(
958                &name,
959                set_gen.random_max(count, max),
960                count,
961                expected_splinter,
962                expected_roaring,
963            );
964        }
965
966        let mut fail_test = false;
967
968        println!("{}", "-".repeat(83));
969        println!(
970            "{:30} {:12} {:>6} {:>10} {:>10} {:>10}",
971            "test", "bitmap", "size", "expected", "relative", "ok"
972        );
973        for report in &reports {
974            println!(
975                "{:30} {:12} {:6} {:10} {:>10} {:>10}",
976                report.name,
977                "Splinter",
978                report.splinter.0,
979                report.splinter.1,
980                "1.00",
981                if report.splinter.0 == report.splinter.1 {
982                    "ok"
983                } else {
984                    fail_test = true;
985                    "FAIL"
986                }
987            );
988
989            let diff = report.roaring.0 as f64 / report.splinter.0 as f64;
990            let ok_status = if report.roaring.0 != report.roaring.1 {
991                fail_test = true;
992                "FAIL".into()
993            } else {
994                ratio_to_marks(diff)
995            };
996            println!(
997                "{:30} {:12} {:6} {:10} {:>10.2} {:>10}",
998                "", "Roaring", report.roaring.0, report.roaring.1, diff, ok_status
999            );
1000
1001            let diff = report.splinter_lz4 as f64 / report.splinter.0 as f64;
1002            println!(
1003                "{:30} {:12} {:6} {:10} {:>10.2} {:>10}",
1004                "",
1005                "Splinter LZ4",
1006                report.splinter_lz4,
1007                report.splinter_lz4,
1008                diff,
1009                ratio_to_marks(diff)
1010            );
1011
1012            let diff = report.roaring_lz4 as f64 / report.splinter_lz4 as f64;
1013            println!(
1014                "{:30} {:12} {:6} {:10} {:>10.2} {:>10}",
1015                "",
1016                "Roaring LZ4",
1017                report.roaring_lz4,
1018                report.roaring_lz4,
1019                diff,
1020                ratio_to_marks(diff)
1021            );
1022
1023            let diff = report.baseline as f64 / report.splinter.0 as f64;
1024            println!(
1025                "{:30} {:12} {:6} {:10} {:>10.2} {:>10}",
1026                "",
1027                "Baseline",
1028                report.baseline,
1029                report.baseline,
1030                diff,
1031                ratio_to_marks(diff)
1032            );
1033        }
1034
1035        // calculate average compression ratio (splinter_lz4 / splinter)
1036        let avg_ratio = reports
1037            .iter()
1038            .map(|r| r.splinter_lz4 as f64 / r.splinter.0 as f64)
1039            .sum::<f64>()
1040            / reports.len() as f64;
1041
1042        println!("average compression ratio (splinter_lz4 / splinter): {avg_ratio:.2}");
1043
1044        assert!(!fail_test, "compression test failed");
1045    }
1046}