1use std::{fmt::Debug, ops::RangeBounds};
2
3use bytes::Bytes;
4
5use crate::{
6 Encodable, Optimizable, SplinterRef,
7 codec::{encoder::Encoder, footer::Footer},
8 level::High,
9 partition::Partition,
10 traits::{PartitionRead, PartitionWrite},
11 util::RangeExt,
12};
13
14#[derive(Clone, PartialEq, Eq, Default, Debug)]
56pub struct Splinter(Partition<High>);
57
58static_assertions::const_assert_eq!(std::mem::size_of::<Splinter>(), 40);
59
60impl Splinter {
61 pub const EMPTY: Self = Splinter(Partition::EMPTY);
63
64 pub const FULL: Self = Splinter(Partition::Full);
66
67 pub fn encode_to_splinter_ref(&self) -> SplinterRef<Bytes> {
85 SplinterRef { data: self.encode_to_bytes() }
86 }
87
88 #[inline(always)]
89 pub(crate) fn new(inner: Partition<High>) -> Self {
90 Self(inner)
91 }
92
93 #[inline(always)]
94 pub(crate) fn inner(&self) -> &Partition<High> {
95 &self.0
96 }
97
98 #[inline(always)]
99 pub(crate) fn inner_mut(&mut self) -> &mut Partition<High> {
100 &mut self.0
101 }
102}
103
104impl FromIterator<u32> for Splinter {
105 fn from_iter<I: IntoIterator<Item = u32>>(iter: I) -> Self {
106 Self(Partition::<High>::from_iter(iter))
107 }
108}
109
110impl<R: RangeBounds<u32>> From<R> for Splinter {
111 fn from(range: R) -> Self {
112 if let Some(range) = range.try_into_inclusive() {
113 if range.start() == &u32::MIN && range.end() == &u32::MAX {
114 Self::FULL
115 } else {
116 Self(Partition::<High>::from(range))
117 }
118 } else {
119 Self::EMPTY
121 }
122 }
123}
124
125impl PartitionRead<High> for Splinter {
126 #[inline]
140 fn cardinality(&self) -> usize {
141 self.0.cardinality()
142 }
143
144 #[inline]
158 fn is_empty(&self) -> bool {
159 self.0.is_empty()
160 }
161
162 #[inline]
176 fn contains(&self, value: u32) -> bool {
177 self.0.contains(value)
178 }
179
180 #[inline]
198 fn position(&self, value: u32) -> Option<usize> {
199 self.0.position(value)
200 }
201
202 #[inline]
220 fn rank(&self, value: u32) -> usize {
221 self.0.rank(value)
222 }
223
224 #[inline]
241 fn select(&self, idx: usize) -> Option<u32> {
242 self.0.select(idx)
243 }
244
245 #[inline]
260 fn last(&self) -> Option<u32> {
261 self.0.last()
262 }
263
264 #[inline]
277 fn iter(&self) -> impl Iterator<Item = u32> {
278 self.0.iter()
279 }
280}
281
282impl PartitionWrite<High> for Splinter {
283 #[inline]
307 fn insert(&mut self, value: u32) -> bool {
308 self.0.insert(value)
309 }
310
311 #[inline]
334 fn remove(&mut self, value: u32) -> bool {
335 self.0.remove(value)
336 }
337
338 #[inline]
363 fn remove_range<R: RangeBounds<u32>>(&mut self, values: R) {
364 self.0.remove_range(values);
365 }
366}
367
368impl Encodable for Splinter {
369 fn encoded_size(&self) -> usize {
370 self.0.encoded_size() + std::mem::size_of::<Footer>()
371 }
372
373 fn encode<B: bytes::BufMut>(&self, encoder: &mut Encoder<B>) {
374 self.0.encode(encoder);
375 encoder.write_footer();
376 }
377}
378
379impl Optimizable for Splinter {
380 #[inline]
381 fn optimize(&mut self) {
382 self.0.optimize();
383 }
384}
385
386impl Extend<u32> for Splinter {
387 #[inline]
388 fn extend<T: IntoIterator<Item = u32>>(&mut self, iter: T) {
389 self.0.extend(iter);
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use crate::{
397 codec::Encodable,
398 level::{Level, Low},
399 testutil::{SetGen, mksplinter, ratio_to_marks, test_partition_read, test_partition_write},
400 traits::Optimizable,
401 };
402 use itertools::Itertools;
403 use proptest::{
404 collection::{hash_set, vec},
405 proptest,
406 };
407 use rand::{SeedableRng, seq::index};
408 use roaring::RoaringBitmap;
409
410 #[test]
411 fn test_sanity() {
412 let mut splinter = Splinter::EMPTY;
413
414 assert!(splinter.insert(1));
415 assert!(!splinter.insert(1));
416 assert!(splinter.contains(1));
417
418 let values = [1024, 123, 16384];
419 for v in values {
420 assert!(splinter.insert(v));
421 assert!(splinter.contains(v));
422 assert!(!splinter.contains(v + 1));
423 }
424
425 for i in 0..8192 + 10 {
426 splinter.insert(i);
427 }
428
429 splinter.optimize();
430
431 dbg!(&splinter);
432
433 let expected = splinter.iter().collect_vec();
434 test_partition_read(&splinter, &expected);
435 test_partition_write(&mut splinter);
436 }
437
438 #[test]
439 fn test_wat() {
440 let mut set_gen = SetGen::new(0xDEAD_BEEF);
441 let set = set_gen.random_max(64, 4096);
442 let baseline_size = set.len() * 4;
443
444 let mut splinter = Splinter::from_iter(set.iter().copied());
445 splinter.optimize();
446
447 dbg!(&splinter, splinter.encoded_size(), baseline_size, set.len());
448 itertools::assert_equal(splinter.iter(), set.into_iter());
449 }
450
451 #[test]
452 fn test_splinter_write() {
453 let mut splinter = Splinter::from_iter(0u32..16384);
454 test_partition_write(&mut splinter);
455 }
456
457 #[test]
458 fn test_splinter_optimize_growth() {
459 let mut splinter = Splinter::EMPTY;
460 let mut rng = rand::rngs::StdRng::seed_from_u64(0xdeadbeef);
461 let set = index::sample(&mut rng, Low::MAX_LEN, 8);
462 dbg!(&splinter);
463 for i in set {
464 splinter.insert(i as u32);
465 dbg!(&splinter);
466 }
467 }
468
469 #[test]
470 fn test_splinter_from_range() {
471 let splinter = Splinter::from(..);
472 assert_eq!(splinter.cardinality(), (u32::MAX as usize) + 1);
473
474 let mut splinter = Splinter::from(1..);
475 assert_eq!(splinter.cardinality(), u32::MAX as usize);
476
477 splinter.remove(1024);
478 assert_eq!(splinter.cardinality(), (u32::MAX as usize) - 1);
479
480 let mut count = 1;
481 for i in (2048..=256000).step_by(1024) {
482 splinter.remove(i);
483 count += 1
484 }
485 assert_eq!(splinter.cardinality(), (u32::MAX as usize) - count);
486 }
487
488 proptest! {
489 #[test]
490 fn test_splinter_read_proptest(set in hash_set(0u32..16384, 0..1024)) {
491 let expected = set.iter().copied().sorted().collect_vec();
492 test_partition_read(&Splinter::from_iter(set), &expected);
493 }
494
495
496 #[test]
497 fn test_splinter_proptest(set in vec(0u32..16384, 0..1024)) {
498 let splinter = mksplinter(&set);
499 if set.is_empty() {
500 assert!(!splinter.contains(123));
501 } else {
502 let lookup = set[set.len() / 3];
503 assert!(splinter.contains(lookup));
504 }
505 }
506
507 #[test]
508 fn test_splinter_opt_proptest(set in vec(0u32..16384, 0..1024)) {
509 let mut splinter = mksplinter(&set);
510 splinter.optimize();
511 if set.is_empty() {
512 assert!(!splinter.contains(123));
513 } else {
514 let lookup = set[set.len() / 3];
515 assert!(splinter.contains(lookup));
516 }
517 }
518
519 #[test]
520 fn test_splinter_eq_proptest(set in vec(0u32..16384, 0..1024)) {
521 let a = mksplinter(&set);
522 assert_eq!(a, a.clone());
523 }
524
525 #[test]
526 fn test_splinter_opt_eq_proptest(set in vec(0u32..16384, 0..1024)) {
527 let mut a = mksplinter(&set);
528 let b = mksplinter(&set);
529 a.optimize();
530 assert_eq!(a, b);
531 }
532 }
533
534 #[test]
535 fn test_expected_compression() {
536 fn to_roaring(set: impl Iterator<Item = u32>) -> Vec<u8> {
537 let mut buf = std::io::Cursor::new(Vec::new());
538 let mut bmp = RoaringBitmap::from_sorted_iter(set).unwrap();
539 bmp.optimize();
540 bmp.serialize_into(&mut buf).unwrap();
541 buf.into_inner()
542 }
543
544 struct Report {
545 name: String,
546 baseline: usize,
547 splinter: (usize, usize),
549 roaring: (usize, usize),
550
551 splinter_lz4: usize,
552 roaring_lz4: usize,
553 }
554
555 let mut reports = vec![];
556
557 let mut run_test = |name: &str,
558 set: Vec<u32>,
559 expected_set_size: usize,
560 expected_splinter: usize,
561 expected_roaring: usize| {
562 assert_eq!(set.len(), expected_set_size, "Set size mismatch");
563
564 let mut splinter = Splinter::from_iter(set.clone());
565 splinter.optimize();
566 itertools::assert_equal(splinter.iter(), set.iter().copied());
567
568 test_partition_read(&splinter, &set);
569
570 let expected_size = splinter.encoded_size();
571 let splinter = splinter.encode_to_bytes();
572
573 assert_eq!(
574 splinter.len(),
575 expected_size,
576 "actual encoded size does not match declared encoded size"
577 );
578
579 let roaring = to_roaring(set.iter().copied());
580
581 let splinter_lz4 = lz4::block::compress(&splinter, None, false).unwrap();
582 let roaring_lz4 = lz4::block::compress(&roaring, None, false).unwrap();
583
584 assert_eq!(
586 splinter,
587 lz4::block::decompress(&splinter_lz4, Some(splinter.len() as i32)).unwrap()
588 );
589 assert_eq!(
590 roaring,
591 lz4::block::decompress(&roaring_lz4, Some(roaring.len() as i32)).unwrap()
592 );
593
594 reports.push(Report {
595 name: name.to_owned(),
596 baseline: set.len() * std::mem::size_of::<u32>(),
597 splinter: (splinter.len(), expected_splinter),
598 roaring: (roaring.len(), expected_roaring),
599
600 splinter_lz4: splinter_lz4.len(),
601 roaring_lz4: roaring_lz4.len(),
602 });
603 };
604
605 let mut set_gen = SetGen::new(0xDEAD_BEEF);
606
607 run_test("empty", vec![], 0, 13, 8);
609
610 let set = set_gen.distributed(1, 1, 1, 1);
612 run_test("1 element", set, 1, 21, 18);
613
614 let set = set_gen.distributed(1, 1, 1, 256);
616 run_test("1 dense block", set, 256, 25, 15);
617
618 let set = set_gen.distributed(1, 1, 1, 128);
620 run_test("1 half full block", set, 128, 63, 255);
621
622 let set = set_gen.distributed(1, 1, 1, 16);
624 run_test("1 sparse block", set, 16, 48, 48);
625
626 let set = set_gen.distributed(1, 1, 8, 128);
628 run_test("8 half full blocks", set, 1024, 315, 2003);
629
630 let set = set_gen.distributed(1, 1, 8, 2);
632 run_test("8 sparse blocks", set, 16, 60, 48);
633
634 let set = set_gen.distributed(4, 4, 4, 128);
636 run_test("64 half full blocks", set, 8192, 2442, 16452);
637
638 let set = set_gen.distributed(4, 4, 4, 2);
640 run_test("64 sparse blocks", set, 128, 410, 392);
641
642 let set = set_gen.distributed(4, 8, 8, 128);
644 run_test("256 half full blocks", set, 32768, 9450, 65580);
645
646 let set = set_gen.distributed(4, 8, 8, 2);
648 run_test("256 sparse blocks", set, 512, 1290, 1288);
649
650 let set = set_gen.distributed(8, 8, 8, 128);
652 run_test("512 half full blocks", set, 65536, 18886, 130810);
653
654 let set = set_gen.distributed(8, 8, 8, 2);
656 run_test("512 sparse blocks", set, 1024, 2566, 2568);
657
658 let elements = 4096;
660
661 let set = set_gen.distributed(1, 1, 16, 256);
663 run_test("fully dense", set, elements, 80, 63);
664
665 let set = set_gen.distributed(1, 1, 32, 128);
667 run_test("128/block; dense", set, elements, 1179, 8208);
668
669 let set = set_gen.distributed(1, 1, 128, 32);
671 run_test("32/block; dense", set, elements, 4539, 8208);
672
673 let set = set_gen.distributed(1, 1, 256, 16);
675 run_test("16/block; dense", set, elements, 5147, 8208);
676
677 let set = set_gen.distributed(1, 32, 1, 128);
679 run_test("128/block; sparse mid", set, elements, 1365, 8282);
680
681 let set = set_gen.distributed(32, 1, 1, 128);
683 run_test("128/block; sparse high", set, elements, 1582, 8224);
684
685 let set = set_gen.distributed(1, 256, 16, 1);
687 run_test("1/block; sparse mid", set, elements, 9749, 10248);
688
689 let set = set_gen.distributed(256, 16, 1, 1);
691 run_test("1/block; sparse high", set, elements, 14350, 40968);
692
693 let set = set_gen.dense(1, 16, 256, 1);
695 run_test("1/block; spread low", set, elements, 8325, 8328);
696
697 let set = set_gen.dense(8, 8, 8, 8);
699 run_test("dense throughout", set, elements, 4113, 2700);
700
701 let set = set_gen.dense(1, 1, 64, 64);
703 run_test("dense low", set, elements, 529, 267);
704
705 let set = set_gen.dense(1, 32, 16, 8);
707 run_test("dense mid/low", set, elements, 4113, 2376);
708
709 let random_cases = [
710 (32, High::MAX_LEN, 145, 328),
712 (256, High::MAX_LEN, 1041, 2544),
713 (1024, High::MAX_LEN, 4113, 10168),
714 (4096, High::MAX_LEN, 14350, 40056),
715 (16384, High::MAX_LEN, 51214, 148656),
716 (65536, High::MAX_LEN, 198670, 461288),
717 (32, 65536, 92, 80),
719 (256, 65536, 540, 528),
720 (1024, 65536, 2071, 2064),
721 (4096, 65536, 5147, 8208),
722 (65536, 65536, 25, 15),
723 (8, 1024, 44, 32),
725 (16, 1024, 60, 48),
726 (32, 1024, 79, 80),
727 (64, 1024, 111, 144),
728 (128, 1024, 168, 272),
729 ];
730
731 for (count, max, expected_splinter, expected_roaring) in random_cases {
732 let name = if max == High::MAX_LEN {
733 format!("random/{count}")
734 } else {
735 format!("random/{count}/{max}")
736 };
737 run_test(
738 &name,
739 set_gen.random_max(count, max),
740 count,
741 expected_splinter,
742 expected_roaring,
743 );
744 }
745
746 let mut fail_test = false;
747
748 println!("{}", "-".repeat(83));
749 println!(
750 "{:30} {:12} {:>6} {:>10} {:>10} {:>10}",
751 "test", "bitmap", "size", "expected", "relative", "ok"
752 );
753 for report in &reports {
754 println!(
755 "{:30} {:12} {:6} {:10} {:>10} {:>10}",
756 report.name,
757 "Splinter",
758 report.splinter.0,
759 report.splinter.1,
760 "1.00",
761 if report.splinter.0 == report.splinter.1 {
762 "ok"
763 } else {
764 fail_test = true;
765 "FAIL"
766 }
767 );
768
769 let diff = report.roaring.0 as f64 / report.splinter.0 as f64;
770 let ok_status = if report.roaring.0 != report.roaring.1 {
771 fail_test = true;
772 "FAIL".into()
773 } else {
774 ratio_to_marks(diff)
775 };
776 println!(
777 "{:30} {:12} {:6} {:10} {:>10.2} {:>10}",
778 "", "Roaring", report.roaring.0, report.roaring.1, diff, ok_status
779 );
780
781 let diff = report.splinter_lz4 as f64 / report.splinter.0 as f64;
782 println!(
783 "{:30} {:12} {:6} {:10} {:>10.2} {:>10}",
784 "",
785 "Splinter LZ4",
786 report.splinter_lz4,
787 report.splinter_lz4,
788 diff,
789 ratio_to_marks(diff)
790 );
791
792 let diff = report.roaring_lz4 as f64 / report.splinter_lz4 as f64;
793 println!(
794 "{:30} {:12} {:6} {:10} {:>10.2} {:>10}",
795 "",
796 "Roaring LZ4",
797 report.roaring_lz4,
798 report.roaring_lz4,
799 diff,
800 ratio_to_marks(diff)
801 );
802
803 let diff = report.baseline as f64 / report.splinter.0 as f64;
804 println!(
805 "{:30} {:12} {:6} {:10} {:>10.2} {:>10}",
806 "",
807 "Baseline",
808 report.baseline,
809 report.baseline,
810 diff,
811 ratio_to_marks(diff)
812 );
813 }
814
815 let avg_ratio = reports
817 .iter()
818 .map(|r| r.splinter_lz4 as f64 / r.splinter.0 as f64)
819 .sum::<f64>()
820 / reports.len() as f64;
821
822 println!("average compression ratio (splinter_lz4 / splinter): {avg_ratio:.2}");
823
824 assert!(!fail_test, "compression test failed");
825 }
826}