sochdb_storage/
learned_index_integration.rs

1// Copyright 2025 Sushanth (https://github.com/sushanthpy)
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Learned Index Integration (Task 5)
16//!
17//! Integrates LearnedSparseIndex for O(1) expected point lookups:
18//! - Accelerates row lookups by key in LSCS
19//! - Falls back to binary search for outliers
20//! - Provides adaptive index selection based on data distribution
21//!
22//! ## Lookup Flow
23//!
24//! ```text
25//! get(key)
26//!   │
27//!   ▼
28//! ┌─────────────────────┐
29//! │ LearnedIndex.lookup │
30//! └──────────┬──────────┘
31//!            │
32//!      ┌─────┴─────┐
33//!      ▼           ▼
34//!   Exact       Range[lo,hi]
35//!     │             │
36//!     ▼             ▼
37//!   O(1)      BinarySearch
38//!   fetch       O(log ε)
39//! ```
40//!
41//! ## Index Selection
42//!
43//! The system automatically chooses between:
44//! - **LearnedIndex**: For sequential/timestamp keys (O(1))
45//! - **B-Tree**: For random/UUID keys (O(log N))
46//! - **Hash**: For exact-match only (O(1))
47
48use std::collections::{BTreeMap, HashSet};
49use std::sync::Arc;
50use sochdb_core::learned_index::{LearnedSparseIndex, LookupResult};
51
52/// Index type based on key characteristics
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum IndexType {
55    /// Learned index for sequential/monotonic keys
56    Learned,
57    /// B-Tree for random keys
58    BTree,
59    /// Hash for exact match only
60    Hash,
61    /// No index (scan)
62    None,
63}
64
65/// Key distribution statistics for index selection
66#[derive(Debug, Clone, Default)]
67pub struct KeyStats {
68    /// Number of keys analyzed
69    pub count: usize,
70    /// Minimum key value
71    pub min_key: u64,
72    /// Maximum key value
73    pub max_key: u64,
74    /// Is monotonically increasing
75    pub is_monotonic: bool,
76    /// Density: count / (max - min + 1)
77    pub density: f64,
78    /// Estimated entropy (randomness)
79    pub entropy: f64,
80}
81
82impl KeyStats {
83    /// Analyze keys to determine statistics
84    pub fn analyze(keys: &[u64]) -> Self {
85        if keys.is_empty() {
86            return Self::default();
87        }
88
89        let min_key = *keys.iter().min().unwrap();
90        let max_key = *keys.iter().max().unwrap();
91        let count = keys.len();
92
93        // Check monotonicity
94        let is_monotonic = keys.windows(2).all(|w| w[0] <= w[1]);
95
96        // Calculate density
97        let range = (max_key - min_key + 1) as f64;
98        let density = count as f64 / range;
99
100        // Estimate entropy from gaps
101        let entropy = Self::estimate_entropy(keys);
102
103        Self {
104            count,
105            min_key,
106            max_key,
107            is_monotonic,
108            density,
109            entropy,
110        }
111    }
112
113    /// Estimate entropy from key gaps
114    fn estimate_entropy(keys: &[u64]) -> f64 {
115        if keys.len() < 2 {
116            return 0.0;
117        }
118
119        // Calculate gap distribution
120        let gaps: Vec<u64> = keys.windows(2).map(|w| w[1] - w[0]).collect();
121
122        if gaps.is_empty() {
123            return 0.0;
124        }
125
126        let mean_gap = gaps.iter().sum::<u64>() as f64 / gaps.len() as f64;
127        if mean_gap == 0.0 {
128            return 0.0;
129        }
130
131        // Calculate coefficient of variation as entropy proxy
132        let variance = gaps
133            .iter()
134            .map(|&g| {
135                let diff = g as f64 - mean_gap;
136                diff * diff
137            })
138            .sum::<f64>()
139            / gaps.len() as f64;
140
141        (variance.sqrt() / mean_gap).min(1.0)
142    }
143
144    /// Recommend index type based on statistics
145    pub fn recommend_index_type(&self) -> IndexType {
146        if self.count == 0 {
147            return IndexType::None;
148        }
149
150        // High density + monotonic = learned index
151        if self.is_monotonic && self.density > 0.5 {
152            return IndexType::Learned;
153        }
154
155        // Low entropy (regular gaps) = learned index
156        if self.entropy < 0.3 {
157            return IndexType::Learned;
158        }
159
160        // Random keys = B-Tree
161        if self.entropy > 0.7 {
162            return IndexType::BTree;
163        }
164
165        // Default to learned for moderate cases
166        IndexType::Learned
167    }
168}
169
170/// Hybrid index that combines learned index with fallback
171pub struct HybridIndex {
172    /// Learned sparse index
173    learned: LearnedSparseIndex,
174    /// Sorted keys for fallback binary search
175    keys: Vec<u64>,
176    /// Key to position mapping for fallback
177    key_map: BTreeMap<u64, usize>,
178    /// Index type in use
179    index_type: IndexType,
180    /// Statistics
181    stats: KeyStats,
182}
183
184impl HybridIndex {
185    /// Build hybrid index from sorted keys
186    pub fn build(keys: &[u64]) -> Self {
187        let stats = KeyStats::analyze(keys);
188        let index_type = stats.recommend_index_type();
189
190        let learned = if index_type == IndexType::Learned {
191            LearnedSparseIndex::build(keys)
192        } else {
193            LearnedSparseIndex::empty()
194        };
195
196        let key_map: BTreeMap<u64, usize> = keys.iter().enumerate().map(|(i, &k)| (k, i)).collect();
197
198        Self {
199            learned,
200            keys: keys.to_vec(),
201            key_map,
202            index_type,
203            stats,
204        }
205    }
206
207    /// Lookup key position
208    ///
209    /// Returns the exact position if found, or None.
210    pub fn lookup(&self, key: u64) -> Option<usize> {
211        match self.index_type {
212            IndexType::Learned => self.lookup_learned(key),
213            IndexType::BTree | IndexType::Hash => self.key_map.get(&key).copied(),
214            IndexType::None => self.binary_search(key),
215        }
216    }
217
218    /// Lookup using learned index
219    fn lookup_learned(&self, key: u64) -> Option<usize> {
220        match self.learned.lookup(key) {
221            LookupResult::Exact(pos) => Some(pos),
222            LookupResult::Range { low, high } => {
223                // Binary search within the predicted range
224                self.binary_search_range(key, low, high)
225            }
226            LookupResult::NotFound => None,
227        }
228    }
229
230    /// Binary search within a range
231    fn binary_search_range(&self, key: u64, low: usize, high: usize) -> Option<usize> {
232        if low > high || high >= self.keys.len() {
233            return None;
234        }
235
236        let slice = &self.keys[low..=high];
237        match slice.binary_search(&key) {
238            Ok(pos) => Some(low + pos),
239            Err(_) => None,
240        }
241    }
242
243    /// Full binary search
244    fn binary_search(&self, key: u64) -> Option<usize> {
245        self.keys.binary_search(&key).ok()
246    }
247
248    /// Range lookup: find all positions in [start, end]
249    pub fn range_lookup(&self, start: u64, end: u64) -> Vec<usize> {
250        // Find start position
251        let start_pos = match self.keys.binary_search(&start) {
252            Ok(pos) => pos,
253            Err(pos) => pos,
254        };
255
256        // Find end position
257        let end_pos = match self.keys.binary_search(&end) {
258            Ok(pos) => pos + 1,
259            Err(pos) => pos,
260        };
261
262        (start_pos..end_pos.min(self.keys.len())).collect()
263    }
264
265    /// Get index statistics
266    pub fn statistics(&self) -> &KeyStats {
267        &self.stats
268    }
269
270    /// Get index type
271    pub fn index_type(&self) -> IndexType {
272        self.index_type
273    }
274
275    /// Check if learned index is efficient for this data
276    pub fn is_efficient(&self) -> bool {
277        self.learned.is_efficient()
278    }
279
280    /// Memory usage in bytes
281    pub fn memory_bytes(&self) -> usize {
282        self.learned.memory_bytes()
283            + self.keys.len() * std::mem::size_of::<u64>()
284            + self.key_map.len() * (std::mem::size_of::<u64>() + std::mem::size_of::<usize>())
285    }
286}
287
288/// Index manager for multiple tables/columns
289pub struct IndexManager {
290    /// Indexes by table and column
291    indexes: BTreeMap<(String, String), Arc<HybridIndex>>,
292}
293
294impl IndexManager {
295    /// Create a new index manager
296    pub fn new() -> Self {
297        Self {
298            indexes: BTreeMap::new(),
299        }
300    }
301
302    /// Build or update index for a column
303    pub fn build_index(&mut self, table: &str, column: &str, keys: &[u64]) {
304        let index = HybridIndex::build(keys);
305        self.indexes
306            .insert((table.to_string(), column.to_string()), Arc::new(index));
307    }
308
309    /// Get index for a column
310    pub fn get_index(&self, table: &str, column: &str) -> Option<Arc<HybridIndex>> {
311        self.indexes
312            .get(&(table.to_string(), column.to_string()))
313            .cloned()
314    }
315
316    /// Remove index
317    pub fn drop_index(&mut self, table: &str, column: &str) -> bool {
318        self.indexes
319            .remove(&(table.to_string(), column.to_string()))
320            .is_some()
321    }
322
323    /// List all indexes
324    pub fn list_indexes(&self) -> Vec<(&str, &str, IndexType)> {
325        self.indexes
326            .iter()
327            .map(|((table, column), index)| (table.as_str(), column.as_str(), index.index_type()))
328            .collect()
329    }
330}
331
332impl Default for IndexManager {
333    fn default() -> Self {
334        Self::new()
335    }
336}
337
338/// Point lookup executor using learned index
339pub struct PointLookupExecutor<'a, V> {
340    /// Index to use
341    index: &'a HybridIndex,
342    /// Data array to fetch from
343    data: &'a [V],
344}
345
346impl<'a, V> PointLookupExecutor<'a, V> {
347    /// Create a new executor
348    pub fn new(index: &'a HybridIndex, data: &'a [V]) -> Self {
349        Self { index, data }
350    }
351
352    /// Execute point lookup
353    pub fn execute(&self, key: u64) -> Option<&V> {
354        self.index.lookup(key).and_then(|pos| self.data.get(pos))
355    }
356
357    /// Execute batch lookup
358    pub fn execute_batch(&self, keys: &[u64]) -> Vec<Option<&V>> {
359        keys.iter().map(|&k| self.execute(k)).collect()
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    #[test]
368    fn test_sequential_keys() {
369        let keys: Vec<u64> = (1..=1000).collect();
370        let stats = KeyStats::analyze(&keys);
371
372        assert!(stats.is_monotonic);
373        assert!(stats.density > 0.99);
374        assert_eq!(stats.recommend_index_type(), IndexType::Learned);
375    }
376
377    #[test]
378    fn test_timestamp_keys() {
379        // Simulate timestamp keys (microseconds, ~1ms apart)
380        let base = 1700000000000000u64; // ~2023
381        let keys: Vec<u64> = (0..1000).map(|i| base + i * 1000).collect();
382
383        let stats = KeyStats::analyze(&keys);
384        assert!(stats.is_monotonic);
385        assert!(stats.entropy < 0.1); // Regular gaps
386        assert_eq!(stats.recommend_index_type(), IndexType::Learned);
387    }
388
389    #[test]
390    fn test_hybrid_index_lookup() {
391        let keys: Vec<u64> = (0..1000).map(|i| i * 10).collect();
392        let index = HybridIndex::build(&keys);
393
394        // Exact match
395        assert_eq!(index.lookup(500), Some(50));
396        assert_eq!(index.lookup(990), Some(99));
397
398        // Not found
399        assert_eq!(index.lookup(5), None);
400        assert_eq!(index.lookup(10000), None);
401    }
402
403    #[test]
404    fn test_range_lookup() {
405        let keys: Vec<u64> = (0..100).map(|i| i * 10).collect();
406        let index = HybridIndex::build(&keys);
407
408        // Range [100, 300] should include positions for 100, 110, ..., 300
409        let positions = index.range_lookup(100, 300);
410        assert_eq!(positions.len(), 21); // 100, 110, ..., 300 = 21 values
411        assert_eq!(positions[0], 10); // Position of key 100
412    }
413
414    #[test]
415    fn test_point_lookup_executor() {
416        let keys: Vec<u64> = vec![10, 20, 30, 40, 50];
417        let values = vec!["a", "b", "c", "d", "e"];
418        let index = HybridIndex::build(&keys);
419
420        let executor = PointLookupExecutor::new(&index, &values);
421
422        assert_eq!(executor.execute(20), Some(&"b"));
423        assert_eq!(executor.execute(50), Some(&"e"));
424        assert_eq!(executor.execute(25), None);
425    }
426
427    #[test]
428    fn test_index_manager() {
429        let mut manager = IndexManager::new();
430
431        let keys: Vec<u64> = (0..100).collect();
432        manager.build_index("users", "id", &keys);
433
434        let index = manager.get_index("users", "id").unwrap();
435        assert_eq!(index.lookup(50), Some(50));
436
437        let indexes = manager.list_indexes();
438        assert_eq!(indexes.len(), 1);
439        assert_eq!(indexes[0], ("users", "id", IndexType::Learned));
440    }
441}
442
443// =============================================================================
444// Task 4: Piecewise Learned Index Enhancement
445// =============================================================================
446
447/// A segment in the piecewise linear index
448#[derive(Debug, Clone)]
449pub struct LinearSegment {
450    /// Start key of this segment
451    pub start_key: u64,
452    /// End key of this segment (exclusive)
453    pub end_key: u64,
454    /// Slope: position increase per key increase
455    pub slope: f64,
456    /// Intercept: position at start_key
457    pub intercept: f64,
458    /// Maximum error in this segment
459    pub max_error: usize,
460}
461
462impl LinearSegment {
463    /// Predict position for a key within this segment
464    pub fn predict(&self, key: u64) -> usize {
465        if key < self.start_key {
466            return 0;
467        }
468        let delta = (key - self.start_key) as f64;
469        (self.intercept + delta * self.slope).round() as usize
470    }
471
472    /// Get search bounds accounting for error
473    pub fn bounds(&self, key: u64, data_len: usize) -> (usize, usize) {
474        let pred = self.predict(key);
475        let low = pred.saturating_sub(self.max_error);
476        let high = (pred + self.max_error).min(data_len.saturating_sub(1));
477        (low, high)
478    }
479}
480
481/// Piecewise Linear Index using dynamic programming for optimal segmentation
482///
483/// ## Algorithm
484/// Uses dynamic programming to find the optimal set of linear segments
485/// that minimize total error while keeping segments count bounded.
486///
487/// Cost function: `sum(segment_errors) + lambda * num_segments`
488///
489/// ## Performance
490/// - Construction: O(n²) with DP, O(n) with greedy
491/// - Lookup: O(log S + log ε) where S = segment count, ε = max error
492#[derive(Debug)]
493#[allow(dead_code)]
494pub struct PiecewiseLearnedIndex {
495    /// Sorted segments by start_key
496    segments: Vec<LinearSegment>,
497    /// Target maximum error per segment
498    max_error_bound: usize,
499    /// Total number of keys indexed
500    total_keys: usize,
501    /// Construction statistics
502    stats: PiecewiseStats,
503}
504
505/// Statistics about the piecewise index
506#[derive(Debug, Clone, Default)]
507pub struct PiecewiseStats {
508    /// Number of segments
509    pub segment_count: usize,
510    /// Average error across segments
511    pub avg_error: f64,
512    /// Maximum error across all segments
513    pub max_error: usize,
514    /// Compression ratio (keys / segments)
515    pub compression_ratio: f64,
516}
517
518impl PiecewiseLearnedIndex {
519    /// Build piecewise index with specified error bound
520    ///
521    /// Uses greedy algorithm for efficiency (O(n) instead of O(n²) DP)
522    pub fn build(keys: &[u64], max_error: usize) -> Self {
523        if keys.is_empty() {
524            return Self {
525                segments: Vec::new(),
526                max_error_bound: max_error,
527                total_keys: 0,
528                stats: PiecewiseStats::default(),
529            };
530        }
531
532        let segments = Self::build_greedy(keys, max_error);
533        let stats = Self::compute_stats(&segments, keys.len());
534
535        Self {
536            segments,
537            max_error_bound: max_error,
538            total_keys: keys.len(),
539            stats,
540        }
541    }
542
543    /// Build using dynamic programming for optimal segmentation
544    ///
545    /// Slower O(n²) but produces optimal segments minimizing total error.
546    pub fn build_optimal(keys: &[u64], max_segments: usize) -> Self {
547        if keys.is_empty() {
548            return Self {
549                segments: Vec::new(),
550                max_error_bound: 0,
551                total_keys: 0,
552                stats: PiecewiseStats::default(),
553            };
554        }
555
556        let segments = Self::build_dp(keys, max_segments);
557        let max_error = segments.iter().map(|s| s.max_error).max().unwrap_or(0);
558        let stats = Self::compute_stats(&segments, keys.len());
559
560        Self {
561            segments,
562            max_error_bound: max_error,
563            total_keys: keys.len(),
564            stats,
565        }
566    }
567
568    /// Greedy segmentation algorithm
569    fn build_greedy(keys: &[u64], max_error: usize) -> Vec<LinearSegment> {
570        let mut segments = Vec::new();
571        let mut start_idx = 0;
572
573        while start_idx < keys.len() {
574            // Find the longest segment starting at start_idx with error <= max_error
575            let (end_idx, segment) = Self::find_longest_segment(keys, start_idx, max_error);
576            segments.push(segment);
577            start_idx = end_idx + 1;
578        }
579
580        segments
581    }
582
583    /// Find longest segment with bounded error
584    fn find_longest_segment(
585        keys: &[u64],
586        start_idx: usize,
587        max_error: usize,
588    ) -> (usize, LinearSegment) {
589        let start_key = keys[start_idx];
590        let mut end_idx = start_idx;
591        let mut best_slope = 0.0;
592        let mut best_intercept = start_idx as f64;
593        let mut best_error = 0;
594
595        // Extend segment as far as possible
596        for i in (start_idx + 1)..keys.len() {
597            let (slope, intercept, error) = Self::fit_segment(keys, start_idx, i);
598            if error <= max_error {
599                end_idx = i;
600                best_slope = slope;
601                best_intercept = intercept;
602                best_error = error;
603            } else {
604                break;
605            }
606        }
607
608        let end_key = if end_idx + 1 < keys.len() {
609            keys[end_idx + 1]
610        } else {
611            keys[end_idx].saturating_add(1)
612        };
613
614        (
615            end_idx,
616            LinearSegment {
617                start_key,
618                end_key,
619                slope: best_slope,
620                intercept: best_intercept,
621                max_error: best_error,
622            },
623        )
624    }
625
626    /// Fit a linear segment and compute max error
627    fn fit_segment(keys: &[u64], start: usize, end: usize) -> (f64, f64, usize) {
628        if start == end {
629            return (0.0, start as f64, 0);
630        }
631
632        let n = (end - start + 1) as f64;
633        let _start_key = keys[start] as f64;
634
635        // Simple linear regression
636        let mut sum_x = 0.0;
637        let mut sum_y = 0.0;
638        let mut sum_xy = 0.0;
639        let mut sum_xx = 0.0;
640
641        for i in start..=end {
642            let x = (keys[i] - keys[start]) as f64;
643            let y = i as f64;
644            sum_x += x;
645            sum_y += y;
646            sum_xy += x * y;
647            sum_xx += x * x;
648        }
649
650        let slope = if sum_xx * n - sum_x * sum_x != 0.0 {
651            (sum_xy * n - sum_x * sum_y) / (sum_xx * n - sum_x * sum_x)
652        } else {
653            0.0
654        };
655
656        let intercept = (sum_y - slope * sum_x) / n;
657
658        // Compute max error
659        let mut max_error = 0usize;
660        for i in start..=end {
661            let x = (keys[i] - keys[start]) as f64;
662            let predicted = (intercept + slope * x).round() as isize;
663            let actual = i as isize;
664            let error = (predicted - actual).unsigned_abs();
665            max_error = max_error.max(error);
666        }
667
668        (slope, intercept, max_error)
669    }
670
671    /// Dynamic programming for optimal segmentation
672    fn build_dp(keys: &[u64], max_segments: usize) -> Vec<LinearSegment> {
673        let n = keys.len();
674        if n == 0 || max_segments == 0 {
675            return Vec::new();
676        }
677
678        // dp[i][k] = (min_cost, best_prev_idx) for first i keys with k segments
679        let mut dp: Vec<Vec<(f64, usize)>> =
680            vec![vec![(f64::INFINITY, 0); max_segments + 1]; n + 1];
681        dp[0][0] = (0.0, 0);
682
683        // Precompute segment costs
684        let segment_cost = |start: usize, end: usize| -> f64 {
685            let (_, _, error) = Self::fit_segment(keys, start, end);
686            error as f64
687        };
688
689        // Fill DP table
690        for i in 1..=n {
691            for k in 1..=max_segments.min(i) {
692                for j in 0..i {
693                    if dp[j][k - 1].0 < f64::INFINITY {
694                        let cost = dp[j][k - 1].0 + segment_cost(j, i - 1);
695                        if cost < dp[i][k].0 {
696                            dp[i][k] = (cost, j);
697                        }
698                    }
699                }
700            }
701        }
702
703        // Find best number of segments
704        let mut best_k = 1;
705        let mut best_cost = f64::INFINITY;
706        for (k, dp_entry) in dp[n].iter().enumerate().take(max_segments + 1).skip(1) {
707            // Cost function: segment_error + lambda * num_segments
708            let lambda = 10.0; // Penalty for additional segments
709            let cost = dp_entry.0 + lambda * k as f64;
710            if cost < best_cost {
711                best_cost = cost;
712                best_k = k;
713            }
714        }
715
716        // Backtrack to get segments
717        let mut segments = Vec::new();
718        let mut i = n;
719        let mut k = best_k;
720
721        while k > 0 && i > 0 {
722            let j = dp[i][k].1;
723            let (slope, intercept, max_error) = Self::fit_segment(keys, j, i - 1);
724
725            let end_key = if i < n {
726                keys[i]
727            } else {
728                keys[i - 1].saturating_add(1)
729            };
730
731            segments.push(LinearSegment {
732                start_key: keys[j],
733                end_key,
734                slope,
735                intercept,
736                max_error,
737            });
738
739            i = j;
740            k -= 1;
741        }
742
743        segments.reverse();
744        segments
745    }
746
747    /// Compute statistics about the index
748    fn compute_stats(segments: &[LinearSegment], total_keys: usize) -> PiecewiseStats {
749        if segments.is_empty() {
750            return PiecewiseStats::default();
751        }
752
753        let segment_count = segments.len();
754        let total_error: usize = segments.iter().map(|s| s.max_error).sum();
755        let max_error = segments.iter().map(|s| s.max_error).max().unwrap_or(0);
756        let avg_error = total_error as f64 / segment_count as f64;
757        let compression_ratio = total_keys as f64 / segment_count as f64;
758
759        PiecewiseStats {
760            segment_count,
761            avg_error,
762            max_error,
763            compression_ratio,
764        }
765    }
766
767    /// Look up key position
768    pub fn lookup(&self, key: u64, data_len: usize) -> Option<(usize, usize)> {
769        if self.segments.is_empty() {
770            return None;
771        }
772
773        // Binary search for the correct segment
774        let segment_idx = self.find_segment(key)?;
775        let segment = &self.segments[segment_idx];
776
777        Some(segment.bounds(key, data_len))
778    }
779
780    /// Find segment containing key
781    fn find_segment(&self, key: u64) -> Option<usize> {
782        if self.segments.is_empty() {
783            return None;
784        }
785
786        // Binary search
787        let idx = self.segments.partition_point(|s| s.end_key <= key);
788        if idx > 0 && idx <= self.segments.len() {
789            let seg = &self.segments[idx - 1];
790            if key >= seg.start_key && key < seg.end_key {
791                return Some(idx - 1);
792            }
793        }
794
795        if idx < self.segments.len() {
796            let seg = &self.segments[idx];
797            if key >= seg.start_key && key < seg.end_key {
798                return Some(idx);
799            }
800        }
801
802        None
803    }
804
805    /// Get statistics
806    pub fn statistics(&self) -> &PiecewiseStats {
807        &self.stats
808    }
809
810    /// Memory usage in bytes
811    pub fn memory_bytes(&self) -> usize {
812        self.segments.len() * std::mem::size_of::<LinearSegment>()
813    }
814
815    /// Number of segments
816    pub fn segment_count(&self) -> usize {
817        self.segments.len()
818    }
819}
820
821#[cfg(test)]
822mod piecewise_tests {
823    use super::*;
824
825    #[test]
826    fn test_piecewise_sequential() {
827        let keys: Vec<u64> = (0..1000).collect();
828        let index = PiecewiseLearnedIndex::build(&keys, 2);
829
830        // Should need very few segments for sequential data
831        assert!(index.segment_count() <= 5);
832        assert!(index.stats.avg_error <= 2.0);
833
834        // Lookup should work
835        let (low, high) = index.lookup(500, 1000).unwrap();
836        assert!(low <= 500 && 500 <= high);
837    }
838
839    #[test]
840    fn test_piecewise_timestamp() {
841        // Simulate timestamps with ~1ms intervals + jitter
842        let base = 1700000000000000u64;
843        let keys: Vec<u64> = (0..1000).map(|i| base + i * 1000 + (i % 10)).collect();
844
845        let index = PiecewiseLearnedIndex::build(&keys, 5);
846
847        // Should handle slight jitter
848        assert!(index.stats.max_error <= 5);
849
850        let (low, high) = index.lookup(base + 500000, 1000).unwrap();
851        assert!(high - low <= 10); // Tight bounds
852    }
853
854    #[test]
855    fn test_piecewise_optimal() {
856        let keys: Vec<u64> = (0..100).map(|i| i * i).collect(); // Quadratic
857
858        let index = PiecewiseLearnedIndex::build_optimal(&keys, 10);
859
860        // Should have at least one segment
861        assert!(index.segment_count() >= 1);
862
863        // All lookups should work
864        for i in 0..100 {
865            let (low, high) = index.lookup(i * i, 100).unwrap();
866            assert!(
867                low <= i as usize && i as usize <= high,
868                "Key {} (i={}): expected bounds to contain {}, got ({}, {})",
869                i * i,
870                i,
871                i,
872                low,
873                high
874            );
875        }
876    }
877
878    #[test]
879    fn test_piecewise_memory() {
880        let keys: Vec<u64> = (0..10000).collect();
881        let index = PiecewiseLearnedIndex::build(&keys, 10);
882
883        // Should use much less memory than storing all keys
884        let key_memory = keys.len() * std::mem::size_of::<u64>();
885        assert!(index.memory_bytes() < key_memory / 10);
886
887        println!(
888            "Compression: {} keys -> {} segments ({:.1}x)",
889            keys.len(),
890            index.segment_count(),
891            index.stats.compression_ratio
892        );
893    }
894}
895
896// =============================================================================
897// Task 9: Recursive Model Index (RMI) with Delta Updates
898// =============================================================================
899
900/// Two-level Recursive Model Index for O(1) expected lookups
901///
902/// ## Architecture
903/// ```text
904/// Level 1 (Root):   [Linear Model] → routes to leaf model
905/// Level 2 (Leaves): [PLM 0] [PLM 1] ... [PLM N] → position predictions
906/// ```
907///
908/// ## Performance
909/// - Space: O(M × params) where M = number of models
910/// - Lookup: O(1) average, O(log ε) for binary search in error range
911#[derive(Debug)]
912pub struct RecursiveModelIndex {
913    /// Root model: maps normalized key → leaf model index
914    root_slope: f64,
915    root_intercept: f64,
916    /// Leaf models (piecewise linear within each bucket)
917    leaves: Vec<PiecewiseLearnedIndex>,
918    /// Min key for normalization
919    min_key: u64,
920    /// Max key for normalization
921    max_key: u64,
922    /// Key range as f64
923    key_range: f64,
924    /// Total keys
925    num_keys: usize,
926    /// Global max error
927    max_error: usize,
928}
929
930impl RecursiveModelIndex {
931    /// Build a 2-level RMI
932    ///
933    /// # Arguments
934    /// * `keys` - Sorted keys
935    /// * `num_leaves` - Number of leaf models (typically √N)
936    /// * `leaf_max_error` - Max error per leaf segment
937    pub fn build(keys: &[u64], num_leaves: usize, leaf_max_error: usize) -> Self {
938        let n = keys.len();
939        if n == 0 {
940            return Self {
941                root_slope: 0.0,
942                root_intercept: 0.0,
943                leaves: Vec::new(),
944                min_key: 0,
945                max_key: 0,
946                key_range: 0.0,
947                num_keys: 0,
948                max_error: 0,
949            };
950        }
951
952        let min_key = keys[0];
953        let max_key = keys[n - 1];
954        let key_range = if max_key == min_key {
955            1.0
956        } else {
957            (max_key - min_key) as f64
958        };
959        let num_leaves = num_leaves.min(n).max(1);
960
961        // Fit root model: normalized_key → bucket_index
962        let bucket_size = n.div_ceil(num_leaves);
963        let root_slope = num_leaves as f64; // Maps [0,1] to [0, num_leaves]
964        let root_intercept = 0.0;
965
966        // Build leaf models for each bucket
967        let mut leaves = Vec::with_capacity(num_leaves);
968        let mut global_max_error = 0usize;
969
970        for bucket_idx in 0..num_leaves {
971            let start = bucket_idx * bucket_size;
972            let end = ((bucket_idx + 1) * bucket_size).min(n);
973
974            if start < n {
975                let bucket_keys: Vec<u64> = keys[start..end].to_vec();
976                let leaf = PiecewiseLearnedIndex::build(&bucket_keys, leaf_max_error);
977                global_max_error = global_max_error.max(leaf.stats.max_error);
978                leaves.push(leaf);
979            }
980        }
981
982        Self {
983            root_slope,
984            root_intercept,
985            leaves,
986            min_key,
987            max_key,
988            key_range,
989            num_keys: n,
990            max_error: global_max_error,
991        }
992    }
993
994    /// Look up position bounds for a key
995    pub fn lookup(&self, key: u64, data_len: usize) -> Option<(usize, usize)> {
996        if self.num_keys == 0 || key < self.min_key || key > self.max_key {
997            return None;
998        }
999
1000        // Normalize key to [0, 1]
1001        let normalized = (key - self.min_key) as f64 / self.key_range;
1002
1003        // Route to leaf
1004        let leaf_idx_f = self.root_slope * normalized + self.root_intercept;
1005        let leaf_idx = (leaf_idx_f as usize).min(self.leaves.len().saturating_sub(1));
1006
1007        // Query leaf model
1008        if let Some(leaf) = self.leaves.get(leaf_idx) {
1009            // Leaf returns relative position within bucket
1010            if let Some((rel_low, rel_high)) = leaf.lookup(key, data_len) {
1011                // Convert to absolute position
1012                let bucket_size = self.num_keys.div_ceil(self.leaves.len());
1013                let bucket_start = leaf_idx * bucket_size;
1014                let abs_low = bucket_start + rel_low;
1015                let abs_high = (bucket_start + rel_high).min(data_len.saturating_sub(1));
1016                return Some((abs_low, abs_high));
1017            }
1018        }
1019
1020        // Fallback: return full range
1021        Some((0, data_len.saturating_sub(1)))
1022    }
1023
1024    /// Get space usage in bytes
1025    pub fn size_bytes(&self) -> usize {
1026        let base = std::mem::size_of::<Self>();
1027        let leaves: usize = self.leaves.iter().map(|l| l.memory_bytes()).sum();
1028        base + leaves
1029    }
1030
1031    /// Get number of leaf models
1032    pub fn num_leaves(&self) -> usize {
1033        self.leaves.len()
1034    }
1035
1036    /// Get total segment count across all leaves
1037    pub fn total_segments(&self) -> usize {
1038        self.leaves.iter().map(|l| l.segment_count()).sum()
1039    }
1040}
1041
1042/// Delta index for online updates without full rebuild
1043///
1044/// Maintains pending inserts/deletes in a B-tree structure,
1045/// merging with static learned index during lookup.
1046#[derive(Debug)]
1047pub struct DeltaIndex {
1048    /// Inserted keys (key → tombstone flag)
1049    entries: BTreeMap<u64, bool>,
1050    /// Insert count since last rebuild
1051    insert_count: usize,
1052    /// Delete count since last rebuild
1053    delete_count: usize,
1054    /// Rebuild threshold as fraction of static size
1055    rebuild_threshold: f64,
1056}
1057
1058impl DeltaIndex {
1059    /// Create new delta index
1060    pub fn new(rebuild_threshold: f64) -> Self {
1061        Self {
1062            entries: BTreeMap::new(),
1063            insert_count: 0,
1064            delete_count: 0,
1065            rebuild_threshold,
1066        }
1067    }
1068
1069    /// Insert a key
1070    pub fn insert(&mut self, key: u64) {
1071        if let Some(deleted) = self.entries.get_mut(&key) {
1072            // Resurrect deleted key
1073            if *deleted {
1074                *deleted = false;
1075                self.delete_count = self.delete_count.saturating_sub(1);
1076            }
1077        } else {
1078            self.entries.insert(key, false);
1079            self.insert_count += 1;
1080        }
1081    }
1082
1083    /// Delete a key (tombstone)
1084    pub fn delete(&mut self, key: u64) {
1085        self.entries.insert(key, true);
1086        self.delete_count += 1;
1087    }
1088
1089    /// Check if key is in delta
1090    pub fn contains(&self, key: u64) -> Option<bool> {
1091        self.entries.get(&key).copied()
1092    }
1093
1094    /// Check if rebuild is needed
1095    pub fn needs_rebuild(&self, static_size: usize) -> bool {
1096        if static_size == 0 {
1097            return self.entries.len() > 100;
1098        }
1099        let delta_size = self.insert_count + self.delete_count;
1100        (delta_size as f64 / static_size as f64) > self.rebuild_threshold
1101    }
1102
1103    /// Get all live keys (for rebuild)
1104    pub fn live_keys(&self) -> impl Iterator<Item = u64> + '_ {
1105        self.entries
1106            .iter()
1107            .filter(|(_, deleted)| !**deleted)
1108            .map(|(k, _)| *k)
1109    }
1110
1111    /// Get all deleted keys
1112    pub fn deleted_keys(&self) -> impl Iterator<Item = u64> + '_ {
1113        self.entries
1114            .iter()
1115            .filter(|(_, deleted)| **deleted)
1116            .map(|(k, _)| *k)
1117    }
1118
1119    /// Clear after rebuild
1120    pub fn clear(&mut self) {
1121        self.entries.clear();
1122        self.insert_count = 0;
1123        self.delete_count = 0;
1124    }
1125
1126    /// Size of delta
1127    pub fn len(&self) -> usize {
1128        self.entries.len()
1129    }
1130
1131    /// Is delta empty
1132    pub fn is_empty(&self) -> bool {
1133        self.entries.is_empty()
1134    }
1135}
1136
1137/// Hybrid RMI with delta updates and B-tree fallback
1138#[derive(Debug)]
1139pub struct HybridRMI {
1140    /// Static RMI structure
1141    rmi: RecursiveModelIndex,
1142    /// Delta index for updates
1143    delta: DeltaIndex,
1144    /// Sorted keys for binary search
1145    keys: Vec<u64>,
1146    /// B-tree fallback for high-error keys
1147    btree_fallback: BTreeMap<u64, usize>,
1148    /// Stats
1149    stats: HybridRMIStats,
1150}
1151
1152/// Statistics for HybridRMI
1153#[derive(Debug, Clone, Default)]
1154pub struct HybridRMIStats {
1155    /// Lookups via RMI
1156    pub rmi_lookups: u64,
1157    /// Lookups via B-tree fallback
1158    pub btree_lookups: u64,
1159    /// Lookups in delta
1160    pub delta_lookups: u64,
1161    /// Number of rebuilds
1162    pub rebuilds: u64,
1163}
1164
1165impl HybridRMI {
1166    /// Build hybrid RMI
1167    pub fn build(
1168        keys: Vec<u64>,
1169        num_leaves: usize,
1170        leaf_max_error: usize,
1171        rebuild_threshold: f64,
1172    ) -> Self {
1173        let rmi = RecursiveModelIndex::build(&keys, num_leaves, leaf_max_error);
1174
1175        // Find overflow keys for B-tree fallback
1176        let _overflow_threshold = leaf_max_error * 3;
1177        let mut btree_fallback = BTreeMap::new();
1178
1179        for (pos, &key) in keys.iter().enumerate() {
1180            if let Some((low, high)) = rmi.lookup(key, keys.len())
1181                && (pos < low || pos > high)
1182            {
1183                btree_fallback.insert(key, pos);
1184            }
1185        }
1186
1187        Self {
1188            rmi,
1189            delta: DeltaIndex::new(rebuild_threshold),
1190            keys,
1191            btree_fallback,
1192            stats: HybridRMIStats::default(),
1193        }
1194    }
1195
1196    /// Look up a key
1197    pub fn lookup(&mut self, key: u64) -> Option<usize> {
1198        // 1. Check delta first
1199        if let Some(deleted) = self.delta.contains(key) {
1200            self.stats.delta_lookups += 1;
1201            if deleted {
1202                return None; // Deleted
1203            }
1204            // Key in delta but not deleted - need to find position
1205            // This is a recent insert, position unknown
1206            return None;
1207        }
1208
1209        // 2. Check B-tree fallback
1210        if let Some(&pos) = self.btree_fallback.get(&key) {
1211            self.stats.btree_lookups += 1;
1212            return Some(pos);
1213        }
1214
1215        // 3. Use RMI
1216        self.stats.rmi_lookups += 1;
1217        if let Some((low, high)) = self.rmi.lookup(key, self.keys.len()) {
1218            // Binary search in predicted range
1219            let range = &self.keys[low..=high.min(self.keys.len().saturating_sub(1))];
1220            if let Ok(idx) = range.binary_search(&key) {
1221                return Some(low + idx);
1222            }
1223        }
1224
1225        // 4. Full binary search fallback
1226        self.keys.binary_search(&key).ok()
1227    }
1228
1229    /// Insert a key (goes to delta)
1230    pub fn insert(&mut self, key: u64) {
1231        self.delta.insert(key);
1232        if self.delta.needs_rebuild(self.keys.len()) {
1233            self.rebuild();
1234        }
1235    }
1236
1237    /// Delete a key
1238    pub fn delete(&mut self, key: u64) {
1239        self.delta.delete(key);
1240    }
1241
1242    /// Rebuild the index
1243    pub fn rebuild(&mut self) {
1244        // Merge delta into keys
1245        let deleted: HashSet<u64> = self.delta.deleted_keys().collect();
1246
1247        let mut new_keys: Vec<u64> = self
1248            .keys
1249            .iter()
1250            .filter(|&k| !deleted.contains(k))
1251            .copied()
1252            .collect();
1253
1254        new_keys.extend(self.delta.live_keys());
1255        new_keys.sort_unstable();
1256        new_keys.dedup();
1257
1258        // Rebuild RMI
1259        let n = new_keys.len();
1260        let num_leaves = ((n as f64).sqrt().ceil() as usize).max(1);
1261        let leaf_max_error = self.rmi.max_error.max(10);
1262
1263        let new_rmi = RecursiveModelIndex::build(&new_keys, num_leaves, leaf_max_error);
1264
1265        // Rebuild B-tree fallback
1266        let mut new_btree = BTreeMap::new();
1267        for (pos, &key) in new_keys.iter().enumerate() {
1268            if let Some((low, high)) = new_rmi.lookup(key, new_keys.len())
1269                && (pos < low || pos > high)
1270            {
1271                new_btree.insert(key, pos);
1272            }
1273        }
1274
1275        self.rmi = new_rmi;
1276        self.keys = new_keys;
1277        self.btree_fallback = new_btree;
1278        self.delta.clear();
1279        self.stats.rebuilds += 1;
1280    }
1281
1282    /// Get statistics
1283    pub fn stats(&self) -> &HybridRMIStats {
1284        &self.stats
1285    }
1286
1287    /// Get number of keys
1288    pub fn len(&self) -> usize {
1289        self.keys.len()
1290    }
1291
1292    /// Check if empty
1293    pub fn is_empty(&self) -> bool {
1294        self.keys.is_empty()
1295    }
1296
1297    /// Space usage in bytes
1298    pub fn size_bytes(&self) -> usize {
1299        self.rmi.size_bytes()
1300            + self.keys.len() * std::mem::size_of::<u64>()
1301            + self.delta.len() * std::mem::size_of::<(u64, bool)>()
1302            + self.btree_fallback.len()
1303                * (std::mem::size_of::<u64>() + std::mem::size_of::<usize>())
1304    }
1305}
1306
1307#[cfg(test)]
1308mod rmi_tests {
1309    use super::*;
1310
1311    #[test]
1312    fn test_rmi_build() {
1313        let keys: Vec<u64> = (0..10000).collect();
1314        let rmi = RecursiveModelIndex::build(&keys, 100, 10);
1315
1316        assert_eq!(rmi.num_leaves(), 100);
1317        assert!(rmi.max_error <= 10);
1318
1319        // Test lookups
1320        for i in (0..10000).step_by(100) {
1321            if let Some((low, high)) = rmi.lookup(i, 10000) {
1322                assert!(
1323                    low <= i as usize && high >= i as usize,
1324                    "Key {}: bounds ({}, {}) don't contain position",
1325                    i,
1326                    low,
1327                    high
1328                );
1329            }
1330        }
1331    }
1332
1333    #[test]
1334    fn test_delta_index() {
1335        let mut delta = DeltaIndex::new(0.1);
1336
1337        delta.insert(100);
1338        delta.insert(200);
1339        delta.delete(150);
1340
1341        assert_eq!(delta.contains(100), Some(false));
1342        assert_eq!(delta.contains(150), Some(true)); // deleted
1343        assert_eq!(delta.contains(300), None);
1344
1345        // Resurrect deleted key
1346        delta.insert(150);
1347        assert_eq!(delta.contains(150), Some(false));
1348    }
1349
1350    #[test]
1351    fn test_hybrid_rmi_lookup() {
1352        let keys: Vec<u64> = (0..1000).step_by(2).collect();
1353        let mut rmi = HybridRMI::build(keys, 10, 5, 0.1);
1354
1355        // Find existing key
1356        assert_eq!(rmi.lookup(100), Some(50));
1357        assert_eq!(rmi.lookup(500), Some(250));
1358
1359        // Non-existing
1360        assert!(rmi.lookup(101).is_none());
1361    }
1362
1363    #[test]
1364    fn test_hybrid_rmi_updates() {
1365        let keys: Vec<u64> = (0..100).collect();
1366        let mut rmi = HybridRMI::build(keys, 5, 5, 0.5);
1367
1368        // Delete a key - next lookup should return None
1369        rmi.delete(50);
1370        assert!(rmi.lookup(50).is_none());
1371
1372        // Insert a new key (goes to delta, won't be in RMI yet)
1373        rmi.insert(200);
1374
1375        // Verify delta has something
1376        assert!(!rmi.delta.is_empty());
1377    }
1378
1379    #[test]
1380    fn test_hybrid_rmi_rebuild() {
1381        let keys: Vec<u64> = (0..100).collect();
1382        let mut rmi = HybridRMI::build(keys, 5, 5, 0.05);
1383
1384        let initial_len = rmi.len();
1385
1386        // Insert enough to trigger rebuild
1387        for i in 100..115 {
1388            rmi.insert(i);
1389        }
1390
1391        // Rebuild should have happened
1392        assert!(rmi.stats().rebuilds > 0);
1393        // Length should have increased
1394        assert!(rmi.len() > initial_len);
1395    }
1396
1397    #[test]
1398    fn test_rmi_space_efficiency() {
1399        let n = 100_000usize;
1400        let keys: Vec<u64> = (0..n as u64).collect();
1401        let rmi = RecursiveModelIndex::build(&keys, 100, 50);
1402
1403        let rmi_size = rmi.size_bytes();
1404        let raw_size = n * std::mem::size_of::<u64>();
1405
1406        println!(
1407            "RMI size: {} bytes, Raw keys: {} bytes, Ratio: {:.2}x",
1408            rmi_size,
1409            raw_size,
1410            raw_size as f64 / rmi_size as f64
1411        );
1412
1413        // RMI structure should be much smaller than raw keys
1414        assert!(
1415            rmi_size < raw_size / 10,
1416            "RMI size {} should be < 10% of raw size {}",
1417            rmi_size,
1418            raw_size
1419        );
1420    }
1421}