Skip to main content

sochdb_storage/
learned_index_integration.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Learned Index Integration (Task 5)
19//!
20//! Integrates LearnedSparseIndex for O(1) expected point lookups:
21//! - Accelerates row lookups by key in LSCS
22//! - Falls back to binary search for outliers
23//! - Provides adaptive index selection based on data distribution
24//!
25//! ## Lookup Flow
26//!
27//! ```text
28//! get(key)
29//!   │
30//!   ▼
31//! ┌─────────────────────┐
32//! │ LearnedIndex.lookup │
33//! └──────────┬──────────┘
34//!            │
35//!      ┌─────┴─────┐
36//!      ▼           ▼
37//!   Exact       Range[lo,hi]
38//!     │             │
39//!     ▼             ▼
40//!   O(1)      BinarySearch
41//!   fetch       O(log ε)
42//! ```
43//!
44//! ## Index Selection
45//!
46//! The system automatically chooses between:
47//! - **LearnedIndex**: For sequential/timestamp keys (O(1))
48//! - **B-Tree**: For random/UUID keys (O(log N))
49//! - **Hash**: For exact-match only (O(1))
50
51use sochdb_core::learned_index::{LearnedSparseIndex, LookupResult};
52use std::collections::{BTreeMap, HashSet};
53use std::sync::Arc;
54
55/// Index type based on key characteristics
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum IndexType {
58    /// Learned index for sequential/monotonic keys
59    Learned,
60    /// B-Tree for random keys
61    BTree,
62    /// Hash for exact match only
63    Hash,
64    /// No index (scan)
65    None,
66}
67
68/// Key distribution statistics for index selection
69#[derive(Debug, Clone, Default)]
70pub struct KeyStats {
71    /// Number of keys analyzed
72    pub count: usize,
73    /// Minimum key value
74    pub min_key: u64,
75    /// Maximum key value
76    pub max_key: u64,
77    /// Is monotonically increasing
78    pub is_monotonic: bool,
79    /// Density: count / (max - min + 1)
80    pub density: f64,
81    /// Estimated entropy (randomness)
82    pub entropy: f64,
83}
84
85impl KeyStats {
86    /// Analyze keys to determine statistics
87    pub fn analyze(keys: &[u64]) -> Self {
88        if keys.is_empty() {
89            return Self::default();
90        }
91
92        let min_key = *keys.iter().min().unwrap();
93        let max_key = *keys.iter().max().unwrap();
94        let count = keys.len();
95
96        // Check monotonicity
97        let is_monotonic = keys.windows(2).all(|w| w[0] <= w[1]);
98
99        // Calculate density
100        let range = (max_key - min_key + 1) as f64;
101        let density = count as f64 / range;
102
103        // Estimate entropy from gaps
104        let entropy = Self::estimate_entropy(keys);
105
106        Self {
107            count,
108            min_key,
109            max_key,
110            is_monotonic,
111            density,
112            entropy,
113        }
114    }
115
116    /// Estimate entropy from key gaps
117    fn estimate_entropy(keys: &[u64]) -> f64 {
118        if keys.len() < 2 {
119            return 0.0;
120        }
121
122        // Calculate gap distribution
123        let gaps: Vec<u64> = keys.windows(2).map(|w| w[1] - w[0]).collect();
124
125        if gaps.is_empty() {
126            return 0.0;
127        }
128
129        let mean_gap = gaps.iter().sum::<u64>() as f64 / gaps.len() as f64;
130        if mean_gap == 0.0 {
131            return 0.0;
132        }
133
134        // Calculate coefficient of variation as entropy proxy
135        let variance = gaps
136            .iter()
137            .map(|&g| {
138                let diff = g as f64 - mean_gap;
139                diff * diff
140            })
141            .sum::<f64>()
142            / gaps.len() as f64;
143
144        (variance.sqrt() / mean_gap).min(1.0)
145    }
146
147    /// Recommend index type based on statistics
148    pub fn recommend_index_type(&self) -> IndexType {
149        if self.count == 0 {
150            return IndexType::None;
151        }
152
153        // High density + monotonic = learned index
154        if self.is_monotonic && self.density > 0.5 {
155            return IndexType::Learned;
156        }
157
158        // Low entropy (regular gaps) = learned index
159        if self.entropy < 0.3 {
160            return IndexType::Learned;
161        }
162
163        // Random keys = B-Tree
164        if self.entropy > 0.7 {
165            return IndexType::BTree;
166        }
167
168        // Default to learned for moderate cases
169        IndexType::Learned
170    }
171}
172
173/// Hybrid index that combines learned index with fallback
174pub struct HybridIndex {
175    /// Learned sparse index
176    learned: LearnedSparseIndex,
177    /// Sorted keys for fallback binary search
178    keys: Vec<u64>,
179    /// Key to position mapping for fallback
180    key_map: BTreeMap<u64, usize>,
181    /// Index type in use
182    index_type: IndexType,
183    /// Statistics
184    stats: KeyStats,
185}
186
187impl HybridIndex {
188    /// Build hybrid index from sorted keys
189    pub fn build(keys: &[u64]) -> Self {
190        let stats = KeyStats::analyze(keys);
191        let index_type = stats.recommend_index_type();
192
193        let learned = if index_type == IndexType::Learned {
194            LearnedSparseIndex::build(keys)
195        } else {
196            LearnedSparseIndex::empty()
197        };
198
199        let key_map: BTreeMap<u64, usize> = keys.iter().enumerate().map(|(i, &k)| (k, i)).collect();
200
201        Self {
202            learned,
203            keys: keys.to_vec(),
204            key_map,
205            index_type,
206            stats,
207        }
208    }
209
210    /// Lookup key position
211    ///
212    /// Returns the exact position if found, or None.
213    pub fn lookup(&self, key: u64) -> Option<usize> {
214        match self.index_type {
215            IndexType::Learned => self.lookup_learned(key),
216            IndexType::BTree | IndexType::Hash => self.key_map.get(&key).copied(),
217            IndexType::None => self.binary_search(key),
218        }
219    }
220
221    /// Lookup using learned index
222    fn lookup_learned(&self, key: u64) -> Option<usize> {
223        match self.learned.lookup(key) {
224            LookupResult::Exact(pos) => Some(pos),
225            LookupResult::Range { low, high } => {
226                // Binary search within the predicted range
227                self.binary_search_range(key, low, high)
228            }
229            LookupResult::NotFound => None,
230        }
231    }
232
233    /// Binary search within a range
234    fn binary_search_range(&self, key: u64, low: usize, high: usize) -> Option<usize> {
235        if low > high || high >= self.keys.len() {
236            return None;
237        }
238
239        let slice = &self.keys[low..=high];
240        match slice.binary_search(&key) {
241            Ok(pos) => Some(low + pos),
242            Err(_) => None,
243        }
244    }
245
246    /// Full binary search
247    fn binary_search(&self, key: u64) -> Option<usize> {
248        self.keys.binary_search(&key).ok()
249    }
250
251    /// Range lookup: find all positions in [start, end]
252    pub fn range_lookup(&self, start: u64, end: u64) -> Vec<usize> {
253        // Find start position
254        let start_pos = match self.keys.binary_search(&start) {
255            Ok(pos) => pos,
256            Err(pos) => pos,
257        };
258
259        // Find end position
260        let end_pos = match self.keys.binary_search(&end) {
261            Ok(pos) => pos + 1,
262            Err(pos) => pos,
263        };
264
265        (start_pos..end_pos.min(self.keys.len())).collect()
266    }
267
268    /// Get index statistics
269    pub fn statistics(&self) -> &KeyStats {
270        &self.stats
271    }
272
273    /// Get index type
274    pub fn index_type(&self) -> IndexType {
275        self.index_type
276    }
277
278    /// Check if learned index is efficient for this data
279    pub fn is_efficient(&self) -> bool {
280        self.learned.is_efficient()
281    }
282
283    /// Memory usage in bytes
284    pub fn memory_bytes(&self) -> usize {
285        self.learned.memory_bytes()
286            + self.keys.len() * std::mem::size_of::<u64>()
287            + self.key_map.len() * (std::mem::size_of::<u64>() + std::mem::size_of::<usize>())
288    }
289}
290
291/// Index manager for multiple tables/columns
292pub struct IndexManager {
293    /// Indexes by table and column
294    indexes: BTreeMap<(String, String), Arc<HybridIndex>>,
295}
296
297impl IndexManager {
298    /// Create a new index manager
299    pub fn new() -> Self {
300        Self {
301            indexes: BTreeMap::new(),
302        }
303    }
304
305    /// Build or update index for a column
306    pub fn build_index(&mut self, table: &str, column: &str, keys: &[u64]) {
307        let index = HybridIndex::build(keys);
308        self.indexes
309            .insert((table.to_string(), column.to_string()), Arc::new(index));
310    }
311
312    /// Get index for a column
313    pub fn get_index(&self, table: &str, column: &str) -> Option<Arc<HybridIndex>> {
314        self.indexes
315            .get(&(table.to_string(), column.to_string()))
316            .cloned()
317    }
318
319    /// Remove index
320    pub fn drop_index(&mut self, table: &str, column: &str) -> bool {
321        self.indexes
322            .remove(&(table.to_string(), column.to_string()))
323            .is_some()
324    }
325
326    /// List all indexes
327    pub fn list_indexes(&self) -> Vec<(&str, &str, IndexType)> {
328        self.indexes
329            .iter()
330            .map(|((table, column), index)| (table.as_str(), column.as_str(), index.index_type()))
331            .collect()
332    }
333}
334
335impl Default for IndexManager {
336    fn default() -> Self {
337        Self::new()
338    }
339}
340
341/// Point lookup executor using learned index
342pub struct PointLookupExecutor<'a, V> {
343    /// Index to use
344    index: &'a HybridIndex,
345    /// Data array to fetch from
346    data: &'a [V],
347}
348
349impl<'a, V> PointLookupExecutor<'a, V> {
350    /// Create a new executor
351    pub fn new(index: &'a HybridIndex, data: &'a [V]) -> Self {
352        Self { index, data }
353    }
354
355    /// Execute point lookup
356    pub fn execute(&self, key: u64) -> Option<&V> {
357        self.index.lookup(key).and_then(|pos| self.data.get(pos))
358    }
359
360    /// Execute batch lookup
361    pub fn execute_batch(&self, keys: &[u64]) -> Vec<Option<&V>> {
362        keys.iter().map(|&k| self.execute(k)).collect()
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[test]
371    fn test_sequential_keys() {
372        let keys: Vec<u64> = (1..=1000).collect();
373        let stats = KeyStats::analyze(&keys);
374
375        assert!(stats.is_monotonic);
376        assert!(stats.density > 0.99);
377        assert_eq!(stats.recommend_index_type(), IndexType::Learned);
378    }
379
380    #[test]
381    fn test_timestamp_keys() {
382        // Simulate timestamp keys (microseconds, ~1ms apart)
383        let base = 1700000000000000u64; // ~2023
384        let keys: Vec<u64> = (0..1000).map(|i| base + i * 1000).collect();
385
386        let stats = KeyStats::analyze(&keys);
387        assert!(stats.is_monotonic);
388        assert!(stats.entropy < 0.1); // Regular gaps
389        assert_eq!(stats.recommend_index_type(), IndexType::Learned);
390    }
391
392    #[test]
393    fn test_hybrid_index_lookup() {
394        let keys: Vec<u64> = (0..1000).map(|i| i * 10).collect();
395        let index = HybridIndex::build(&keys);
396
397        // Exact match
398        assert_eq!(index.lookup(500), Some(50));
399        assert_eq!(index.lookup(990), Some(99));
400
401        // Not found
402        assert_eq!(index.lookup(5), None);
403        assert_eq!(index.lookup(10000), None);
404    }
405
406    #[test]
407    fn test_range_lookup() {
408        let keys: Vec<u64> = (0..100).map(|i| i * 10).collect();
409        let index = HybridIndex::build(&keys);
410
411        // Range [100, 300] should include positions for 100, 110, ..., 300
412        let positions = index.range_lookup(100, 300);
413        assert_eq!(positions.len(), 21); // 100, 110, ..., 300 = 21 values
414        assert_eq!(positions[0], 10); // Position of key 100
415    }
416
417    #[test]
418    fn test_point_lookup_executor() {
419        let keys: Vec<u64> = vec![10, 20, 30, 40, 50];
420        let values = vec!["a", "b", "c", "d", "e"];
421        let index = HybridIndex::build(&keys);
422
423        let executor = PointLookupExecutor::new(&index, &values);
424
425        assert_eq!(executor.execute(20), Some(&"b"));
426        assert_eq!(executor.execute(50), Some(&"e"));
427        assert_eq!(executor.execute(25), None);
428    }
429
430    #[test]
431    fn test_index_manager() {
432        let mut manager = IndexManager::new();
433
434        let keys: Vec<u64> = (0..100).collect();
435        manager.build_index("users", "id", &keys);
436
437        let index = manager.get_index("users", "id").unwrap();
438        assert_eq!(index.lookup(50), Some(50));
439
440        let indexes = manager.list_indexes();
441        assert_eq!(indexes.len(), 1);
442        assert_eq!(indexes[0], ("users", "id", IndexType::Learned));
443    }
444}
445
446// =============================================================================
447// Task 4: Piecewise Learned Index Enhancement
448// =============================================================================
449
450/// A segment in the piecewise linear index
451#[derive(Debug, Clone)]
452pub struct LinearSegment {
453    /// Start key of this segment
454    pub start_key: u64,
455    /// End key of this segment (exclusive)
456    pub end_key: u64,
457    /// Slope: position increase per key increase
458    pub slope: f64,
459    /// Intercept: position at start_key
460    pub intercept: f64,
461    /// Maximum error in this segment
462    pub max_error: usize,
463}
464
465impl LinearSegment {
466    /// Predict position for a key within this segment
467    pub fn predict(&self, key: u64) -> usize {
468        if key < self.start_key {
469            return 0;
470        }
471        let delta = (key - self.start_key) as f64;
472        (self.intercept + delta * self.slope).round() as usize
473    }
474
475    /// Get search bounds accounting for error
476    pub fn bounds(&self, key: u64, data_len: usize) -> (usize, usize) {
477        let pred = self.predict(key);
478        let low = pred.saturating_sub(self.max_error);
479        let high = (pred + self.max_error).min(data_len.saturating_sub(1));
480        (low, high)
481    }
482}
483
484/// Piecewise Linear Index using dynamic programming for optimal segmentation
485///
486/// ## Algorithm
487/// Uses dynamic programming to find the optimal set of linear segments
488/// that minimize total error while keeping segments count bounded.
489///
490/// Cost function: `sum(segment_errors) + lambda * num_segments`
491///
492/// ## Performance
493/// - Construction: O(n²) with DP, O(n) with greedy
494/// - Lookup: O(log S + log ε) where S = segment count, ε = max error
495#[derive(Debug)]
496#[allow(dead_code)]
497pub struct PiecewiseLearnedIndex {
498    /// Sorted segments by start_key
499    segments: Vec<LinearSegment>,
500    /// Target maximum error per segment
501    max_error_bound: usize,
502    /// Total number of keys indexed
503    total_keys: usize,
504    /// Construction statistics
505    stats: PiecewiseStats,
506}
507
508/// Statistics about the piecewise index
509#[derive(Debug, Clone, Default)]
510pub struct PiecewiseStats {
511    /// Number of segments
512    pub segment_count: usize,
513    /// Average error across segments
514    pub avg_error: f64,
515    /// Maximum error across all segments
516    pub max_error: usize,
517    /// Compression ratio (keys / segments)
518    pub compression_ratio: f64,
519}
520
521impl PiecewiseLearnedIndex {
522    /// Build piecewise index with specified error bound
523    ///
524    /// Uses greedy algorithm for efficiency (O(n) instead of O(n²) DP)
525    pub fn build(keys: &[u64], max_error: usize) -> Self {
526        if keys.is_empty() {
527            return Self {
528                segments: Vec::new(),
529                max_error_bound: max_error,
530                total_keys: 0,
531                stats: PiecewiseStats::default(),
532            };
533        }
534
535        let segments = Self::build_greedy(keys, max_error);
536        let stats = Self::compute_stats(&segments, keys.len());
537
538        Self {
539            segments,
540            max_error_bound: max_error,
541            total_keys: keys.len(),
542            stats,
543        }
544    }
545
546    /// Build using dynamic programming for optimal segmentation
547    ///
548    /// Slower O(n²) but produces optimal segments minimizing total error.
549    pub fn build_optimal(keys: &[u64], max_segments: usize) -> Self {
550        if keys.is_empty() {
551            return Self {
552                segments: Vec::new(),
553                max_error_bound: 0,
554                total_keys: 0,
555                stats: PiecewiseStats::default(),
556            };
557        }
558
559        let segments = Self::build_dp(keys, max_segments);
560        let max_error = segments.iter().map(|s| s.max_error).max().unwrap_or(0);
561        let stats = Self::compute_stats(&segments, keys.len());
562
563        Self {
564            segments,
565            max_error_bound: max_error,
566            total_keys: keys.len(),
567            stats,
568        }
569    }
570
571    /// Greedy segmentation algorithm
572    fn build_greedy(keys: &[u64], max_error: usize) -> Vec<LinearSegment> {
573        let mut segments = Vec::new();
574        let mut start_idx = 0;
575
576        while start_idx < keys.len() {
577            // Find the longest segment starting at start_idx with error <= max_error
578            let (end_idx, segment) = Self::find_longest_segment(keys, start_idx, max_error);
579            segments.push(segment);
580            start_idx = end_idx + 1;
581        }
582
583        segments
584    }
585
586    /// Find longest segment with bounded error
587    fn find_longest_segment(
588        keys: &[u64],
589        start_idx: usize,
590        max_error: usize,
591    ) -> (usize, LinearSegment) {
592        let start_key = keys[start_idx];
593        let mut end_idx = start_idx;
594        let mut best_slope = 0.0;
595        let mut best_intercept = start_idx as f64;
596        let mut best_error = 0;
597
598        // Extend segment as far as possible
599        for i in (start_idx + 1)..keys.len() {
600            let (slope, intercept, error) = Self::fit_segment(keys, start_idx, i);
601            if error <= max_error {
602                end_idx = i;
603                best_slope = slope;
604                best_intercept = intercept;
605                best_error = error;
606            } else {
607                break;
608            }
609        }
610
611        let end_key = if end_idx + 1 < keys.len() {
612            keys[end_idx + 1]
613        } else {
614            keys[end_idx].saturating_add(1)
615        };
616
617        (
618            end_idx,
619            LinearSegment {
620                start_key,
621                end_key,
622                slope: best_slope,
623                intercept: best_intercept,
624                max_error: best_error,
625            },
626        )
627    }
628
629    /// Fit a linear segment and compute max error
630    fn fit_segment(keys: &[u64], start: usize, end: usize) -> (f64, f64, usize) {
631        if start == end {
632            return (0.0, start as f64, 0);
633        }
634
635        let n = (end - start + 1) as f64;
636        let _start_key = keys[start] as f64;
637
638        // Simple linear regression
639        let mut sum_x = 0.0;
640        let mut sum_y = 0.0;
641        let mut sum_xy = 0.0;
642        let mut sum_xx = 0.0;
643
644        for i in start..=end {
645            let x = (keys[i] - keys[start]) as f64;
646            let y = i as f64;
647            sum_x += x;
648            sum_y += y;
649            sum_xy += x * y;
650            sum_xx += x * x;
651        }
652
653        let slope = if sum_xx * n - sum_x * sum_x != 0.0 {
654            (sum_xy * n - sum_x * sum_y) / (sum_xx * n - sum_x * sum_x)
655        } else {
656            0.0
657        };
658
659        let intercept = (sum_y - slope * sum_x) / n;
660
661        // Compute max error
662        let mut max_error = 0usize;
663        for i in start..=end {
664            let x = (keys[i] - keys[start]) as f64;
665            let predicted = (intercept + slope * x).round() as isize;
666            let actual = i as isize;
667            let error = (predicted - actual).unsigned_abs();
668            max_error = max_error.max(error);
669        }
670
671        (slope, intercept, max_error)
672    }
673
674    /// Dynamic programming for optimal segmentation
675    fn build_dp(keys: &[u64], max_segments: usize) -> Vec<LinearSegment> {
676        let n = keys.len();
677        if n == 0 || max_segments == 0 {
678            return Vec::new();
679        }
680
681        // dp[i][k] = (min_cost, best_prev_idx) for first i keys with k segments
682        let mut dp: Vec<Vec<(f64, usize)>> =
683            vec![vec![(f64::INFINITY, 0); max_segments + 1]; n + 1];
684        dp[0][0] = (0.0, 0);
685
686        // Precompute segment costs
687        let segment_cost = |start: usize, end: usize| -> f64 {
688            let (_, _, error) = Self::fit_segment(keys, start, end);
689            error as f64
690        };
691
692        // Fill DP table
693        for i in 1..=n {
694            for k in 1..=max_segments.min(i) {
695                for j in 0..i {
696                    if dp[j][k - 1].0 < f64::INFINITY {
697                        let cost = dp[j][k - 1].0 + segment_cost(j, i - 1);
698                        if cost < dp[i][k].0 {
699                            dp[i][k] = (cost, j);
700                        }
701                    }
702                }
703            }
704        }
705
706        // Find best number of segments
707        let mut best_k = 1;
708        let mut best_cost = f64::INFINITY;
709        for (k, dp_entry) in dp[n].iter().enumerate().take(max_segments + 1).skip(1) {
710            // Cost function: segment_error + lambda * num_segments
711            let lambda = 10.0; // Penalty for additional segments
712            let cost = dp_entry.0 + lambda * k as f64;
713            if cost < best_cost {
714                best_cost = cost;
715                best_k = k;
716            }
717        }
718
719        // Backtrack to get segments
720        let mut segments = Vec::new();
721        let mut i = n;
722        let mut k = best_k;
723
724        while k > 0 && i > 0 {
725            let j = dp[i][k].1;
726            let (slope, intercept, max_error) = Self::fit_segment(keys, j, i - 1);
727
728            let end_key = if i < n {
729                keys[i]
730            } else {
731                keys[i - 1].saturating_add(1)
732            };
733
734            segments.push(LinearSegment {
735                start_key: keys[j],
736                end_key,
737                slope,
738                intercept,
739                max_error,
740            });
741
742            i = j;
743            k -= 1;
744        }
745
746        segments.reverse();
747        segments
748    }
749
750    /// Compute statistics about the index
751    fn compute_stats(segments: &[LinearSegment], total_keys: usize) -> PiecewiseStats {
752        if segments.is_empty() {
753            return PiecewiseStats::default();
754        }
755
756        let segment_count = segments.len();
757        let total_error: usize = segments.iter().map(|s| s.max_error).sum();
758        let max_error = segments.iter().map(|s| s.max_error).max().unwrap_or(0);
759        let avg_error = total_error as f64 / segment_count as f64;
760        let compression_ratio = total_keys as f64 / segment_count as f64;
761
762        PiecewiseStats {
763            segment_count,
764            avg_error,
765            max_error,
766            compression_ratio,
767        }
768    }
769
770    /// Look up key position
771    pub fn lookup(&self, key: u64, data_len: usize) -> Option<(usize, usize)> {
772        if self.segments.is_empty() {
773            return None;
774        }
775
776        // Binary search for the correct segment
777        let segment_idx = self.find_segment(key)?;
778        let segment = &self.segments[segment_idx];
779
780        Some(segment.bounds(key, data_len))
781    }
782
783    /// Find segment containing key
784    fn find_segment(&self, key: u64) -> Option<usize> {
785        if self.segments.is_empty() {
786            return None;
787        }
788
789        // Binary search
790        let idx = self.segments.partition_point(|s| s.end_key <= key);
791        if idx > 0 && idx <= self.segments.len() {
792            let seg = &self.segments[idx - 1];
793            if key >= seg.start_key && key < seg.end_key {
794                return Some(idx - 1);
795            }
796        }
797
798        if idx < self.segments.len() {
799            let seg = &self.segments[idx];
800            if key >= seg.start_key && key < seg.end_key {
801                return Some(idx);
802            }
803        }
804
805        None
806    }
807
808    /// Get statistics
809    pub fn statistics(&self) -> &PiecewiseStats {
810        &self.stats
811    }
812
813    /// Memory usage in bytes
814    pub fn memory_bytes(&self) -> usize {
815        self.segments.len() * std::mem::size_of::<LinearSegment>()
816    }
817
818    /// Number of segments
819    pub fn segment_count(&self) -> usize {
820        self.segments.len()
821    }
822}
823
824#[cfg(test)]
825mod piecewise_tests {
826    use super::*;
827
828    #[test]
829    fn test_piecewise_sequential() {
830        let keys: Vec<u64> = (0..1000).collect();
831        let index = PiecewiseLearnedIndex::build(&keys, 2);
832
833        // Should need very few segments for sequential data
834        assert!(index.segment_count() <= 5);
835        assert!(index.stats.avg_error <= 2.0);
836
837        // Lookup should work
838        let (low, high) = index.lookup(500, 1000).unwrap();
839        assert!(low <= 500 && 500 <= high);
840    }
841
842    #[test]
843    fn test_piecewise_timestamp() {
844        // Simulate timestamps with ~1ms intervals + jitter
845        let base = 1700000000000000u64;
846        let keys: Vec<u64> = (0..1000).map(|i| base + i * 1000 + (i % 10)).collect();
847
848        let index = PiecewiseLearnedIndex::build(&keys, 5);
849
850        // Should handle slight jitter
851        assert!(index.stats.max_error <= 5);
852
853        let (low, high) = index.lookup(base + 500000, 1000).unwrap();
854        assert!(high - low <= 10); // Tight bounds
855    }
856
857    #[test]
858    fn test_piecewise_optimal() {
859        let keys: Vec<u64> = (0..100).map(|i| i * i).collect(); // Quadratic
860
861        let index = PiecewiseLearnedIndex::build_optimal(&keys, 10);
862
863        // Should have at least one segment
864        assert!(index.segment_count() >= 1);
865
866        // All lookups should work
867        for i in 0..100 {
868            let (low, high) = index.lookup(i * i, 100).unwrap();
869            assert!(
870                low <= i as usize && i as usize <= high,
871                "Key {} (i={}): expected bounds to contain {}, got ({}, {})",
872                i * i,
873                i,
874                i,
875                low,
876                high
877            );
878        }
879    }
880
881    #[test]
882    fn test_piecewise_memory() {
883        let keys: Vec<u64> = (0..10000).collect();
884        let index = PiecewiseLearnedIndex::build(&keys, 10);
885
886        // Should use much less memory than storing all keys
887        let key_memory = keys.len() * std::mem::size_of::<u64>();
888        assert!(index.memory_bytes() < key_memory / 10);
889
890        println!(
891            "Compression: {} keys -> {} segments ({:.1}x)",
892            keys.len(),
893            index.segment_count(),
894            index.stats.compression_ratio
895        );
896    }
897}
898
899// =============================================================================
900// Task 9: Recursive Model Index (RMI) with Delta Updates
901// =============================================================================
902
903/// Two-level Recursive Model Index for O(1) expected lookups
904///
905/// ## Architecture
906/// ```text
907/// Level 1 (Root):   [Linear Model] → routes to leaf model
908/// Level 2 (Leaves): [PLM 0] [PLM 1] ... [PLM N] → position predictions
909/// ```
910///
911/// ## Performance
912/// - Space: O(M × params) where M = number of models
913/// - Lookup: O(1) average, O(log ε) for binary search in error range
914#[derive(Debug)]
915pub struct RecursiveModelIndex {
916    /// Root model: maps normalized key → leaf model index
917    root_slope: f64,
918    root_intercept: f64,
919    /// Leaf models (piecewise linear within each bucket)
920    leaves: Vec<PiecewiseLearnedIndex>,
921    /// Min key for normalization
922    min_key: u64,
923    /// Max key for normalization
924    max_key: u64,
925    /// Key range as f64
926    key_range: f64,
927    /// Total keys
928    num_keys: usize,
929    /// Global max error
930    max_error: usize,
931}
932
933impl RecursiveModelIndex {
934    /// Build a 2-level RMI
935    ///
936    /// # Arguments
937    /// * `keys` - Sorted keys
938    /// * `num_leaves` - Number of leaf models (typically √N)
939    /// * `leaf_max_error` - Max error per leaf segment
940    pub fn build(keys: &[u64], num_leaves: usize, leaf_max_error: usize) -> Self {
941        let n = keys.len();
942        if n == 0 {
943            return Self {
944                root_slope: 0.0,
945                root_intercept: 0.0,
946                leaves: Vec::new(),
947                min_key: 0,
948                max_key: 0,
949                key_range: 0.0,
950                num_keys: 0,
951                max_error: 0,
952            };
953        }
954
955        let min_key = keys[0];
956        let max_key = keys[n - 1];
957        let key_range = if max_key == min_key {
958            1.0
959        } else {
960            (max_key - min_key) as f64
961        };
962        let num_leaves = num_leaves.min(n).max(1);
963
964        // Fit root model: normalized_key → bucket_index
965        let bucket_size = n.div_ceil(num_leaves);
966        let root_slope = num_leaves as f64; // Maps [0,1] to [0, num_leaves]
967        let root_intercept = 0.0;
968
969        // Build leaf models for each bucket
970        let mut leaves = Vec::with_capacity(num_leaves);
971        let mut global_max_error = 0usize;
972
973        for bucket_idx in 0..num_leaves {
974            let start = bucket_idx * bucket_size;
975            let end = ((bucket_idx + 1) * bucket_size).min(n);
976
977            if start < n {
978                let bucket_keys: Vec<u64> = keys[start..end].to_vec();
979                let leaf = PiecewiseLearnedIndex::build(&bucket_keys, leaf_max_error);
980                global_max_error = global_max_error.max(leaf.stats.max_error);
981                leaves.push(leaf);
982            }
983        }
984
985        Self {
986            root_slope,
987            root_intercept,
988            leaves,
989            min_key,
990            max_key,
991            key_range,
992            num_keys: n,
993            max_error: global_max_error,
994        }
995    }
996
997    /// Look up position bounds for a key
998    pub fn lookup(&self, key: u64, data_len: usize) -> Option<(usize, usize)> {
999        if self.num_keys == 0 || key < self.min_key || key > self.max_key {
1000            return None;
1001        }
1002
1003        // Normalize key to [0, 1]
1004        let normalized = (key - self.min_key) as f64 / self.key_range;
1005
1006        // Route to leaf
1007        let leaf_idx_f = self.root_slope * normalized + self.root_intercept;
1008        let leaf_idx = (leaf_idx_f as usize).min(self.leaves.len().saturating_sub(1));
1009
1010        // Query leaf model
1011        if let Some(leaf) = self.leaves.get(leaf_idx) {
1012            // Leaf returns relative position within bucket
1013            if let Some((rel_low, rel_high)) = leaf.lookup(key, data_len) {
1014                // Convert to absolute position
1015                let bucket_size = self.num_keys.div_ceil(self.leaves.len());
1016                let bucket_start = leaf_idx * bucket_size;
1017                let abs_low = bucket_start + rel_low;
1018                let abs_high = (bucket_start + rel_high).min(data_len.saturating_sub(1));
1019                return Some((abs_low, abs_high));
1020            }
1021        }
1022
1023        // Fallback: return full range
1024        Some((0, data_len.saturating_sub(1)))
1025    }
1026
1027    /// Get space usage in bytes
1028    pub fn size_bytes(&self) -> usize {
1029        let base = std::mem::size_of::<Self>();
1030        let leaves: usize = self.leaves.iter().map(|l| l.memory_bytes()).sum();
1031        base + leaves
1032    }
1033
1034    /// Get number of leaf models
1035    pub fn num_leaves(&self) -> usize {
1036        self.leaves.len()
1037    }
1038
1039    /// Get total segment count across all leaves
1040    pub fn total_segments(&self) -> usize {
1041        self.leaves.iter().map(|l| l.segment_count()).sum()
1042    }
1043}
1044
1045/// Delta index for online updates without full rebuild
1046///
1047/// Maintains pending inserts/deletes in a B-tree structure,
1048/// merging with static learned index during lookup.
1049#[derive(Debug)]
1050pub struct DeltaIndex {
1051    /// Inserted keys (key → tombstone flag)
1052    entries: BTreeMap<u64, bool>,
1053    /// Insert count since last rebuild
1054    insert_count: usize,
1055    /// Delete count since last rebuild
1056    delete_count: usize,
1057    /// Rebuild threshold as fraction of static size
1058    rebuild_threshold: f64,
1059}
1060
1061impl DeltaIndex {
1062    /// Create new delta index
1063    pub fn new(rebuild_threshold: f64) -> Self {
1064        Self {
1065            entries: BTreeMap::new(),
1066            insert_count: 0,
1067            delete_count: 0,
1068            rebuild_threshold,
1069        }
1070    }
1071
1072    /// Insert a key
1073    pub fn insert(&mut self, key: u64) {
1074        if let Some(deleted) = self.entries.get_mut(&key) {
1075            // Resurrect deleted key
1076            if *deleted {
1077                *deleted = false;
1078                self.delete_count = self.delete_count.saturating_sub(1);
1079            }
1080        } else {
1081            self.entries.insert(key, false);
1082            self.insert_count += 1;
1083        }
1084    }
1085
1086    /// Delete a key (tombstone)
1087    pub fn delete(&mut self, key: u64) {
1088        self.entries.insert(key, true);
1089        self.delete_count += 1;
1090    }
1091
1092    /// Check if key is in delta
1093    pub fn contains(&self, key: u64) -> Option<bool> {
1094        self.entries.get(&key).copied()
1095    }
1096
1097    /// Check if rebuild is needed
1098    pub fn needs_rebuild(&self, static_size: usize) -> bool {
1099        if static_size == 0 {
1100            return self.entries.len() > 100;
1101        }
1102        let delta_size = self.insert_count + self.delete_count;
1103        (delta_size as f64 / static_size as f64) > self.rebuild_threshold
1104    }
1105
1106    /// Get all live keys (for rebuild)
1107    pub fn live_keys(&self) -> impl Iterator<Item = u64> + '_ {
1108        self.entries
1109            .iter()
1110            .filter(|(_, deleted)| !**deleted)
1111            .map(|(k, _)| *k)
1112    }
1113
1114    /// Get all deleted keys
1115    pub fn deleted_keys(&self) -> impl Iterator<Item = u64> + '_ {
1116        self.entries
1117            .iter()
1118            .filter(|(_, deleted)| **deleted)
1119            .map(|(k, _)| *k)
1120    }
1121
1122    /// Clear after rebuild
1123    pub fn clear(&mut self) {
1124        self.entries.clear();
1125        self.insert_count = 0;
1126        self.delete_count = 0;
1127    }
1128
1129    /// Size of delta
1130    pub fn len(&self) -> usize {
1131        self.entries.len()
1132    }
1133
1134    /// Is delta empty
1135    pub fn is_empty(&self) -> bool {
1136        self.entries.is_empty()
1137    }
1138}
1139
1140/// Hybrid RMI with delta updates and B-tree fallback
1141#[derive(Debug)]
1142pub struct HybridRMI {
1143    /// Static RMI structure
1144    rmi: RecursiveModelIndex,
1145    /// Delta index for updates
1146    delta: DeltaIndex,
1147    /// Sorted keys for binary search
1148    keys: Vec<u64>,
1149    /// B-tree fallback for high-error keys
1150    btree_fallback: BTreeMap<u64, usize>,
1151    /// Stats
1152    stats: HybridRMIStats,
1153}
1154
1155/// Statistics for HybridRMI
1156#[derive(Debug, Clone, Default)]
1157pub struct HybridRMIStats {
1158    /// Lookups via RMI
1159    pub rmi_lookups: u64,
1160    /// Lookups via B-tree fallback
1161    pub btree_lookups: u64,
1162    /// Lookups in delta
1163    pub delta_lookups: u64,
1164    /// Number of rebuilds
1165    pub rebuilds: u64,
1166}
1167
1168impl HybridRMI {
1169    /// Build hybrid RMI
1170    pub fn build(
1171        keys: Vec<u64>,
1172        num_leaves: usize,
1173        leaf_max_error: usize,
1174        rebuild_threshold: f64,
1175    ) -> Self {
1176        let rmi = RecursiveModelIndex::build(&keys, num_leaves, leaf_max_error);
1177
1178        // Find overflow keys for B-tree fallback
1179        let _overflow_threshold = leaf_max_error * 3;
1180        let mut btree_fallback = BTreeMap::new();
1181
1182        for (pos, &key) in keys.iter().enumerate() {
1183            if let Some((low, high)) = rmi.lookup(key, keys.len())
1184                && (pos < low || pos > high)
1185            {
1186                btree_fallback.insert(key, pos);
1187            }
1188        }
1189
1190        Self {
1191            rmi,
1192            delta: DeltaIndex::new(rebuild_threshold),
1193            keys,
1194            btree_fallback,
1195            stats: HybridRMIStats::default(),
1196        }
1197    }
1198
1199    /// Look up a key
1200    pub fn lookup(&mut self, key: u64) -> Option<usize> {
1201        // 1. Check delta first
1202        if let Some(deleted) = self.delta.contains(key) {
1203            self.stats.delta_lookups += 1;
1204            if deleted {
1205                return None; // Deleted
1206            }
1207            // Key in delta but not deleted - need to find position
1208            // This is a recent insert, position unknown
1209            return None;
1210        }
1211
1212        // 2. Check B-tree fallback
1213        if let Some(&pos) = self.btree_fallback.get(&key) {
1214            self.stats.btree_lookups += 1;
1215            return Some(pos);
1216        }
1217
1218        // 3. Use RMI
1219        self.stats.rmi_lookups += 1;
1220        if let Some((low, high)) = self.rmi.lookup(key, self.keys.len()) {
1221            // Binary search in predicted range
1222            let range = &self.keys[low..=high.min(self.keys.len().saturating_sub(1))];
1223            if let Ok(idx) = range.binary_search(&key) {
1224                return Some(low + idx);
1225            }
1226        }
1227
1228        // 4. Full binary search fallback
1229        self.keys.binary_search(&key).ok()
1230    }
1231
1232    /// Insert a key (goes to delta)
1233    pub fn insert(&mut self, key: u64) {
1234        self.delta.insert(key);
1235        if self.delta.needs_rebuild(self.keys.len()) {
1236            self.rebuild();
1237        }
1238    }
1239
1240    /// Delete a key
1241    pub fn delete(&mut self, key: u64) {
1242        self.delta.delete(key);
1243    }
1244
1245    /// Rebuild the index
1246    pub fn rebuild(&mut self) {
1247        // Merge delta into keys
1248        let deleted: HashSet<u64> = self.delta.deleted_keys().collect();
1249
1250        let mut new_keys: Vec<u64> = self
1251            .keys
1252            .iter()
1253            .filter(|&k| !deleted.contains(k))
1254            .copied()
1255            .collect();
1256
1257        new_keys.extend(self.delta.live_keys());
1258        new_keys.sort_unstable();
1259        new_keys.dedup();
1260
1261        // Rebuild RMI
1262        let n = new_keys.len();
1263        let num_leaves = ((n as f64).sqrt().ceil() as usize).max(1);
1264        let leaf_max_error = self.rmi.max_error.max(10);
1265
1266        let new_rmi = RecursiveModelIndex::build(&new_keys, num_leaves, leaf_max_error);
1267
1268        // Rebuild B-tree fallback
1269        let mut new_btree = BTreeMap::new();
1270        for (pos, &key) in new_keys.iter().enumerate() {
1271            if let Some((low, high)) = new_rmi.lookup(key, new_keys.len())
1272                && (pos < low || pos > high)
1273            {
1274                new_btree.insert(key, pos);
1275            }
1276        }
1277
1278        self.rmi = new_rmi;
1279        self.keys = new_keys;
1280        self.btree_fallback = new_btree;
1281        self.delta.clear();
1282        self.stats.rebuilds += 1;
1283    }
1284
1285    /// Get statistics
1286    pub fn stats(&self) -> &HybridRMIStats {
1287        &self.stats
1288    }
1289
1290    /// Get number of keys
1291    pub fn len(&self) -> usize {
1292        self.keys.len()
1293    }
1294
1295    /// Check if empty
1296    pub fn is_empty(&self) -> bool {
1297        self.keys.is_empty()
1298    }
1299
1300    /// Space usage in bytes
1301    pub fn size_bytes(&self) -> usize {
1302        self.rmi.size_bytes()
1303            + self.keys.len() * std::mem::size_of::<u64>()
1304            + self.delta.len() * std::mem::size_of::<(u64, bool)>()
1305            + self.btree_fallback.len()
1306                * (std::mem::size_of::<u64>() + std::mem::size_of::<usize>())
1307    }
1308}
1309
1310#[cfg(test)]
1311mod rmi_tests {
1312    use super::*;
1313
1314    #[test]
1315    fn test_rmi_build() {
1316        let keys: Vec<u64> = (0..10000).collect();
1317        let rmi = RecursiveModelIndex::build(&keys, 100, 10);
1318
1319        assert_eq!(rmi.num_leaves(), 100);
1320        assert!(rmi.max_error <= 10);
1321
1322        // Test lookups
1323        for i in (0..10000).step_by(100) {
1324            if let Some((low, high)) = rmi.lookup(i, 10000) {
1325                assert!(
1326                    low <= i as usize && high >= i as usize,
1327                    "Key {}: bounds ({}, {}) don't contain position",
1328                    i,
1329                    low,
1330                    high
1331                );
1332            }
1333        }
1334    }
1335
1336    #[test]
1337    fn test_delta_index() {
1338        let mut delta = DeltaIndex::new(0.1);
1339
1340        delta.insert(100);
1341        delta.insert(200);
1342        delta.delete(150);
1343
1344        assert_eq!(delta.contains(100), Some(false));
1345        assert_eq!(delta.contains(150), Some(true)); // deleted
1346        assert_eq!(delta.contains(300), None);
1347
1348        // Resurrect deleted key
1349        delta.insert(150);
1350        assert_eq!(delta.contains(150), Some(false));
1351    }
1352
1353    #[test]
1354    fn test_hybrid_rmi_lookup() {
1355        let keys: Vec<u64> = (0..1000).step_by(2).collect();
1356        let mut rmi = HybridRMI::build(keys, 10, 5, 0.1);
1357
1358        // Find existing key
1359        assert_eq!(rmi.lookup(100), Some(50));
1360        assert_eq!(rmi.lookup(500), Some(250));
1361
1362        // Non-existing
1363        assert!(rmi.lookup(101).is_none());
1364    }
1365
1366    #[test]
1367    fn test_hybrid_rmi_updates() {
1368        let keys: Vec<u64> = (0..100).collect();
1369        let mut rmi = HybridRMI::build(keys, 5, 5, 0.5);
1370
1371        // Delete a key - next lookup should return None
1372        rmi.delete(50);
1373        assert!(rmi.lookup(50).is_none());
1374
1375        // Insert a new key (goes to delta, won't be in RMI yet)
1376        rmi.insert(200);
1377
1378        // Verify delta has something
1379        assert!(!rmi.delta.is_empty());
1380    }
1381
1382    #[test]
1383    fn test_hybrid_rmi_rebuild() {
1384        let keys: Vec<u64> = (0..100).collect();
1385        let mut rmi = HybridRMI::build(keys, 5, 5, 0.05);
1386
1387        let initial_len = rmi.len();
1388
1389        // Insert enough to trigger rebuild
1390        for i in 100..115 {
1391            rmi.insert(i);
1392        }
1393
1394        // Rebuild should have happened
1395        assert!(rmi.stats().rebuilds > 0);
1396        // Length should have increased
1397        assert!(rmi.len() > initial_len);
1398    }
1399
1400    #[test]
1401    fn test_rmi_space_efficiency() {
1402        let n = 100_000usize;
1403        let keys: Vec<u64> = (0..n as u64).collect();
1404        let rmi = RecursiveModelIndex::build(&keys, 100, 50);
1405
1406        let rmi_size = rmi.size_bytes();
1407        let raw_size = n * std::mem::size_of::<u64>();
1408
1409        println!(
1410            "RMI size: {} bytes, Raw keys: {} bytes, Ratio: {:.2}x",
1411            rmi_size,
1412            raw_size,
1413            raw_size as f64 / rmi_size as f64
1414        );
1415
1416        // RMI structure should be much smaller than raw keys
1417        assert!(
1418            rmi_size < raw_size / 10,
1419            "RMI size {} should be < 10% of raw size {}",
1420            rmi_size,
1421            raw_size
1422        );
1423    }
1424}