1use serde::{Deserialize, Serialize};
44
45#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
47pub enum LookupResult {
48 Exact(usize),
50 Range { low: usize, high: usize },
52 NotFound,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct LearnedSparseIndex {
63 slope: f64,
65 intercept: f64,
66 max_error: usize,
68 corrections: Vec<(u64, usize)>,
70 min_key: u64,
72 max_key: u64,
74 key_range: f64,
76 num_keys: usize,
78 correction_threshold: usize,
80}
81
82impl LearnedSparseIndex {
83 const DEFAULT_CORRECTION_THRESHOLD: usize = 64;
85
86 pub fn empty() -> Self {
88 Self {
89 slope: 0.0,
90 intercept: 0.0,
91 max_error: 0,
92 corrections: Vec::new(),
93 min_key: 0,
94 max_key: 0,
95 key_range: 0.0,
96 num_keys: 0,
97 correction_threshold: Self::DEFAULT_CORRECTION_THRESHOLD,
98 }
99 }
100
101 #[inline]
106 fn normalize_key(&self, key: u64) -> f64 {
107 if self.key_range == 0.0 {
108 return 0.0;
109 }
110 let offset = (key as u128).saturating_sub(self.min_key as u128) as f64;
112 (offset / self.key_range) * (self.num_keys - 1) as f64
113 }
114
115 pub fn build(keys: &[u64]) -> Self {
119 Self::build_with_threshold(keys, Self::DEFAULT_CORRECTION_THRESHOLD)
120 }
121
122 pub fn build_with_threshold(keys: &[u64], correction_threshold: usize) -> Self {
127 let n = keys.len();
128 if n == 0 {
129 return Self::empty();
130 }
131
132 if n == 1 {
133 return Self {
134 slope: 0.0,
135 intercept: 0.0,
136 max_error: 0,
137 corrections: Vec::new(),
138 min_key: keys[0],
139 max_key: keys[0],
140 key_range: 0.0,
141 num_keys: 1,
142 correction_threshold,
143 };
144 }
145
146 let min_key = keys[0];
147 let max_key = keys[n - 1];
148 let key_range = (max_key as u128 - min_key as u128) as f64;
150
151 let (slope, intercept) = Self::linear_regression_normalized(keys, min_key, key_range, n);
153
154 let mut max_error = 0usize;
156 let mut corrections = Vec::new();
157
158 for (actual_pos, &key) in keys.iter().enumerate() {
159 let normalized = if key_range == 0.0 {
161 0.0
162 } else {
163 let offset = (key as u128 - min_key as u128) as f64;
164 (offset / key_range) * (n - 1) as f64
165 };
166 let predicted = slope * normalized + intercept;
167 let predicted_pos = predicted.round() as isize;
168 let error = (actual_pos as isize - predicted_pos).unsigned_abs();
169
170 if error > max_error {
171 max_error = error;
172 }
173
174 if error > correction_threshold {
176 corrections.push((key, actual_pos));
177 }
178 }
179
180 Self {
181 slope,
182 intercept,
183 max_error,
184 corrections,
185 min_key,
186 max_key,
187 key_range,
188 num_keys: n,
189 correction_threshold,
190 }
191 }
192
193 pub fn lookup(&self, key: u64) -> LookupResult {
195 if self.num_keys == 0 {
196 return LookupResult::NotFound;
197 }
198
199 if key < self.min_key || key > self.max_key {
201 return LookupResult::NotFound;
202 }
203
204 if let Ok(idx) = self.corrections.binary_search_by_key(&key, |&(k, _)| k) {
207 return LookupResult::Exact(self.corrections[idx].1);
208 }
209
210 let normalized = self.normalize_key(key);
212 let predicted = self.slope * normalized + self.intercept;
213 let predicted_pos = predicted.round() as isize;
214
215 let low = (predicted_pos - self.max_error as isize).max(0) as usize;
217 let high =
218 (predicted_pos + self.max_error as isize).min(self.num_keys as isize - 1) as usize;
219
220 LookupResult::Range { low, high }
221 }
222
223 pub fn lookup_with_error(&self, key: u64, max_error: usize) -> LookupResult {
225 if self.num_keys == 0 {
226 return LookupResult::NotFound;
227 }
228
229 if key < self.min_key || key > self.max_key {
230 return LookupResult::NotFound;
231 }
232
233 if let Ok(idx) = self.corrections.binary_search_by_key(&key, |&(k, _)| k) {
234 return LookupResult::Exact(self.corrections[idx].1);
235 }
236
237 let normalized = self.normalize_key(key);
239 let predicted = self.slope * normalized + self.intercept;
240 let predicted_pos = predicted.round() as isize;
241
242 let low = (predicted_pos - max_error as isize).max(0) as usize;
243 let high = (predicted_pos + max_error as isize).min(self.num_keys as isize - 1) as usize;
244
245 LookupResult::Range { low, high }
246 }
247
248 pub fn stats(&self) -> LearnedIndexStats {
250 LearnedIndexStats {
251 num_keys: self.num_keys,
252 max_error: self.max_error,
253 num_corrections: self.corrections.len(),
254 slope: self.slope,
255 intercept: self.intercept,
256 correction_ratio: if self.num_keys > 0 {
257 self.corrections.len() as f64 / self.num_keys as f64
258 } else {
259 0.0
260 },
261 }
262 }
263
264 pub fn is_efficient(&self) -> bool {
267 let low_error = self.max_error <= 128;
271 let low_corrections =
272 self.num_keys == 0 || (self.corrections.len() as f64 / self.num_keys as f64) < 0.05;
273 low_error && low_corrections
274 }
275
276 pub fn memory_bytes(&self) -> usize {
278 std::mem::size_of::<Self>()
279 + self.corrections.len() * (std::mem::size_of::<u64>() + std::mem::size_of::<usize>())
280 }
281
282 fn linear_regression_normalized(
287 keys: &[u64],
288 min_key: u64,
289 key_range: f64,
290 n: usize,
291 ) -> (f64, f64) {
292 let n_f64 = n as f64;
293
294 let mut sum_x: f64 = 0.0;
296 let mut sum_y: f64 = 0.0;
297 let mut sum_xy: f64 = 0.0;
298 let mut sum_xx: f64 = 0.0;
299
300 for (i, &key) in keys.iter().enumerate() {
301 let x = if key_range == 0.0 {
303 0.0
304 } else {
305 let offset = (key as u128 - min_key as u128) as f64;
306 (offset / key_range) * (n - 1) as f64
307 };
308 let y = i as f64;
309
310 sum_x += x;
311 sum_y += y;
312 sum_xy += x * y;
313 sum_xx += x * x;
314 }
315
316 let denominator = n_f64 * sum_xx - sum_x * sum_x;
317
318 if denominator.abs() < f64::EPSILON {
320 return (0.0, sum_y / n_f64);
321 }
322
323 let slope = (n_f64 * sum_xy - sum_x * sum_y) / denominator;
324 let intercept = (sum_y - slope * sum_x) / n_f64;
325
326 (slope, intercept)
327 }
328
329 #[allow(dead_code)]
331 fn linear_regression(keys: &[u64]) -> (f64, f64) {
332 let n = keys.len() as f64;
333
334 let mut sum_x: f64 = 0.0;
336 let mut sum_y: f64 = 0.0;
337 let mut sum_xy: f64 = 0.0;
338 let mut sum_xx: f64 = 0.0;
339
340 for (i, &key) in keys.iter().enumerate() {
341 let x = key as f64;
342 let y = i as f64;
343 sum_x += x;
344 sum_y += y;
345 sum_xy += x * y;
346 sum_xx += x * x;
347 }
348
349 let denominator = n * sum_xx - sum_x * sum_x;
350
351 if denominator.abs() < f64::EPSILON {
353 return (0.0, sum_y / n);
354 }
355
356 let slope = (n * sum_xy - sum_x * sum_y) / denominator;
357 let intercept = (sum_y - slope * sum_x) / n;
358
359 (slope, intercept)
360 }
361
362 pub fn insert(&mut self, key: u64, position: usize, keys: &[u64]) -> bool {
365 let normalized = self.normalize_key(key);
367 let predicted = self.slope * normalized + self.intercept;
368 let predicted_pos = predicted.round() as isize;
369 let error = (position as isize - predicted_pos).unsigned_abs();
370
371 self.min_key = self.min_key.min(key);
373 self.max_key = self.max_key.max(key);
374 self.key_range = (self.max_key as u128 - self.min_key as u128) as f64;
376 self.num_keys += 1;
377
378 if error > self.max_error {
379 self.max_error = error;
380 }
381
382 if error > self.correction_threshold {
383 match self.corrections.binary_search_by_key(&key, |&(k, _)| k) {
385 Ok(idx) => self.corrections[idx] = (key, position),
386 Err(idx) => self.corrections.insert(idx, (key, position)),
387 }
388 }
389
390 if self.corrections.len() > self.num_keys / 10 {
392 *self = Self::build_with_threshold(keys, self.correction_threshold);
393 return true;
394 }
395
396 false
397 }
398}
399
400#[derive(Debug, Clone)]
402pub struct LearnedIndexStats {
403 pub num_keys: usize,
405 pub max_error: usize,
407 pub num_corrections: usize,
409 pub slope: f64,
411 pub intercept: f64,
413 pub correction_ratio: f64,
415}
416
417#[derive(Debug, Clone)]
421pub struct PiecewiseLearnedIndex {
422 boundaries: Vec<u64>,
424 segments: Vec<LearnedSparseIndex>,
426}
427
428impl PiecewiseLearnedIndex {
429 pub fn build(keys: &[u64], max_segments: usize) -> Self {
431 if keys.is_empty() || max_segments == 0 {
432 return Self {
433 boundaries: vec![],
434 segments: vec![],
435 };
436 }
437
438 let segment_size = keys.len().div_ceil(max_segments);
440 let mut boundaries = Vec::with_capacity(max_segments);
441 let mut segments = Vec::with_capacity(max_segments);
442
443 for chunk in keys.chunks(segment_size) {
444 if !chunk.is_empty() {
445 boundaries.push(chunk[0]);
446 segments.push(LearnedSparseIndex::build(chunk));
447 }
448 }
449
450 Self {
451 boundaries,
452 segments,
453 }
454 }
455
456 fn find_segment(&self, key: u64) -> Option<usize> {
458 if self.boundaries.is_empty() {
459 return None;
460 }
461
462 match self.boundaries.binary_search(&key) {
464 Ok(i) => Some(i),
465 Err(i) => {
466 if i == 0 {
467 None
468 } else {
469 Some(i - 1)
470 }
471 }
472 }
473 }
474
475 pub fn lookup(&self, key: u64) -> LookupResult {
477 match self.find_segment(key) {
478 Some(seg_idx) => self.segments[seg_idx].lookup(key),
479 None => LookupResult::NotFound,
480 }
481 }
482
483 pub fn stats(&self) -> PiecewiseStats {
485 let segment_stats: Vec<_> = self.segments.iter().map(|s| s.stats()).collect();
486 let total_keys: usize = segment_stats.iter().map(|s| s.num_keys).sum();
487 let max_error = segment_stats.iter().map(|s| s.max_error).max().unwrap_or(0);
488 let total_corrections: usize = segment_stats.iter().map(|s| s.num_corrections).sum();
489
490 PiecewiseStats {
491 num_segments: self.segments.len(),
492 total_keys,
493 max_error,
494 total_corrections,
495 avg_segment_size: if self.segments.is_empty() {
496 0.0
497 } else {
498 total_keys as f64 / self.segments.len() as f64
499 },
500 }
501 }
502}
503
504#[derive(Debug, Clone)]
506pub struct PiecewiseStats {
507 pub num_segments: usize,
508 pub total_keys: usize,
509 pub max_error: usize,
510 pub total_corrections: usize,
511 pub avg_segment_size: f64,
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517
518 #[test]
519 fn test_empty_index() {
520 let index = LearnedSparseIndex::build(&[]);
521 assert_eq!(index.lookup(42), LookupResult::NotFound);
522 assert_eq!(index.stats().num_keys, 0);
523 }
524
525 #[test]
526 fn test_single_key() {
527 let index = LearnedSparseIndex::build(&[100]);
528 assert!(matches!(
529 index.lookup(100),
530 LookupResult::Range { low: 0, high: 0 }
531 ));
532 assert_eq!(index.lookup(50), LookupResult::NotFound);
533 assert_eq!(index.lookup(150), LookupResult::NotFound);
534 }
535
536 #[test]
537 fn test_sequential_keys() {
538 let keys: Vec<u64> = (0..1000).collect();
540 let index = LearnedSparseIndex::build(&keys);
541
542 let stats = index.stats();
543 assert!(
544 stats.max_error <= 1,
545 "Sequential keys should have near-zero error"
546 );
547 assert!(
548 stats.num_corrections == 0,
549 "No corrections needed for linear data"
550 );
551
552 if let LookupResult::Range { low, high } = index.lookup(500) {
554 assert!(low <= 500 && high >= 500, "Key 500 should be in range");
555 assert!(high - low <= 2, "Range should be very tight");
556 }
557 }
558
559 #[test]
560 fn test_timestamp_like_keys() {
561 let mut keys: Vec<u64> = Vec::new();
563 let mut ts: u64 = 1704067200; for _ in 0..10000 {
565 keys.push(ts);
566 ts += 1 + (ts % 10); }
568
569 let index = LearnedSparseIndex::build(&keys);
570
571 assert!(
573 index.is_efficient(),
574 "Timestamp data should be efficiently indexable"
575 );
576
577 for &key in keys.iter().take(100) {
579 let result = index.lookup(key);
580 assert!(
581 !matches!(result, LookupResult::NotFound),
582 "Existing key should be found"
583 );
584 }
585 }
586
587 #[test]
588 fn test_sparse_keys() {
589 let keys: Vec<u64> = vec![1, 100, 10000, 1000000, 100000000];
591 let index = LearnedSparseIndex::build(&keys);
592
593 for (i, &key) in keys.iter().enumerate() {
595 match index.lookup(key) {
596 LookupResult::Exact(pos) => assert_eq!(pos, i),
597 LookupResult::Range { low, high } => {
598 assert!(
599 low <= i && i <= high,
600 "Key {} should be in range [{}, {}]",
601 key,
602 low,
603 high
604 );
605 }
606 LookupResult::NotFound => panic!("Key {} should be found", key),
607 }
608 }
609 }
610
611 #[test]
612 fn test_out_of_bounds() {
613 let keys: Vec<u64> = (100..200).collect();
614 let index = LearnedSparseIndex::build(&keys);
615
616 assert_eq!(index.lookup(50), LookupResult::NotFound);
617 assert_eq!(index.lookup(250), LookupResult::NotFound);
618 }
619
620 #[test]
621 fn test_piecewise_index() {
622 let mut keys: Vec<u64> = Vec::new();
624
625 keys.extend(0..1000); keys.extend((100000..101000).step_by(10)); keys.extend(1000000..1001000); let piecewise = PiecewiseLearnedIndex::build(&keys, 3);
631 let stats = piecewise.stats();
632
633 assert_eq!(stats.num_segments, 3);
634
635 assert!(!matches!(piecewise.lookup(500), LookupResult::NotFound));
637 assert!(!matches!(piecewise.lookup(100500), LookupResult::NotFound));
638 assert!(!matches!(piecewise.lookup(1000500), LookupResult::NotFound));
639 }
640
641 #[test]
642 fn test_memory_efficiency() {
643 let keys: Vec<u64> = (0..100000).collect();
645 let index = LearnedSparseIndex::build(&keys);
646
647 let lsi_bytes = index.memory_bytes();
648 let btree_bytes = keys.len() * std::mem::size_of::<u64>(); assert!(
652 lsi_bytes < btree_bytes,
653 "LSI ({} bytes) should use less memory than keys alone ({} bytes)",
654 lsi_bytes,
655 btree_bytes
656 );
657 }
658
659 #[test]
660 fn test_correction_threshold() {
661 let mut keys: Vec<u64> = (0..100).map(|x| x * 10).collect();
663 keys.push(5000); keys.sort();
665
666 let low_thresh = LearnedSparseIndex::build_with_threshold(&keys, 10);
668
669 let high_thresh = LearnedSparseIndex::build_with_threshold(&keys, 1000);
671
672 assert!(
673 low_thresh.stats().num_corrections >= high_thresh.stats().num_corrections,
674 "Lower threshold should produce more or equal corrections"
675 );
676 }
677
678 #[test]
683 fn test_large_key_normalization() {
684 let base = u64::MAX - 1000;
686 let keys: Vec<u64> = (0..100).map(|i| base + i * 10).collect();
687
688 let index = LearnedSparseIndex::build(&keys);
689
690 assert!(
692 index.max_error < 10,
693 "Error should be small for linear data"
694 );
695
696 for (i, &key) in keys.iter().enumerate() {
698 let result = index.lookup(key);
699 match result {
700 LookupResult::Range { low, high } => {
701 assert!(
702 low <= i && i <= high,
703 "Key {} at position {} should be in range [{}, {}]",
704 key,
705 i,
706 low,
707 high
708 );
709 }
710 LookupResult::Exact(pos) => {
711 assert_eq!(pos, i, "Exact position should match");
712 }
713 LookupResult::NotFound => {
714 panic!("Key {} should be found", key);
715 }
716 }
717 }
718 }
719
720 #[test]
721 fn test_full_range_keys() {
722 let keys: Vec<u64> = vec![
724 0,
725 1_000_000,
726 1_000_000_000,
727 1_000_000_000_000,
728 1_000_000_000_000_000,
729 u64::MAX / 2,
730 u64::MAX - 1000,
731 u64::MAX - 100,
732 u64::MAX - 10,
733 u64::MAX - 1,
734 ];
735
736 let index = LearnedSparseIndex::build(&keys);
737
738 for (i, &key) in keys.iter().enumerate() {
740 let result = index.lookup(key);
741 match result {
742 LookupResult::Range { low, high } => {
743 assert!(
744 low <= i && i <= high,
745 "Key {} at position {} should be in range [{}, {}]",
746 key,
747 i,
748 low,
749 high
750 );
751 }
752 LookupResult::Exact(pos) => {
753 assert_eq!(pos, i, "Exact position should match");
754 }
755 LookupResult::NotFound => {
756 panic!("Key {} should be found", key);
757 }
758 }
759 }
760 }
761
762 #[test]
763 fn test_timestamp_keys() {
764 let base_ts: u64 = 1_700_000_000_000_000;
767 let keys: Vec<u64> = (0..1000).map(|i| base_ts + i * 1000).collect();
768
769 let index = LearnedSparseIndex::build(&keys);
770
771 assert!(
773 index.max_error <= 1,
774 "Error for sequential timestamps should be ≤ 1, got {}",
775 index.max_error
776 );
777
778 assert!(
780 index.is_efficient(),
781 "Sequential timestamp data should be efficient"
782 );
783 }
784
785 #[test]
786 fn test_normalization_precision() {
787 let index = LearnedSparseIndex {
789 slope: 1.0,
790 intercept: 0.0,
791 max_error: 0,
792 corrections: Vec::new(),
793 min_key: 0,
794 max_key: 99,
795 key_range: 99.0,
796 num_keys: 100,
797 correction_threshold: 64,
798 };
799
800 assert!((index.normalize_key(0) - 0.0).abs() < f64::EPSILON);
802
803 assert!((index.normalize_key(99) - 99.0).abs() < f64::EPSILON);
805
806 assert!((index.normalize_key(49) - 49.0).abs() < 0.5);
808 }
809}