1use std::{fmt::Debug, ops::RangeBounds};
2
3use bytes::Bytes;
4
5use crate::{
6 Encodable, Optimizable, SplinterRef,
7 codec::{encoder::Encoder, footer::Footer},
8 level::High,
9 partition::Partition,
10 traits::{PartitionRead, PartitionWrite},
11 util::RangeExt,
12};
13
14#[derive(Clone, PartialEq, Eq, Default, Debug)]
56pub struct Splinter(Partition<High>);
57
58static_assertions::const_assert_eq!(std::mem::size_of::<Splinter>(), 40);
59
60impl Splinter {
61 pub const EMPTY: Self = Splinter(Partition::EMPTY);
63
64 pub const FULL: Self = Splinter(Partition::Full);
66
67 pub fn encode_to_splinter_ref(&self) -> SplinterRef<Bytes> {
85 SplinterRef { data: self.encode_to_bytes() }
86 }
87
88 #[inline(always)]
89 pub(crate) fn new(inner: Partition<High>) -> Self {
90 Self(inner)
91 }
92
93 #[inline(always)]
94 pub(crate) fn inner(&self) -> &Partition<High> {
95 &self.0
96 }
97
98 #[inline(always)]
99 pub(crate) fn inner_mut(&mut self) -> &mut Partition<High> {
100 &mut self.0
101 }
102}
103
104impl FromIterator<u32> for Splinter {
105 fn from_iter<I: IntoIterator<Item = u32>>(iter: I) -> Self {
106 Self(Partition::<High>::from_iter(iter))
107 }
108}
109
110impl<R: RangeBounds<u32>> From<R> for Splinter {
111 fn from(range: R) -> Self {
112 if let Some(range) = range.try_into_inclusive() {
113 if range.start() == &u32::MIN && range.end() == &u32::MAX {
114 Self::FULL
115 } else {
116 Self(Partition::<High>::from(range))
117 }
118 } else {
119 Self::EMPTY
121 }
122 }
123}
124
125impl PartitionRead<High> for Splinter {
126 #[inline]
140 fn cardinality(&self) -> usize {
141 self.0.cardinality()
142 }
143
144 #[inline]
158 fn is_empty(&self) -> bool {
159 self.0.is_empty()
160 }
161
162 #[inline]
176 fn contains(&self, value: u32) -> bool {
177 self.0.contains(value)
178 }
179
180 #[inline]
198 fn position(&self, value: u32) -> Option<usize> {
199 self.0.position(value)
200 }
201
202 #[inline]
220 fn rank(&self, value: u32) -> usize {
221 self.0.rank(value)
222 }
223
224 #[inline]
241 fn select(&self, idx: usize) -> Option<u32> {
242 self.0.select(idx)
243 }
244
245 #[inline]
260 fn last(&self) -> Option<u32> {
261 self.0.last()
262 }
263
264 #[inline]
277 fn iter(&self) -> impl Iterator<Item = u32> {
278 self.0.iter()
279 }
280}
281
282impl PartitionWrite<High> for Splinter {
283 #[inline]
307 fn insert(&mut self, value: u32) -> bool {
308 self.0.insert(value)
309 }
310
311 #[inline]
334 fn remove(&mut self, value: u32) -> bool {
335 self.0.remove(value)
336 }
337
338 #[inline]
363 fn remove_range<R: RangeBounds<u32>>(&mut self, values: R) {
364 self.0.remove_range(values);
365 }
366}
367
368impl Encodable for Splinter {
369 fn encoded_size(&self) -> usize {
370 self.0.encoded_size() + std::mem::size_of::<Footer>()
371 }
372
373 fn encode<B: bytes::BufMut>(&self, encoder: &mut Encoder<B>) {
374 self.0.encode(encoder);
375 encoder.write_footer();
376 }
377}
378
379impl Optimizable for Splinter {
380 #[inline]
381 fn optimize(&mut self) {
382 self.0.optimize();
383 }
384}
385
386impl Extend<u32> for Splinter {
387 #[inline]
388 fn extend<T: IntoIterator<Item = u32>>(&mut self, iter: T) {
389 self.0.extend(iter);
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use std::ops::Bound;
396
397 use super::*;
398 use crate::{
399 codec::Encodable,
400 level::{Level, Low},
401 testutil::{SetGen, mksplinter, ratio_to_marks, test_partition_read, test_partition_write},
402 traits::Optimizable,
403 };
404 use itertools::{Itertools, assert_equal};
405 use proptest::{
406 collection::{hash_set, vec},
407 proptest,
408 };
409 use rand::{SeedableRng, seq::index};
410 use roaring::RoaringBitmap;
411
412 #[test]
413 fn test_sanity() {
414 let mut splinter = Splinter::EMPTY;
415
416 assert!(splinter.insert(1));
417 assert!(!splinter.insert(1));
418 assert!(splinter.contains(1));
419
420 let values = [1024, 123, 16384];
421 for v in values {
422 assert!(splinter.insert(v));
423 assert!(splinter.contains(v));
424 assert!(!splinter.contains(v + 1));
425 }
426
427 for i in 0..8192 + 10 {
428 splinter.insert(i);
429 }
430
431 splinter.optimize();
432
433 dbg!(&splinter);
434
435 let expected = splinter.iter().collect_vec();
436 test_partition_read(&splinter, &expected);
437 test_partition_write(&mut splinter);
438 }
439
440 #[test]
441 fn test_wat() {
442 let mut set_gen = SetGen::new(0xDEAD_BEEF);
443 let set = set_gen.random_max(64, 4096);
444 let baseline_size = set.len() * 4;
445
446 let mut splinter = Splinter::from_iter(set.iter().copied());
447 splinter.optimize();
448
449 dbg!(&splinter, splinter.encoded_size(), baseline_size, set.len());
450 itertools::assert_equal(splinter.iter(), set.into_iter());
451 }
452
453 #[test]
454 fn test_splinter_write() {
455 let mut splinter = Splinter::from_iter(0u32..16384);
456 test_partition_write(&mut splinter);
457 }
458
459 #[test]
460 fn test_splinter_optimize_growth() {
461 let mut splinter = Splinter::EMPTY;
462 let mut rng = rand::rngs::StdRng::seed_from_u64(0xdeadbeef);
463 let set = index::sample(&mut rng, Low::MAX_LEN, 8);
464 dbg!(&splinter);
465 for i in set {
466 splinter.insert(i as u32);
467 dbg!(&splinter);
468 }
469 }
470
471 #[test]
472 fn test_splinter_from_range() {
473 let splinter = Splinter::from(..);
474 assert_eq!(splinter.cardinality(), (u32::MAX as usize) + 1);
475
476 let mut splinter = Splinter::from(1..);
477 assert_eq!(splinter.cardinality(), u32::MAX as usize);
478
479 splinter.remove(1024);
480 assert_eq!(splinter.cardinality(), (u32::MAX as usize) - 1);
481
482 let mut count = 1;
483 for i in (2048..=256000).step_by(1024) {
484 splinter.remove(i);
485 count += 1
486 }
487 assert_eq!(splinter.cardinality(), (u32::MAX as usize) - count);
488 }
489
490 proptest! {
491 #[test]
492 fn test_splinter_read_proptest(set in hash_set(0u32..16384, 0..1024)) {
493 let expected = set.iter().copied().sorted().collect_vec();
494 test_partition_read(&Splinter::from_iter(set), &expected);
495 }
496
497
498 #[test]
499 fn test_splinter_proptest(set in vec(0u32..16384, 0..1024)) {
500 let splinter = mksplinter(&set);
501 if set.is_empty() {
502 assert!(!splinter.contains(123));
503 } else {
504 let lookup = set[set.len() / 3];
505 assert!(splinter.contains(lookup));
506 }
507 }
508
509 #[test]
510 fn test_splinter_opt_proptest(set in vec(0u32..16384, 0..1024)) {
511 let mut splinter = mksplinter(&set);
512 splinter.optimize();
513 if set.is_empty() {
514 assert!(!splinter.contains(123));
515 } else {
516 let lookup = set[set.len() / 3];
517 assert!(splinter.contains(lookup));
518 }
519 }
520
521 #[test]
522 fn test_splinter_eq_proptest(set in vec(0u32..16384, 0..1024)) {
523 let a = mksplinter(&set);
524 assert_eq!(a, a.clone());
525 }
526
527 #[test]
528 fn test_splinter_opt_eq_proptest(set in vec(0u32..16384, 0..1024)) {
529 let mut a = mksplinter(&set);
530 let b = mksplinter(&set);
531 a.optimize();
532 assert_eq!(a, b);
533 }
534
535 #[test]
536 fn test_splinter_remove_range_proptest(set in hash_set(0u32..16384, 0..1024)) {
537 let expected = set.iter().copied().sorted().collect_vec();
538 let mut splinter = mksplinter(&expected);
539 if let Some(last) = expected.last() {
540 splinter.remove_range((Bound::Excluded(last), Bound::Unbounded));
541 assert_equal(splinter.iter(), expected);
542 }
543 }
544 }
545
546 #[test]
547 fn test_expected_compression() {
548 fn to_roaring(set: impl Iterator<Item = u32>) -> Vec<u8> {
549 let mut buf = std::io::Cursor::new(Vec::new());
550 let mut bmp = RoaringBitmap::from_sorted_iter(set).unwrap();
551 bmp.optimize();
552 bmp.serialize_into(&mut buf).unwrap();
553 buf.into_inner()
554 }
555
556 struct Report {
557 name: String,
558 baseline: usize,
559 splinter: (usize, usize),
561 roaring: (usize, usize),
562
563 splinter_lz4: usize,
564 roaring_lz4: usize,
565 }
566
567 let mut reports = vec![];
568
569 let mut run_test = |name: &str,
570 set: Vec<u32>,
571 expected_set_size: usize,
572 expected_splinter: usize,
573 expected_roaring: usize| {
574 assert_eq!(set.len(), expected_set_size, "Set size mismatch");
575
576 let mut splinter = Splinter::from_iter(set.clone());
577 splinter.optimize();
578 itertools::assert_equal(splinter.iter(), set.iter().copied());
579
580 test_partition_read(&splinter, &set);
581
582 let expected_size = splinter.encoded_size();
583 let splinter = splinter.encode_to_bytes();
584
585 assert_eq!(
586 splinter.len(),
587 expected_size,
588 "actual encoded size does not match declared encoded size"
589 );
590
591 let roaring = to_roaring(set.iter().copied());
592
593 let splinter_lz4 = lz4::block::compress(&splinter, None, false).unwrap();
594 let roaring_lz4 = lz4::block::compress(&roaring, None, false).unwrap();
595
596 assert_eq!(
598 splinter,
599 lz4::block::decompress(&splinter_lz4, Some(splinter.len() as i32)).unwrap()
600 );
601 assert_eq!(
602 roaring,
603 lz4::block::decompress(&roaring_lz4, Some(roaring.len() as i32)).unwrap()
604 );
605
606 reports.push(Report {
607 name: name.to_owned(),
608 baseline: set.len() * std::mem::size_of::<u32>(),
609 splinter: (splinter.len(), expected_splinter),
610 roaring: (roaring.len(), expected_roaring),
611
612 splinter_lz4: splinter_lz4.len(),
613 roaring_lz4: roaring_lz4.len(),
614 });
615 };
616
617 let mut set_gen = SetGen::new(0xDEAD_BEEF);
618
619 run_test("empty", vec![], 0, 13, 8);
621
622 let set = set_gen.distributed(1, 1, 1, 1);
624 run_test("1 element", set, 1, 21, 18);
625
626 let set = set_gen.distributed(1, 1, 1, 256);
628 run_test("1 dense block", set, 256, 25, 15);
629
630 let set = set_gen.distributed(1, 1, 1, 128);
632 run_test("1 half full block", set, 128, 63, 255);
633
634 let set = set_gen.distributed(1, 1, 1, 16);
636 run_test("1 sparse block", set, 16, 48, 48);
637
638 let set = set_gen.distributed(1, 1, 8, 128);
640 run_test("8 half full blocks", set, 1024, 315, 2003);
641
642 let set = set_gen.distributed(1, 1, 8, 2);
644 run_test("8 sparse blocks", set, 16, 60, 48);
645
646 let set = set_gen.distributed(4, 4, 4, 128);
648 run_test("64 half full blocks", set, 8192, 2442, 16452);
649
650 let set = set_gen.distributed(4, 4, 4, 2);
652 run_test("64 sparse blocks", set, 128, 410, 392);
653
654 let set = set_gen.distributed(4, 8, 8, 128);
656 run_test("256 half full blocks", set, 32768, 9450, 65580);
657
658 let set = set_gen.distributed(4, 8, 8, 2);
660 run_test("256 sparse blocks", set, 512, 1290, 1288);
661
662 let set = set_gen.distributed(8, 8, 8, 128);
664 run_test("512 half full blocks", set, 65536, 18886, 130810);
665
666 let set = set_gen.distributed(8, 8, 8, 2);
668 run_test("512 sparse blocks", set, 1024, 2566, 2568);
669
670 let elements = 4096;
672
673 let set = set_gen.distributed(1, 1, 16, 256);
675 run_test("fully dense", set, elements, 80, 63);
676
677 let set = set_gen.distributed(1, 1, 32, 128);
679 run_test("128/block; dense", set, elements, 1179, 8208);
680
681 let set = set_gen.distributed(1, 1, 128, 32);
683 run_test("32/block; dense", set, elements, 4539, 8208);
684
685 let set = set_gen.distributed(1, 1, 256, 16);
687 run_test("16/block; dense", set, elements, 5147, 8208);
688
689 let set = set_gen.distributed(1, 32, 1, 128);
691 run_test("128/block; sparse mid", set, elements, 1365, 8282);
692
693 let set = set_gen.distributed(32, 1, 1, 128);
695 run_test("128/block; sparse high", set, elements, 1582, 8224);
696
697 let set = set_gen.distributed(1, 256, 16, 1);
699 run_test("1/block; sparse mid", set, elements, 9749, 10248);
700
701 let set = set_gen.distributed(256, 16, 1, 1);
703 run_test("1/block; sparse high", set, elements, 14350, 40968);
704
705 let set = set_gen.dense(1, 16, 256, 1);
707 run_test("1/block; spread low", set, elements, 8325, 8328);
708
709 let set = set_gen.dense(8, 8, 8, 8);
711 run_test("dense throughout", set, elements, 4113, 2700);
712
713 let set = set_gen.dense(1, 1, 64, 64);
715 run_test("dense low", set, elements, 529, 267);
716
717 let set = set_gen.dense(1, 32, 16, 8);
719 run_test("dense mid/low", set, elements, 4113, 2376);
720
721 let random_cases = [
722 (32, High::MAX_LEN, 145, 328),
724 (256, High::MAX_LEN, 1041, 2544),
725 (1024, High::MAX_LEN, 4113, 10168),
726 (4096, High::MAX_LEN, 14350, 40056),
727 (16384, High::MAX_LEN, 51214, 148656),
728 (65536, High::MAX_LEN, 198670, 461288),
729 (32, 65536, 92, 80),
731 (256, 65536, 540, 528),
732 (1024, 65536, 2071, 2064),
733 (4096, 65536, 5147, 8208),
734 (65536, 65536, 25, 15),
735 (8, 1024, 44, 32),
737 (16, 1024, 60, 48),
738 (32, 1024, 79, 80),
739 (64, 1024, 111, 144),
740 (128, 1024, 168, 272),
741 ];
742
743 for (count, max, expected_splinter, expected_roaring) in random_cases {
744 let name = if max == High::MAX_LEN {
745 format!("random/{count}")
746 } else {
747 format!("random/{count}/{max}")
748 };
749 run_test(
750 &name,
751 set_gen.random_max(count, max),
752 count,
753 expected_splinter,
754 expected_roaring,
755 );
756 }
757
758 let mut fail_test = false;
759
760 println!("{}", "-".repeat(83));
761 println!(
762 "{:30} {:12} {:>6} {:>10} {:>10} {:>10}",
763 "test", "bitmap", "size", "expected", "relative", "ok"
764 );
765 for report in &reports {
766 println!(
767 "{:30} {:12} {:6} {:10} {:>10} {:>10}",
768 report.name,
769 "Splinter",
770 report.splinter.0,
771 report.splinter.1,
772 "1.00",
773 if report.splinter.0 == report.splinter.1 {
774 "ok"
775 } else {
776 fail_test = true;
777 "FAIL"
778 }
779 );
780
781 let diff = report.roaring.0 as f64 / report.splinter.0 as f64;
782 let ok_status = if report.roaring.0 != report.roaring.1 {
783 fail_test = true;
784 "FAIL".into()
785 } else {
786 ratio_to_marks(diff)
787 };
788 println!(
789 "{:30} {:12} {:6} {:10} {:>10.2} {:>10}",
790 "", "Roaring", report.roaring.0, report.roaring.1, diff, ok_status
791 );
792
793 let diff = report.splinter_lz4 as f64 / report.splinter.0 as f64;
794 println!(
795 "{:30} {:12} {:6} {:10} {:>10.2} {:>10}",
796 "",
797 "Splinter LZ4",
798 report.splinter_lz4,
799 report.splinter_lz4,
800 diff,
801 ratio_to_marks(diff)
802 );
803
804 let diff = report.roaring_lz4 as f64 / report.splinter_lz4 as f64;
805 println!(
806 "{:30} {:12} {:6} {:10} {:>10.2} {:>10}",
807 "",
808 "Roaring LZ4",
809 report.roaring_lz4,
810 report.roaring_lz4,
811 diff,
812 ratio_to_marks(diff)
813 );
814
815 let diff = report.baseline as f64 / report.splinter.0 as f64;
816 println!(
817 "{:30} {:12} {:6} {:10} {:>10.2} {:>10}",
818 "",
819 "Baseline",
820 report.baseline,
821 report.baseline,
822 diff,
823 ratio_to_marks(diff)
824 );
825 }
826
827 let avg_ratio = reports
829 .iter()
830 .map(|r| r.splinter_lz4 as f64 / r.splinter.0 as f64)
831 .sum::<f64>()
832 / reports.len() as f64;
833
834 println!("average compression ratio (splinter_lz4 / splinter): {avg_ratio:.2}");
835
836 assert!(!fail_test, "compression test failed");
837 }
838}