Skip to main content

tenflowers_dataset/
smart_cache.rs

1//! Smart caching system with adaptive policies and multi-tier caching
2//!
3//! This module provides advanced caching strategies that adapt to access patterns
4//! and provide multi-tier memory management for optimal performance.
5
6use crate::Dataset;
7use std::collections::{HashMap, VecDeque};
8use std::hash::Hash;
9use std::sync::{Arc, Mutex, RwLock};
10use std::time::{Duration, Instant};
11use tenflowers_core::{Result, Tensor};
12
13/// Access pattern tracking for adaptive caching decisions
14#[derive(Debug, Clone)]
15#[allow(dead_code)]
16pub struct AccessPattern {
17    /// Last access time
18    last_access: Instant,
19    /// Number of accesses
20    access_count: u64,
21    /// Average time between accesses
22    avg_interval: Duration,
23    /// Is this a sequential access pattern?
24    is_sequential: bool,
25    /// Frequency score (higher = more frequently accessed)
26    frequency_score: f64,
27}
28
29impl AccessPattern {
30    fn new() -> Self {
31        Self {
32            last_access: Instant::now(),
33            access_count: 1,
34            avg_interval: Duration::from_secs(0),
35            is_sequential: false,
36            frequency_score: 1.0,
37        }
38    }
39
40    fn update(&mut self, now: Instant) {
41        let interval = now.duration_since(self.last_access);
42
43        // Update average interval with exponential moving average
44        if self.access_count > 1 {
45            let alpha = 0.1;
46            self.avg_interval = Duration::from_secs_f64(
47                alpha * interval.as_secs_f64() + (1.0 - alpha) * self.avg_interval.as_secs_f64(),
48            );
49        } else {
50            self.avg_interval = interval;
51        }
52
53        self.last_access = now;
54        self.access_count += 1;
55
56        // Update frequency score (decay over time, boost with access)
57        let time_decay = (-interval.as_secs_f64() / 300.0).exp(); // 5-minute half-life
58        self.frequency_score = self.frequency_score * time_decay + 1.0;
59    }
60
61    fn priority_score(&self) -> f64 {
62        let recency_score = 1.0 / (1.0 + self.last_access.elapsed().as_secs_f64() / 60.0);
63        let frequency_weight = 0.7;
64        let recency_weight = 0.3;
65
66        frequency_weight * self.frequency_score + recency_weight * recency_score
67    }
68}
69
70/// Cache eviction policies
71#[derive(Debug, Clone)]
72pub enum EvictionPolicy {
73    /// Least Recently Used
74    LRU,
75    /// Least Frequently Used
76    LFU,
77    /// Adaptive based on access patterns
78    Adaptive,
79    /// Time-based with TTL
80    TimeBasedTTL(Duration),
81    /// Hybrid: combines multiple strategies
82    Hybrid,
83}
84
85/// Multi-tier cache levels
86#[derive(Debug, Clone)]
87pub enum CacheLevel {
88    /// Fast memory cache (e.g., RAM)
89    L1Memory,
90    /// Slower but larger storage (e.g., SSD)
91    L2Storage,
92    /// Very slow but huge storage (e.g., HDD, remote)
93    L3Remote,
94}
95
96/// Cache entry with metadata
97#[derive(Debug, Clone)]
98#[allow(dead_code)]
99struct CacheEntry<T> {
100    data: (Tensor<T>, Tensor<T>),
101    pattern: AccessPattern,
102    size: usize,
103    level: CacheLevel,
104    compressed: bool,
105    ttl: Option<Instant>,
106}
107
108impl<T> CacheEntry<T> {
109    fn new(data: (Tensor<T>, Tensor<T>), level: CacheLevel) -> Self {
110        let size = data.0.shape().size() + data.1.shape().size();
111        Self {
112            data,
113            pattern: AccessPattern::new(),
114            size,
115            level,
116            compressed: false,
117            ttl: None,
118        }
119    }
120
121    fn is_expired(&self) -> bool {
122        if let Some(ttl) = self.ttl {
123            Instant::now() > ttl
124        } else {
125            false
126        }
127    }
128}
129
130/// Smart adaptive cache with multi-tier support
131pub struct SmartCache<T, K>
132where
133    K: Eq + Hash + Clone,
134{
135    /// L1 cache: fast memory
136    l1_cache: Arc<RwLock<HashMap<K, CacheEntry<T>>>>,
137    /// L2 cache: larger but slower storage
138    l2_cache: Arc<RwLock<HashMap<K, CacheEntry<T>>>>,
139    /// L3 cache: very large remote/disk storage
140    l3_cache: Arc<RwLock<HashMap<K, CacheEntry<T>>>>,
141
142    /// Maximum size for each cache level
143    l1_max_size: usize,
144    l2_max_size: usize,
145    l3_max_size: usize,
146
147    /// Current size for each cache level
148    l1_current_size: Arc<Mutex<usize>>,
149    l2_current_size: Arc<Mutex<usize>>,
150    l3_current_size: Arc<Mutex<usize>>,
151
152    /// Eviction policy
153    policy: EvictionPolicy,
154
155    /// Access order tracking for LRU
156    l1_access_order: Arc<Mutex<VecDeque<K>>>,
157    l2_access_order: Arc<Mutex<VecDeque<K>>>,
158    l3_access_order: Arc<Mutex<VecDeque<K>>>,
159
160    /// Statistics
161    stats: Arc<Mutex<CacheStats>>,
162
163    /// Configuration
164    config: CacheConfig,
165}
166
167/// Cache configuration
168#[derive(Debug, Clone)]
169#[allow(dead_code)]
170pub struct CacheConfig {
171    /// Enable compression for larger cache levels
172    enable_compression: bool,
173    /// TTL for cache entries
174    default_ttl: Option<Duration>,
175    /// Threshold for promoting entries between cache levels
176    promotion_threshold: f64,
177    /// Threshold for demoting entries between cache levels
178    demotion_threshold: f64,
179    /// Maximum memory usage before triggering aggressive eviction
180    memory_pressure_threshold: f64,
181    /// Background cleanup interval
182    cleanup_interval: Duration,
183}
184
185impl Default for CacheConfig {
186    fn default() -> Self {
187        Self {
188            enable_compression: true,
189            default_ttl: Some(Duration::from_secs(3600)), // 1 hour
190            promotion_threshold: 3.0,
191            demotion_threshold: 0.5,
192            memory_pressure_threshold: 0.8,
193            cleanup_interval: Duration::from_secs(60),
194        }
195    }
196}
197
198/// Cache statistics
199#[derive(Debug, Clone)]
200pub struct CacheStats {
201    pub l1_hits: u64,
202    pub l2_hits: u64,
203    pub l3_hits: u64,
204    pub misses: u64,
205    pub evictions: u64,
206    pub promotions: u64,
207    pub demotions: u64,
208    pub total_requests: u64,
209    pub avg_access_time: Duration,
210}
211
212impl Default for CacheStats {
213    fn default() -> Self {
214        Self {
215            l1_hits: 0,
216            l2_hits: 0,
217            l3_hits: 0,
218            misses: 0,
219            evictions: 0,
220            promotions: 0,
221            demotions: 0,
222            total_requests: 0,
223            avg_access_time: Duration::from_secs(0),
224        }
225    }
226}
227
228impl<T, K> SmartCache<T, K>
229where
230    T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
231    K: Eq + Hash + Clone + Send + Sync,
232{
233    /// Create a new smart cache with specified capacities
234    pub fn new(
235        l1_max_size: usize,
236        l2_max_size: usize,
237        l3_max_size: usize,
238        policy: EvictionPolicy,
239        config: CacheConfig,
240    ) -> Self {
241        Self {
242            l1_cache: Arc::new(RwLock::new(HashMap::new())),
243            l2_cache: Arc::new(RwLock::new(HashMap::new())),
244            l3_cache: Arc::new(RwLock::new(HashMap::new())),
245            l1_max_size,
246            l2_max_size,
247            l3_max_size,
248            l1_current_size: Arc::new(Mutex::new(0)),
249            l2_current_size: Arc::new(Mutex::new(0)),
250            l3_current_size: Arc::new(Mutex::new(0)),
251            policy,
252            l1_access_order: Arc::new(Mutex::new(VecDeque::new())),
253            l2_access_order: Arc::new(Mutex::new(VecDeque::new())),
254            l3_access_order: Arc::new(Mutex::new(VecDeque::new())),
255            stats: Arc::new(Mutex::new(CacheStats::default())),
256            config,
257        }
258    }
259
260    /// Get an item from cache (checks all levels)
261    pub fn get(&self, key: &K) -> Option<(Tensor<T>, Tensor<T>)> {
262        let start_time = Instant::now();
263        let mut stats = self.stats.lock().expect("lock should not be poisoned");
264        stats.total_requests += 1;
265        drop(stats);
266
267        // Check L1 cache first
268        if let Some(mut entry) = self.get_from_level(key, CacheLevel::L1Memory) {
269            entry.pattern.update(Instant::now());
270            self.update_stats_hit(CacheLevel::L1Memory, start_time);
271            return Some(entry.data);
272        }
273
274        // Check L2 cache
275        if let Some(mut entry) = self.get_from_level(key, CacheLevel::L2Storage) {
276            entry.pattern.update(Instant::now());
277
278            // Consider promotion to L1 based on access pattern
279            if entry.pattern.priority_score() > self.config.promotion_threshold {
280                self.promote_entry(key.clone(), entry.clone(), CacheLevel::L1Memory);
281            }
282
283            self.update_stats_hit(CacheLevel::L2Storage, start_time);
284            return Some(entry.data);
285        }
286
287        // Check L3 cache
288        if let Some(mut entry) = self.get_from_level(key, CacheLevel::L3Remote) {
289            entry.pattern.update(Instant::now());
290
291            // Consider promotion to L2 or L1 based on access pattern
292            if entry.pattern.priority_score() > self.config.promotion_threshold {
293                self.promote_entry(key.clone(), entry.clone(), CacheLevel::L2Storage);
294            }
295
296            self.update_stats_hit(CacheLevel::L3Remote, start_time);
297            return Some(entry.data);
298        }
299
300        // Cache miss
301        let mut stats = self.stats.lock().expect("lock should not be poisoned");
302        stats.misses += 1;
303        None
304    }
305
306    /// Put an item into cache (automatically selects appropriate level)
307    pub fn put(&self, key: K, value: (Tensor<T>, Tensor<T>)) {
308        let entry = CacheEntry::new(value, CacheLevel::L1Memory);
309
310        // Try to insert into L1 first
311        if self.try_insert_at_level(key.clone(), entry.clone(), CacheLevel::L1Memory) {
312            return;
313        }
314
315        // If L1 is full, try L2
316        if self.try_insert_at_level(key.clone(), entry.clone(), CacheLevel::L2Storage) {
317            return;
318        }
319
320        // If L2 is full, use L3
321        self.try_insert_at_level(key, entry, CacheLevel::L3Remote);
322    }
323
324    fn get_from_level(&self, key: &K, level: CacheLevel) -> Option<CacheEntry<T>> {
325        let cache = match level {
326            CacheLevel::L1Memory => &self.l1_cache,
327            CacheLevel::L2Storage => &self.l2_cache,
328            CacheLevel::L3Remote => &self.l3_cache,
329        };
330
331        let cache_read = cache.read().expect("read lock should not be poisoned");
332        cache_read.get(key).and_then(|entry| {
333            if entry.is_expired() {
334                None
335            } else {
336                Some(entry.clone())
337            }
338        })
339    }
340
341    fn try_insert_at_level(&self, key: K, mut entry: CacheEntry<T>, level: CacheLevel) -> bool {
342        let (cache, current_size, max_size, access_order) = match level {
343            CacheLevel::L1Memory => (
344                &self.l1_cache,
345                &self.l1_current_size,
346                self.l1_max_size,
347                &self.l1_access_order,
348            ),
349            CacheLevel::L2Storage => (
350                &self.l2_cache,
351                &self.l2_current_size,
352                self.l2_max_size,
353                &self.l2_access_order,
354            ),
355            CacheLevel::L3Remote => (
356                &self.l3_cache,
357                &self.l3_current_size,
358                self.l3_max_size,
359                &self.l3_access_order,
360            ),
361        };
362
363        entry.level = level.clone();
364        if let Some(ttl) = self.config.default_ttl {
365            entry.ttl = Some(Instant::now() + ttl);
366        }
367
368        let mut size_guard = current_size.lock().expect("lock should not be poisoned");
369
370        // Check if we need to evict entries
371        while *size_guard + entry.size > max_size {
372            if !self.evict_from_level(level.clone()) {
373                return false; // Cannot evict, cache full
374            }
375            *size_guard = current_size
376                .lock()
377                .expect("lock should not be poisoned")
378                .saturating_sub(entry.size);
379        }
380
381        // Insert the entry
382        let mut cache_write = cache.write().expect("write lock should not be poisoned");
383        cache_write.insert(key.clone(), entry.clone());
384        *size_guard += entry.size;
385
386        // Update access order for LRU
387        let mut order = access_order.lock().expect("lock should not be poisoned");
388        order.push_back(key);
389
390        true
391    }
392
393    fn evict_from_level(&self, level: CacheLevel) -> bool {
394        let (cache, current_size, access_order) = match level {
395            CacheLevel::L1Memory => (&self.l1_cache, &self.l1_current_size, &self.l1_access_order),
396            CacheLevel::L2Storage => (&self.l2_cache, &self.l2_current_size, &self.l2_access_order),
397            CacheLevel::L3Remote => (&self.l3_cache, &self.l3_current_size, &self.l3_access_order),
398        };
399
400        let victim_key = match self.policy {
401            EvictionPolicy::LRU => {
402                let mut order = access_order.lock().expect("lock should not be poisoned");
403                order.pop_front()
404            }
405            EvictionPolicy::LFU | EvictionPolicy::Adaptive | EvictionPolicy::Hybrid => {
406                self.find_lfu_victim(cache)
407            }
408            EvictionPolicy::TimeBasedTTL(_) => self.find_expired_victim(cache),
409        };
410
411        if let Some(key) = victim_key {
412            let mut cache_write = cache.write().expect("write lock should not be poisoned");
413            if let Some(entry) = cache_write.remove(&key) {
414                let mut size_guard = current_size.lock().expect("lock should not be poisoned");
415                *size_guard = size_guard.saturating_sub(entry.size);
416
417                let mut stats = self.stats.lock().expect("lock should not be poisoned");
418                stats.evictions += 1;
419
420                return true;
421            }
422        }
423
424        false
425    }
426
427    fn find_lfu_victim(&self, cache: &Arc<RwLock<HashMap<K, CacheEntry<T>>>>) -> Option<K> {
428        let cache_read = cache.read().expect("read lock should not be poisoned");
429        cache_read
430            .iter()
431            .min_by(|(_, a), (_, b)| {
432                a.pattern
433                    .priority_score()
434                    .partial_cmp(&b.pattern.priority_score())
435                    .unwrap_or(std::cmp::Ordering::Equal)
436            })
437            .map(|(k, _)| k.clone())
438    }
439
440    fn find_expired_victim(&self, cache: &Arc<RwLock<HashMap<K, CacheEntry<T>>>>) -> Option<K> {
441        let cache_read = cache.read().expect("read lock should not be poisoned");
442        cache_read
443            .iter()
444            .find(|(_, entry)| entry.is_expired())
445            .map(|(k, _)| k.clone())
446    }
447
448    fn promote_entry(&self, key: K, entry: CacheEntry<T>, target_level: CacheLevel) {
449        let original_level = entry.level.clone();
450        if self.try_insert_at_level(key.clone(), entry, target_level) {
451            // Remove from lower level
452            match original_level {
453                CacheLevel::L3Remote => {
454                    let mut cache = self
455                        .l3_cache
456                        .write()
457                        .expect("write lock should not be poisoned");
458                    cache.remove(&key);
459                }
460                CacheLevel::L2Storage => {
461                    let mut cache = self
462                        .l2_cache
463                        .write()
464                        .expect("write lock should not be poisoned");
465                    cache.remove(&key);
466                }
467                _ => {}
468            }
469
470            let mut stats = self.stats.lock().expect("lock should not be poisoned");
471            stats.promotions += 1;
472        }
473    }
474
475    fn update_stats_hit(&self, level: CacheLevel, start_time: Instant) {
476        let mut stats = self.stats.lock().expect("lock should not be poisoned");
477        match level {
478            CacheLevel::L1Memory => stats.l1_hits += 1,
479            CacheLevel::L2Storage => stats.l2_hits += 1,
480            CacheLevel::L3Remote => stats.l3_hits += 1,
481        }
482
483        let access_time = start_time.elapsed();
484        let alpha = 0.1;
485        stats.avg_access_time = Duration::from_secs_f64(
486            alpha * access_time.as_secs_f64() + (1.0 - alpha) * stats.avg_access_time.as_secs_f64(),
487        );
488    }
489
490    /// Get cache statistics
491    pub fn stats(&self) -> CacheStats {
492        self.stats
493            .lock()
494            .expect("lock should not be poisoned")
495            .clone()
496    }
497
498    /// Clear all cache levels
499    pub fn clear(&self) {
500        let mut l1 = self
501            .l1_cache
502            .write()
503            .expect("write lock should not be poisoned");
504        let mut l2 = self
505            .l2_cache
506            .write()
507            .expect("write lock should not be poisoned");
508        let mut l3 = self
509            .l3_cache
510            .write()
511            .expect("write lock should not be poisoned");
512
513        l1.clear();
514        l2.clear();
515        l3.clear();
516
517        *self
518            .l1_current_size
519            .lock()
520            .expect("lock should not be poisoned") = 0;
521        *self
522            .l2_current_size
523            .lock()
524            .expect("lock should not be poisoned") = 0;
525        *self
526            .l3_current_size
527            .lock()
528            .expect("lock should not be poisoned") = 0;
529    }
530
531    /// Run background cleanup to remove expired entries
532    pub fn cleanup_expired(&self) {
533        for level in [
534            CacheLevel::L1Memory,
535            CacheLevel::L2Storage,
536            CacheLevel::L3Remote,
537        ] {
538            while self
539                .find_expired_victim(match level {
540                    CacheLevel::L1Memory => &self.l1_cache,
541                    CacheLevel::L2Storage => &self.l2_cache,
542                    CacheLevel::L3Remote => &self.l3_cache,
543                })
544                .is_some()
545            {
546                self.evict_from_level(level.clone());
547            }
548        }
549    }
550}
551
552/// Smart cached dataset wrapper
553pub struct SmartCachedDataset<T, D: Dataset<T>> {
554    dataset: D,
555    cache: Arc<SmartCache<T, usize>>,
556}
557
558impl<T, D: Dataset<T>> SmartCachedDataset<T, D>
559where
560    T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
561{
562    /// Create a new smart cached dataset
563    pub fn new(
564        dataset: D,
565        l1_size: usize,
566        l2_size: usize,
567        l3_size: usize,
568        policy: EvictionPolicy,
569        config: CacheConfig,
570    ) -> Self {
571        let cache = Arc::new(SmartCache::new(l1_size, l2_size, l3_size, policy, config));
572
573        Self { dataset, cache }
574    }
575
576    /// Get cache statistics
577    pub fn cache_stats(&self) -> CacheStats {
578        self.cache.stats()
579    }
580
581    /// Clear the cache
582    pub fn clear_cache(&self) {
583        self.cache.clear();
584    }
585}
586
587impl<T, D: Dataset<T>> Dataset<T> for SmartCachedDataset<T, D>
588where
589    T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
590{
591    fn len(&self) -> usize {
592        self.dataset.len()
593    }
594
595    fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
596        // Try cache first
597        if let Some(cached) = self.cache.get(&index) {
598            return Ok(cached);
599        }
600
601        // Cache miss - load from dataset
602        let sample = self.dataset.get(index)?;
603
604        // Store in cache
605        self.cache.put(index, sample.clone());
606
607        Ok(sample)
608    }
609}
610
611/// Predictive access pattern analyzer for smart prefetching
612#[derive(Debug, Clone)]
613pub struct AccessPatternPredictor<K>
614where
615    K: Eq + Hash + Clone + Send + Sync,
616{
617    /// History of recent accesses (sliding window)
618    access_history: VecDeque<(K, Instant)>,
619    /// Patterns detected in access sequences
620    sequence_patterns: HashMap<Vec<K>, f64>,
621    /// Maximum history size to maintain
622    max_history_size: usize,
623    /// Minimum pattern length to consider
624    min_pattern_length: usize,
625    /// Maximum pattern length to consider
626    max_pattern_length: usize,
627}
628
629impl<K> AccessPatternPredictor<K>
630where
631    K: Eq + Hash + Clone + Send + Sync,
632{
633    pub fn new() -> Self {
634        Self {
635            access_history: VecDeque::with_capacity(1000),
636            sequence_patterns: HashMap::new(),
637            max_history_size: 1000,
638            min_pattern_length: 2,
639            max_pattern_length: 5,
640        }
641    }
642
643    /// Record a new access and update patterns
644    pub fn record_access(&mut self, key: K) {
645        let now = Instant::now();
646
647        // Add to history
648        self.access_history.push_back((key.clone(), now));
649
650        // Maintain sliding window
651        if self.access_history.len() > self.max_history_size {
652            self.access_history.pop_front();
653        }
654
655        // Update sequence patterns
656        self.update_patterns();
657    }
658
659    /// Predict the next likely accesses based on recent patterns
660    pub fn predict_next_accesses(&self, current_key: &K, max_predictions: usize) -> Vec<(K, f64)> {
661        let mut predictions = Vec::new();
662
663        // Look for patterns ending with the current key
664        for pattern_len in self.min_pattern_length..=self.max_pattern_length {
665            if let Some(recent_sequence) = self.get_recent_sequence(pattern_len) {
666                if recent_sequence.last() == Some(current_key) {
667                    // Find patterns that start with this sequence
668                    for (pattern, confidence) in &self.sequence_patterns {
669                        if pattern.len() > pattern_len
670                            && pattern[..pattern_len] == recent_sequence[..]
671                        {
672                            let next_key = &pattern[pattern_len];
673                            predictions.push((next_key.clone(), *confidence));
674                        }
675                    }
676                }
677            }
678        }
679
680        // Sort by confidence and return top predictions
681        predictions.sort_by(|a, b| {
682            b.1.partial_cmp(&a.1)
683                .expect("partial_cmp should not return None for valid values")
684        });
685        predictions.truncate(max_predictions);
686        predictions
687    }
688
689    /// Get recent access sequence of specified length
690    fn get_recent_sequence(&self, length: usize) -> Option<Vec<K>> {
691        if self.access_history.len() < length {
692            return None;
693        }
694
695        let recent: Vec<K> = self
696            .access_history
697            .iter()
698            .rev()
699            .take(length)
700            .map(|(k, _)| k.clone())
701            .collect::<Vec<_>>()
702            .into_iter()
703            .rev()
704            .collect();
705
706        Some(recent)
707    }
708
709    /// Update sequence patterns based on access history
710    fn update_patterns(&mut self) {
711        let history_keys: Vec<K> = self.access_history.iter().map(|(k, _)| k.clone()).collect();
712
713        // Extract patterns of different lengths
714        for pattern_len in self.min_pattern_length..=self.max_pattern_length {
715            if history_keys.len() >= pattern_len {
716                for i in 0..=(history_keys.len() - pattern_len) {
717                    let pattern = history_keys[i..i + pattern_len].to_vec();
718
719                    // Exponential decay for older patterns
720                    let age_factor = 1.0 - (i as f64 / history_keys.len() as f64 * 0.1);
721
722                    *self.sequence_patterns.entry(pattern).or_insert(0.0) += age_factor;
723                }
724            }
725        }
726
727        // Decay all patterns over time to prevent unbounded growth
728        for confidence in self.sequence_patterns.values_mut() {
729            *confidence *= 0.99; // Small decay factor
730        }
731
732        // Remove patterns with very low confidence
733        self.sequence_patterns
734            .retain(|_, confidence| *confidence > 0.1);
735    }
736}
737
738impl<K> Default for AccessPatternPredictor<K>
739where
740    K: Eq + Hash + Clone + Send + Sync,
741{
742    fn default() -> Self {
743        Self::new()
744    }
745}
746
747/// Enhanced smart cache with predictive prefetching
748pub struct PredictiveSmartCache<T, K>
749where
750    T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
751    K: Eq + Hash + Clone + Send + Sync,
752{
753    /// Base smart cache
754    base_cache: SmartCache<T, K>,
755    /// Pattern predictor for prefetching
756    predictor: Arc<Mutex<AccessPatternPredictor<K>>>,
757    /// Reference to the dataset for prefetching
758    dataset: Option<Arc<dyn Dataset<T>>>,
759    /// Prefetch queue
760    prefetch_queue: Arc<Mutex<VecDeque<K>>>,
761    /// Maximum prefetch queue size
762    max_prefetch_size: usize,
763}
764
765impl<T, K> PredictiveSmartCache<T, K>
766where
767    T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
768    K: Eq + Hash + Clone + Send + Sync,
769{
770    pub fn new(
771        l1_max_size: usize,
772        l2_max_size: usize,
773        l3_max_size: usize,
774        policy: EvictionPolicy,
775        config: CacheConfig,
776        max_prefetch_size: usize,
777    ) -> Self {
778        Self {
779            base_cache: SmartCache::new(l1_max_size, l2_max_size, l3_max_size, policy, config),
780            predictor: Arc::new(Mutex::new(AccessPatternPredictor::new())),
781            dataset: None,
782            prefetch_queue: Arc::new(Mutex::new(VecDeque::with_capacity(max_prefetch_size))),
783            max_prefetch_size,
784        }
785    }
786
787    /// Set the dataset reference for prefetching
788    pub fn set_dataset(&mut self, dataset: Arc<dyn Dataset<T>>) {
789        self.dataset = Some(dataset);
790    }
791
792    /// Get item with predictive prefetching
793    pub fn get(&self, key: &K) -> Option<(Tensor<T>, Tensor<T>)> {
794        // Record access for pattern learning
795        {
796            let mut predictor = self.predictor.lock().expect("lock should not be poisoned");
797            predictor.record_access(key.clone());
798        }
799
800        // Try to get from base cache first
801        if let Some(result) = self.base_cache.get(key) {
802            // Trigger predictive prefetching based on this access
803            self.trigger_prefetch(key);
804            return Some(result);
805        }
806
807        // Cache miss - load from dataset if available
808        if let Some(ref dataset) = self.dataset {
809            // For this example, assume K can be converted to usize for dataset access
810            // In a real implementation, you'd need proper key-to-index mapping
811            if let Some(data) = self.load_from_dataset(dataset, key) {
812                self.base_cache.put(key.clone(), data.clone());
813                self.trigger_prefetch(key);
814                return Some(data);
815            }
816        }
817
818        None
819    }
820
821    /// Put item in cache
822    pub fn put(&self, key: K, data: (Tensor<T>, Tensor<T>)) {
823        self.base_cache.put(key, data);
824    }
825
826    /// Get cache statistics
827    pub fn stats(&self) -> CacheStats {
828        self.base_cache.stats()
829    }
830
831    /// Trigger predictive prefetching based on current access
832    fn trigger_prefetch(&self, current_key: &K) {
833        let predictions = {
834            let predictor = self.predictor.lock().expect("lock should not be poisoned");
835            predictor.predict_next_accesses(current_key, 3) // Predict up to 3 next accesses
836        };
837
838        let mut prefetch_queue = self
839            .prefetch_queue
840            .lock()
841            .expect("lock should not be poisoned");
842
843        for (predicted_key, confidence) in predictions {
844            // Only prefetch if confidence is high enough and not already cached
845            if confidence > 0.5 && self.base_cache.get(&predicted_key).is_none() {
846                prefetch_queue.push_back(predicted_key);
847
848                // Maintain queue size limit
849                if prefetch_queue.len() > self.max_prefetch_size {
850                    prefetch_queue.pop_front();
851                }
852            }
853        }
854    }
855
856    /// Load data from dataset (placeholder implementation)
857    fn load_from_dataset(
858        &self,
859        _dataset: &Arc<dyn Dataset<T>>,
860        _key: &K,
861    ) -> Option<(Tensor<T>, Tensor<T>)> {
862        // This is a placeholder - in a real implementation, you would:
863        // 1. Convert key to dataset index
864        // 2. Load data from dataset
865        // 3. Return the loaded data
866        None
867    }
868
869    /// Process prefetch queue (should be called periodically)
870    pub fn process_prefetch_queue(&self) {
871        if let Some(ref dataset) = self.dataset {
872            let mut prefetch_queue = self
873                .prefetch_queue
874                .lock()
875                .expect("lock should not be poisoned");
876
877            // Process a few items from the prefetch queue
878            for _ in 0..3 {
879                if let Some(key) = prefetch_queue.pop_front() {
880                    // Check if already cached
881                    if self.base_cache.get(&key).is_none() {
882                        if let Some(data) = self.load_from_dataset(dataset, &key) {
883                            self.base_cache.put(key, data);
884                        }
885                    }
886                }
887            }
888        }
889    }
890}
891
892#[cfg(test)]
893mod tests {
894    use super::*;
895    use crate::TensorDataset;
896    use tenflowers_core::Tensor;
897
898    #[test]
899    fn test_smart_cache_creation() {
900        let cache: SmartCache<f32, usize> = SmartCache::new(
901            100,   // L1: 100 entries
902            1000,  // L2: 1000 entries
903            10000, // L3: 10000 entries
904            EvictionPolicy::LRU,
905            CacheConfig::default(),
906        );
907
908        let stats = cache.stats();
909        assert_eq!(stats.total_requests, 0);
910        assert_eq!(stats.l1_hits, 0);
911    }
912
913    #[test]
914    fn test_smart_cache_put_get() {
915        let cache: SmartCache<f32, usize> = SmartCache::new(
916            100,
917            1000,
918            10000,
919            EvictionPolicy::LRU,
920            CacheConfig::default(),
921        );
922
923        let features = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2])
924            .expect("test: tensor creation should succeed");
925        let labels =
926            Tensor::<f32>::from_vec(vec![0.0], &[]).expect("test: tensor creation should succeed");
927
928        cache.put(0, (features.clone(), labels.clone()));
929
930        let retrieved = cache.get(&0).expect("test: get should succeed");
931        assert_eq!(retrieved.0.shape().dims(), features.shape().dims());
932        assert_eq!(retrieved.1.shape().dims(), labels.shape().dims());
933
934        let stats = cache.stats();
935        assert_eq!(stats.l1_hits, 1);
936        assert_eq!(stats.total_requests, 1);
937    }
938
939    #[test]
940    fn test_smart_cached_dataset() {
941        let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
942            .expect("test: tensor creation should succeed");
943        let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
944            .expect("test: tensor creation should succeed");
945
946        let base_dataset = TensorDataset::new(features, labels);
947        let cached_dataset = SmartCachedDataset::new(
948            base_dataset,
949            10,   // L1 size
950            100,  // L2 size
951            1000, // L3 size
952            EvictionPolicy::Adaptive,
953            CacheConfig::default(),
954        );
955
956        assert_eq!(cached_dataset.len(), 2);
957
958        // First access - cache miss
959        let (feat0, _label0) = cached_dataset.get(0).expect("index should be in bounds");
960        assert_eq!(feat0.shape().dims(), &[2]);
961
962        // Second access - cache hit
963        let (feat0_cached, _) = cached_dataset.get(0).expect("index should be in bounds");
964        assert_eq!(feat0_cached.shape().dims(), &[2]);
965
966        let stats = cached_dataset.cache_stats();
967        assert_eq!(stats.total_requests, 2);
968        assert_eq!(stats.l1_hits, 1);
969    }
970}