1use serde::{Deserialize, Serialize};
41
42#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub enum LookupResult {
45 Exact(usize),
47 Range { low: usize, high: usize },
49 NotFound,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct LearnedSparseIndex {
60 slope: f64,
62 intercept: f64,
63 max_error: usize,
65 corrections: Vec<(u64, usize)>,
67 min_key: u64,
69 max_key: u64,
71 key_range: f64,
73 num_keys: usize,
75 correction_threshold: usize,
77}
78
79impl LearnedSparseIndex {
80 const DEFAULT_CORRECTION_THRESHOLD: usize = 64;
82
83 pub fn empty() -> Self {
85 Self {
86 slope: 0.0,
87 intercept: 0.0,
88 max_error: 0,
89 corrections: Vec::new(),
90 min_key: 0,
91 max_key: 0,
92 key_range: 0.0,
93 num_keys: 0,
94 correction_threshold: Self::DEFAULT_CORRECTION_THRESHOLD,
95 }
96 }
97
98 #[inline]
103 fn normalize_key(&self, key: u64) -> f64 {
104 if self.key_range == 0.0 {
105 return 0.0;
106 }
107 let offset = (key as u128).saturating_sub(self.min_key as u128) as f64;
109 (offset / self.key_range) * (self.num_keys - 1) as f64
110 }
111
112 pub fn build(keys: &[u64]) -> Self {
116 Self::build_with_threshold(keys, Self::DEFAULT_CORRECTION_THRESHOLD)
117 }
118
119 pub fn build_with_threshold(keys: &[u64], correction_threshold: usize) -> Self {
124 let n = keys.len();
125 if n == 0 {
126 return Self::empty();
127 }
128
129 if n == 1 {
130 return Self {
131 slope: 0.0,
132 intercept: 0.0,
133 max_error: 0,
134 corrections: Vec::new(),
135 min_key: keys[0],
136 max_key: keys[0],
137 key_range: 0.0,
138 num_keys: 1,
139 correction_threshold,
140 };
141 }
142
143 let min_key = keys[0];
144 let max_key = keys[n - 1];
145 let key_range = (max_key as u128 - min_key as u128) as f64;
147
148 let (slope, intercept) = Self::linear_regression_normalized(keys, min_key, key_range, n);
150
151 let mut max_error = 0usize;
153 let mut corrections = Vec::new();
154
155 for (actual_pos, &key) in keys.iter().enumerate() {
156 let normalized = if key_range == 0.0 {
158 0.0
159 } else {
160 let offset = (key as u128 - min_key as u128) as f64;
161 (offset / key_range) * (n - 1) as f64
162 };
163 let predicted = slope * normalized + intercept;
164 let predicted_pos = predicted.round() as isize;
165 let error = (actual_pos as isize - predicted_pos).unsigned_abs();
166
167 if error > max_error {
168 max_error = error;
169 }
170
171 if error > correction_threshold {
173 corrections.push((key, actual_pos));
174 }
175 }
176
177 Self {
178 slope,
179 intercept,
180 max_error,
181 corrections,
182 min_key,
183 max_key,
184 key_range,
185 num_keys: n,
186 correction_threshold,
187 }
188 }
189
190 pub fn lookup(&self, key: u64) -> LookupResult {
192 if self.num_keys == 0 {
193 return LookupResult::NotFound;
194 }
195
196 if key < self.min_key || key > self.max_key {
198 return LookupResult::NotFound;
199 }
200
201 if let Ok(idx) = self.corrections.binary_search_by_key(&key, |&(k, _)| k) {
204 return LookupResult::Exact(self.corrections[idx].1);
205 }
206
207 let normalized = self.normalize_key(key);
209 let predicted = self.slope * normalized + self.intercept;
210 let predicted_pos = predicted.round() as isize;
211
212 let low = (predicted_pos - self.max_error as isize).max(0) as usize;
214 let high =
215 (predicted_pos + self.max_error as isize).min(self.num_keys as isize - 1) as usize;
216
217 LookupResult::Range { low, high }
218 }
219
220 pub fn lookup_with_error(&self, key: u64, max_error: usize) -> LookupResult {
222 if self.num_keys == 0 {
223 return LookupResult::NotFound;
224 }
225
226 if key < self.min_key || key > self.max_key {
227 return LookupResult::NotFound;
228 }
229
230 if let Ok(idx) = self.corrections.binary_search_by_key(&key, |&(k, _)| k) {
231 return LookupResult::Exact(self.corrections[idx].1);
232 }
233
234 let normalized = self.normalize_key(key);
236 let predicted = self.slope * normalized + self.intercept;
237 let predicted_pos = predicted.round() as isize;
238
239 let low = (predicted_pos - max_error as isize).max(0) as usize;
240 let high = (predicted_pos + max_error as isize).min(self.num_keys as isize - 1) as usize;
241
242 LookupResult::Range { low, high }
243 }
244
245 pub fn stats(&self) -> LearnedIndexStats {
247 LearnedIndexStats {
248 num_keys: self.num_keys,
249 max_error: self.max_error,
250 num_corrections: self.corrections.len(),
251 slope: self.slope,
252 intercept: self.intercept,
253 correction_ratio: if self.num_keys > 0 {
254 self.corrections.len() as f64 / self.num_keys as f64
255 } else {
256 0.0
257 },
258 }
259 }
260
261 pub fn is_efficient(&self) -> bool {
264 let low_error = self.max_error <= 128;
268 let low_corrections =
269 self.num_keys == 0 || (self.corrections.len() as f64 / self.num_keys as f64) < 0.05;
270 low_error && low_corrections
271 }
272
273 pub fn memory_bytes(&self) -> usize {
275 std::mem::size_of::<Self>()
276 + self.corrections.len() * (std::mem::size_of::<u64>() + std::mem::size_of::<usize>())
277 }
278
279 fn linear_regression_normalized(
284 keys: &[u64],
285 min_key: u64,
286 key_range: f64,
287 n: usize,
288 ) -> (f64, f64) {
289 let n_f64 = n as f64;
290
291 let mut sum_x: f64 = 0.0;
293 let mut sum_y: f64 = 0.0;
294 let mut sum_xy: f64 = 0.0;
295 let mut sum_xx: f64 = 0.0;
296
297 for (i, &key) in keys.iter().enumerate() {
298 let x = if key_range == 0.0 {
300 0.0
301 } else {
302 let offset = (key as u128 - min_key as u128) as f64;
303 (offset / key_range) * (n - 1) as f64
304 };
305 let y = i as f64;
306
307 sum_x += x;
308 sum_y += y;
309 sum_xy += x * y;
310 sum_xx += x * x;
311 }
312
313 let denominator = n_f64 * sum_xx - sum_x * sum_x;
314
315 if denominator.abs() < f64::EPSILON {
317 return (0.0, sum_y / n_f64);
318 }
319
320 let slope = (n_f64 * sum_xy - sum_x * sum_y) / denominator;
321 let intercept = (sum_y - slope * sum_x) / n_f64;
322
323 (slope, intercept)
324 }
325
326 #[allow(dead_code)]
328 fn linear_regression(keys: &[u64]) -> (f64, f64) {
329 let n = keys.len() as f64;
330
331 let mut sum_x: f64 = 0.0;
333 let mut sum_y: f64 = 0.0;
334 let mut sum_xy: f64 = 0.0;
335 let mut sum_xx: f64 = 0.0;
336
337 for (i, &key) in keys.iter().enumerate() {
338 let x = key as f64;
339 let y = i as f64;
340 sum_x += x;
341 sum_y += y;
342 sum_xy += x * y;
343 sum_xx += x * x;
344 }
345
346 let denominator = n * sum_xx - sum_x * sum_x;
347
348 if denominator.abs() < f64::EPSILON {
350 return (0.0, sum_y / n);
351 }
352
353 let slope = (n * sum_xy - sum_x * sum_y) / denominator;
354 let intercept = (sum_y - slope * sum_x) / n;
355
356 (slope, intercept)
357 }
358
359 pub fn insert(&mut self, key: u64, position: usize, keys: &[u64]) -> bool {
362 let normalized = self.normalize_key(key);
364 let predicted = self.slope * normalized + self.intercept;
365 let predicted_pos = predicted.round() as isize;
366 let error = (position as isize - predicted_pos).unsigned_abs();
367
368 self.min_key = self.min_key.min(key);
370 self.max_key = self.max_key.max(key);
371 self.key_range = (self.max_key as u128 - self.min_key as u128) as f64;
373 self.num_keys += 1;
374
375 if error > self.max_error {
376 self.max_error = error;
377 }
378
379 if error > self.correction_threshold {
380 match self.corrections.binary_search_by_key(&key, |&(k, _)| k) {
382 Ok(idx) => self.corrections[idx] = (key, position),
383 Err(idx) => self.corrections.insert(idx, (key, position)),
384 }
385 }
386
387 if self.corrections.len() > self.num_keys / 10 {
389 *self = Self::build_with_threshold(keys, self.correction_threshold);
390 return true;
391 }
392
393 false
394 }
395}
396
397#[derive(Debug, Clone)]
399pub struct LearnedIndexStats {
400 pub num_keys: usize,
402 pub max_error: usize,
404 pub num_corrections: usize,
406 pub slope: f64,
408 pub intercept: f64,
410 pub correction_ratio: f64,
412}
413
414#[derive(Debug, Clone)]
418pub struct PiecewiseLearnedIndex {
419 boundaries: Vec<u64>,
421 segments: Vec<LearnedSparseIndex>,
423}
424
425impl PiecewiseLearnedIndex {
426 pub fn build(keys: &[u64], max_segments: usize) -> Self {
428 if keys.is_empty() || max_segments == 0 {
429 return Self {
430 boundaries: vec![],
431 segments: vec![],
432 };
433 }
434
435 let segment_size = keys.len().div_ceil(max_segments);
437 let mut boundaries = Vec::with_capacity(max_segments);
438 let mut segments = Vec::with_capacity(max_segments);
439
440 for chunk in keys.chunks(segment_size) {
441 if !chunk.is_empty() {
442 boundaries.push(chunk[0]);
443 segments.push(LearnedSparseIndex::build(chunk));
444 }
445 }
446
447 Self {
448 boundaries,
449 segments,
450 }
451 }
452
453 fn find_segment(&self, key: u64) -> Option<usize> {
455 if self.boundaries.is_empty() {
456 return None;
457 }
458
459 match self.boundaries.binary_search(&key) {
461 Ok(i) => Some(i),
462 Err(i) => {
463 if i == 0 {
464 None
465 } else {
466 Some(i - 1)
467 }
468 }
469 }
470 }
471
472 pub fn lookup(&self, key: u64) -> LookupResult {
474 match self.find_segment(key) {
475 Some(seg_idx) => self.segments[seg_idx].lookup(key),
476 None => LookupResult::NotFound,
477 }
478 }
479
480 pub fn stats(&self) -> PiecewiseStats {
482 let segment_stats: Vec<_> = self.segments.iter().map(|s| s.stats()).collect();
483 let total_keys: usize = segment_stats.iter().map(|s| s.num_keys).sum();
484 let max_error = segment_stats.iter().map(|s| s.max_error).max().unwrap_or(0);
485 let total_corrections: usize = segment_stats.iter().map(|s| s.num_corrections).sum();
486
487 PiecewiseStats {
488 num_segments: self.segments.len(),
489 total_keys,
490 max_error,
491 total_corrections,
492 avg_segment_size: if self.segments.is_empty() {
493 0.0
494 } else {
495 total_keys as f64 / self.segments.len() as f64
496 },
497 }
498 }
499}
500
501#[derive(Debug, Clone)]
503pub struct PiecewiseStats {
504 pub num_segments: usize,
505 pub total_keys: usize,
506 pub max_error: usize,
507 pub total_corrections: usize,
508 pub avg_segment_size: f64,
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 #[test]
516 fn test_empty_index() {
517 let index = LearnedSparseIndex::build(&[]);
518 assert_eq!(index.lookup(42), LookupResult::NotFound);
519 assert_eq!(index.stats().num_keys, 0);
520 }
521
522 #[test]
523 fn test_single_key() {
524 let index = LearnedSparseIndex::build(&[100]);
525 assert!(matches!(
526 index.lookup(100),
527 LookupResult::Range { low: 0, high: 0 }
528 ));
529 assert_eq!(index.lookup(50), LookupResult::NotFound);
530 assert_eq!(index.lookup(150), LookupResult::NotFound);
531 }
532
533 #[test]
534 fn test_sequential_keys() {
535 let keys: Vec<u64> = (0..1000).collect();
537 let index = LearnedSparseIndex::build(&keys);
538
539 let stats = index.stats();
540 assert!(
541 stats.max_error <= 1,
542 "Sequential keys should have near-zero error"
543 );
544 assert!(
545 stats.num_corrections == 0,
546 "No corrections needed for linear data"
547 );
548
549 if let LookupResult::Range { low, high } = index.lookup(500) {
551 assert!(low <= 500 && high >= 500, "Key 500 should be in range");
552 assert!(high - low <= 2, "Range should be very tight");
553 }
554 }
555
556 #[test]
557 fn test_timestamp_like_keys() {
558 let mut keys: Vec<u64> = Vec::new();
560 let mut ts: u64 = 1704067200; for _ in 0..10000 {
562 keys.push(ts);
563 ts += 1 + (ts % 10); }
565
566 let index = LearnedSparseIndex::build(&keys);
567
568 assert!(
570 index.is_efficient(),
571 "Timestamp data should be efficiently indexable"
572 );
573
574 for &key in keys.iter().take(100) {
576 let result = index.lookup(key);
577 assert!(
578 !matches!(result, LookupResult::NotFound),
579 "Existing key should be found"
580 );
581 }
582 }
583
584 #[test]
585 fn test_sparse_keys() {
586 let keys: Vec<u64> = vec![1, 100, 10000, 1000000, 100000000];
588 let index = LearnedSparseIndex::build(&keys);
589
590 for (i, &key) in keys.iter().enumerate() {
592 match index.lookup(key) {
593 LookupResult::Exact(pos) => assert_eq!(pos, i),
594 LookupResult::Range { low, high } => {
595 assert!(
596 low <= i && i <= high,
597 "Key {} should be in range [{}, {}]",
598 key,
599 low,
600 high
601 );
602 }
603 LookupResult::NotFound => panic!("Key {} should be found", key),
604 }
605 }
606 }
607
608 #[test]
609 fn test_out_of_bounds() {
610 let keys: Vec<u64> = (100..200).collect();
611 let index = LearnedSparseIndex::build(&keys);
612
613 assert_eq!(index.lookup(50), LookupResult::NotFound);
614 assert_eq!(index.lookup(250), LookupResult::NotFound);
615 }
616
617 #[test]
618 fn test_piecewise_index() {
619 let mut keys: Vec<u64> = Vec::new();
621
622 keys.extend(0..1000); keys.extend((100000..101000).step_by(10)); keys.extend(1000000..1001000); let piecewise = PiecewiseLearnedIndex::build(&keys, 3);
628 let stats = piecewise.stats();
629
630 assert_eq!(stats.num_segments, 3);
631
632 assert!(!matches!(piecewise.lookup(500), LookupResult::NotFound));
634 assert!(!matches!(piecewise.lookup(100500), LookupResult::NotFound));
635 assert!(!matches!(piecewise.lookup(1000500), LookupResult::NotFound));
636 }
637
638 #[test]
639 fn test_memory_efficiency() {
640 let keys: Vec<u64> = (0..100000).collect();
642 let index = LearnedSparseIndex::build(&keys);
643
644 let lsi_bytes = index.memory_bytes();
645 let btree_bytes = keys.len() * std::mem::size_of::<u64>(); assert!(
649 lsi_bytes < btree_bytes,
650 "LSI ({} bytes) should use less memory than keys alone ({} bytes)",
651 lsi_bytes,
652 btree_bytes
653 );
654 }
655
656 #[test]
657 fn test_correction_threshold() {
658 let mut keys: Vec<u64> = (0..100).map(|x| x * 10).collect();
660 keys.push(5000); keys.sort();
662
663 let low_thresh = LearnedSparseIndex::build_with_threshold(&keys, 10);
665
666 let high_thresh = LearnedSparseIndex::build_with_threshold(&keys, 1000);
668
669 assert!(
670 low_thresh.stats().num_corrections >= high_thresh.stats().num_corrections,
671 "Lower threshold should produce more or equal corrections"
672 );
673 }
674
675 #[test]
680 fn test_large_key_normalization() {
681 let base = u64::MAX - 1000;
683 let keys: Vec<u64> = (0..100).map(|i| base + i * 10).collect();
684
685 let index = LearnedSparseIndex::build(&keys);
686
687 assert!(
689 index.max_error < 10,
690 "Error should be small for linear data"
691 );
692
693 for (i, &key) in keys.iter().enumerate() {
695 let result = index.lookup(key);
696 match result {
697 LookupResult::Range { low, high } => {
698 assert!(
699 low <= i && i <= high,
700 "Key {} at position {} should be in range [{}, {}]",
701 key,
702 i,
703 low,
704 high
705 );
706 }
707 LookupResult::Exact(pos) => {
708 assert_eq!(pos, i, "Exact position should match");
709 }
710 LookupResult::NotFound => {
711 panic!("Key {} should be found", key);
712 }
713 }
714 }
715 }
716
717 #[test]
718 fn test_full_range_keys() {
719 let keys: Vec<u64> = vec![
721 0,
722 1_000_000,
723 1_000_000_000,
724 1_000_000_000_000,
725 1_000_000_000_000_000,
726 u64::MAX / 2,
727 u64::MAX - 1000,
728 u64::MAX - 100,
729 u64::MAX - 10,
730 u64::MAX - 1,
731 ];
732
733 let index = LearnedSparseIndex::build(&keys);
734
735 for (i, &key) in keys.iter().enumerate() {
737 let result = index.lookup(key);
738 match result {
739 LookupResult::Range { low, high } => {
740 assert!(
741 low <= i && i <= high,
742 "Key {} at position {} should be in range [{}, {}]",
743 key,
744 i,
745 low,
746 high
747 );
748 }
749 LookupResult::Exact(pos) => {
750 assert_eq!(pos, i, "Exact position should match");
751 }
752 LookupResult::NotFound => {
753 panic!("Key {} should be found", key);
754 }
755 }
756 }
757 }
758
759 #[test]
760 fn test_timestamp_keys() {
761 let base_ts: u64 = 1_700_000_000_000_000;
764 let keys: Vec<u64> = (0..1000).map(|i| base_ts + i * 1000).collect();
765
766 let index = LearnedSparseIndex::build(&keys);
767
768 assert!(
770 index.max_error <= 1,
771 "Error for sequential timestamps should be ≤ 1, got {}",
772 index.max_error
773 );
774
775 assert!(
777 index.is_efficient(),
778 "Sequential timestamp data should be efficient"
779 );
780 }
781
782 #[test]
783 fn test_normalization_precision() {
784 let index = LearnedSparseIndex {
786 slope: 1.0,
787 intercept: 0.0,
788 max_error: 0,
789 corrections: Vec::new(),
790 min_key: 0,
791 max_key: 99,
792 key_range: 99.0,
793 num_keys: 100,
794 correction_threshold: 64,
795 };
796
797 assert!((index.normalize_key(0) - 0.0).abs() < f64::EPSILON);
799
800 assert!((index.normalize_key(99) - 99.0).abs() < f64::EPSILON);
802
803 assert!((index.normalize_key(49) - 49.0).abs() < 0.5);
805 }
806}