1use crate::Dataset;
7use std::collections::{HashMap, VecDeque};
8use std::hash::Hash;
9use std::sync::{Arc, Mutex, RwLock};
10use std::time::{Duration, Instant};
11use tenflowers_core::{Result, Tensor};
12
13#[derive(Debug, Clone)]
15#[allow(dead_code)]
16pub struct AccessPattern {
17 last_access: Instant,
19 access_count: u64,
21 avg_interval: Duration,
23 is_sequential: bool,
25 frequency_score: f64,
27}
28
29impl AccessPattern {
30 fn new() -> Self {
31 Self {
32 last_access: Instant::now(),
33 access_count: 1,
34 avg_interval: Duration::from_secs(0),
35 is_sequential: false,
36 frequency_score: 1.0,
37 }
38 }
39
40 fn update(&mut self, now: Instant) {
41 let interval = now.duration_since(self.last_access);
42
43 if self.access_count > 1 {
45 let alpha = 0.1;
46 self.avg_interval = Duration::from_secs_f64(
47 alpha * interval.as_secs_f64() + (1.0 - alpha) * self.avg_interval.as_secs_f64(),
48 );
49 } else {
50 self.avg_interval = interval;
51 }
52
53 self.last_access = now;
54 self.access_count += 1;
55
56 let time_decay = (-interval.as_secs_f64() / 300.0).exp(); self.frequency_score = self.frequency_score * time_decay + 1.0;
59 }
60
61 fn priority_score(&self) -> f64 {
62 let recency_score = 1.0 / (1.0 + self.last_access.elapsed().as_secs_f64() / 60.0);
63 let frequency_weight = 0.7;
64 let recency_weight = 0.3;
65
66 frequency_weight * self.frequency_score + recency_weight * recency_score
67 }
68}
69
70#[derive(Debug, Clone)]
72pub enum EvictionPolicy {
73 LRU,
75 LFU,
77 Adaptive,
79 TimeBasedTTL(Duration),
81 Hybrid,
83}
84
85#[derive(Debug, Clone)]
87pub enum CacheLevel {
88 L1Memory,
90 L2Storage,
92 L3Remote,
94}
95
96#[derive(Debug, Clone)]
98#[allow(dead_code)]
99struct CacheEntry<T> {
100 data: (Tensor<T>, Tensor<T>),
101 pattern: AccessPattern,
102 size: usize,
103 level: CacheLevel,
104 compressed: bool,
105 ttl: Option<Instant>,
106}
107
108impl<T> CacheEntry<T> {
109 fn new(data: (Tensor<T>, Tensor<T>), level: CacheLevel) -> Self {
110 let size = data.0.shape().size() + data.1.shape().size();
111 Self {
112 data,
113 pattern: AccessPattern::new(),
114 size,
115 level,
116 compressed: false,
117 ttl: None,
118 }
119 }
120
121 fn is_expired(&self) -> bool {
122 if let Some(ttl) = self.ttl {
123 Instant::now() > ttl
124 } else {
125 false
126 }
127 }
128}
129
130pub struct SmartCache<T, K>
132where
133 K: Eq + Hash + Clone,
134{
135 l1_cache: Arc<RwLock<HashMap<K, CacheEntry<T>>>>,
137 l2_cache: Arc<RwLock<HashMap<K, CacheEntry<T>>>>,
139 l3_cache: Arc<RwLock<HashMap<K, CacheEntry<T>>>>,
141
142 l1_max_size: usize,
144 l2_max_size: usize,
145 l3_max_size: usize,
146
147 l1_current_size: Arc<Mutex<usize>>,
149 l2_current_size: Arc<Mutex<usize>>,
150 l3_current_size: Arc<Mutex<usize>>,
151
152 policy: EvictionPolicy,
154
155 l1_access_order: Arc<Mutex<VecDeque<K>>>,
157 l2_access_order: Arc<Mutex<VecDeque<K>>>,
158 l3_access_order: Arc<Mutex<VecDeque<K>>>,
159
160 stats: Arc<Mutex<CacheStats>>,
162
163 config: CacheConfig,
165}
166
167#[derive(Debug, Clone)]
169#[allow(dead_code)]
170pub struct CacheConfig {
171 enable_compression: bool,
173 default_ttl: Option<Duration>,
175 promotion_threshold: f64,
177 demotion_threshold: f64,
179 memory_pressure_threshold: f64,
181 cleanup_interval: Duration,
183}
184
185impl Default for CacheConfig {
186 fn default() -> Self {
187 Self {
188 enable_compression: true,
189 default_ttl: Some(Duration::from_secs(3600)), promotion_threshold: 3.0,
191 demotion_threshold: 0.5,
192 memory_pressure_threshold: 0.8,
193 cleanup_interval: Duration::from_secs(60),
194 }
195 }
196}
197
198#[derive(Debug, Clone)]
200pub struct CacheStats {
201 pub l1_hits: u64,
202 pub l2_hits: u64,
203 pub l3_hits: u64,
204 pub misses: u64,
205 pub evictions: u64,
206 pub promotions: u64,
207 pub demotions: u64,
208 pub total_requests: u64,
209 pub avg_access_time: Duration,
210}
211
212impl Default for CacheStats {
213 fn default() -> Self {
214 Self {
215 l1_hits: 0,
216 l2_hits: 0,
217 l3_hits: 0,
218 misses: 0,
219 evictions: 0,
220 promotions: 0,
221 demotions: 0,
222 total_requests: 0,
223 avg_access_time: Duration::from_secs(0),
224 }
225 }
226}
227
228impl<T, K> SmartCache<T, K>
229where
230 T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
231 K: Eq + Hash + Clone + Send + Sync,
232{
233 pub fn new(
235 l1_max_size: usize,
236 l2_max_size: usize,
237 l3_max_size: usize,
238 policy: EvictionPolicy,
239 config: CacheConfig,
240 ) -> Self {
241 Self {
242 l1_cache: Arc::new(RwLock::new(HashMap::new())),
243 l2_cache: Arc::new(RwLock::new(HashMap::new())),
244 l3_cache: Arc::new(RwLock::new(HashMap::new())),
245 l1_max_size,
246 l2_max_size,
247 l3_max_size,
248 l1_current_size: Arc::new(Mutex::new(0)),
249 l2_current_size: Arc::new(Mutex::new(0)),
250 l3_current_size: Arc::new(Mutex::new(0)),
251 policy,
252 l1_access_order: Arc::new(Mutex::new(VecDeque::new())),
253 l2_access_order: Arc::new(Mutex::new(VecDeque::new())),
254 l3_access_order: Arc::new(Mutex::new(VecDeque::new())),
255 stats: Arc::new(Mutex::new(CacheStats::default())),
256 config,
257 }
258 }
259
260 pub fn get(&self, key: &K) -> Option<(Tensor<T>, Tensor<T>)> {
262 let start_time = Instant::now();
263 let mut stats = self.stats.lock().expect("lock should not be poisoned");
264 stats.total_requests += 1;
265 drop(stats);
266
267 if let Some(mut entry) = self.get_from_level(key, CacheLevel::L1Memory) {
269 entry.pattern.update(Instant::now());
270 self.update_stats_hit(CacheLevel::L1Memory, start_time);
271 return Some(entry.data);
272 }
273
274 if let Some(mut entry) = self.get_from_level(key, CacheLevel::L2Storage) {
276 entry.pattern.update(Instant::now());
277
278 if entry.pattern.priority_score() > self.config.promotion_threshold {
280 self.promote_entry(key.clone(), entry.clone(), CacheLevel::L1Memory);
281 }
282
283 self.update_stats_hit(CacheLevel::L2Storage, start_time);
284 return Some(entry.data);
285 }
286
287 if let Some(mut entry) = self.get_from_level(key, CacheLevel::L3Remote) {
289 entry.pattern.update(Instant::now());
290
291 if entry.pattern.priority_score() > self.config.promotion_threshold {
293 self.promote_entry(key.clone(), entry.clone(), CacheLevel::L2Storage);
294 }
295
296 self.update_stats_hit(CacheLevel::L3Remote, start_time);
297 return Some(entry.data);
298 }
299
300 let mut stats = self.stats.lock().expect("lock should not be poisoned");
302 stats.misses += 1;
303 None
304 }
305
306 pub fn put(&self, key: K, value: (Tensor<T>, Tensor<T>)) {
308 let entry = CacheEntry::new(value, CacheLevel::L1Memory);
309
310 if self.try_insert_at_level(key.clone(), entry.clone(), CacheLevel::L1Memory) {
312 return;
313 }
314
315 if self.try_insert_at_level(key.clone(), entry.clone(), CacheLevel::L2Storage) {
317 return;
318 }
319
320 self.try_insert_at_level(key, entry, CacheLevel::L3Remote);
322 }
323
324 fn get_from_level(&self, key: &K, level: CacheLevel) -> Option<CacheEntry<T>> {
325 let cache = match level {
326 CacheLevel::L1Memory => &self.l1_cache,
327 CacheLevel::L2Storage => &self.l2_cache,
328 CacheLevel::L3Remote => &self.l3_cache,
329 };
330
331 let cache_read = cache.read().expect("read lock should not be poisoned");
332 cache_read.get(key).and_then(|entry| {
333 if entry.is_expired() {
334 None
335 } else {
336 Some(entry.clone())
337 }
338 })
339 }
340
341 fn try_insert_at_level(&self, key: K, mut entry: CacheEntry<T>, level: CacheLevel) -> bool {
342 let (cache, current_size, max_size, access_order) = match level {
343 CacheLevel::L1Memory => (
344 &self.l1_cache,
345 &self.l1_current_size,
346 self.l1_max_size,
347 &self.l1_access_order,
348 ),
349 CacheLevel::L2Storage => (
350 &self.l2_cache,
351 &self.l2_current_size,
352 self.l2_max_size,
353 &self.l2_access_order,
354 ),
355 CacheLevel::L3Remote => (
356 &self.l3_cache,
357 &self.l3_current_size,
358 self.l3_max_size,
359 &self.l3_access_order,
360 ),
361 };
362
363 entry.level = level.clone();
364 if let Some(ttl) = self.config.default_ttl {
365 entry.ttl = Some(Instant::now() + ttl);
366 }
367
368 let mut size_guard = current_size.lock().expect("lock should not be poisoned");
369
370 while *size_guard + entry.size > max_size {
372 if !self.evict_from_level(level.clone()) {
373 return false; }
375 *size_guard = current_size
376 .lock()
377 .expect("lock should not be poisoned")
378 .saturating_sub(entry.size);
379 }
380
381 let mut cache_write = cache.write().expect("write lock should not be poisoned");
383 cache_write.insert(key.clone(), entry.clone());
384 *size_guard += entry.size;
385
386 let mut order = access_order.lock().expect("lock should not be poisoned");
388 order.push_back(key);
389
390 true
391 }
392
393 fn evict_from_level(&self, level: CacheLevel) -> bool {
394 let (cache, current_size, access_order) = match level {
395 CacheLevel::L1Memory => (&self.l1_cache, &self.l1_current_size, &self.l1_access_order),
396 CacheLevel::L2Storage => (&self.l2_cache, &self.l2_current_size, &self.l2_access_order),
397 CacheLevel::L3Remote => (&self.l3_cache, &self.l3_current_size, &self.l3_access_order),
398 };
399
400 let victim_key = match self.policy {
401 EvictionPolicy::LRU => {
402 let mut order = access_order.lock().expect("lock should not be poisoned");
403 order.pop_front()
404 }
405 EvictionPolicy::LFU | EvictionPolicy::Adaptive | EvictionPolicy::Hybrid => {
406 self.find_lfu_victim(cache)
407 }
408 EvictionPolicy::TimeBasedTTL(_) => self.find_expired_victim(cache),
409 };
410
411 if let Some(key) = victim_key {
412 let mut cache_write = cache.write().expect("write lock should not be poisoned");
413 if let Some(entry) = cache_write.remove(&key) {
414 let mut size_guard = current_size.lock().expect("lock should not be poisoned");
415 *size_guard = size_guard.saturating_sub(entry.size);
416
417 let mut stats = self.stats.lock().expect("lock should not be poisoned");
418 stats.evictions += 1;
419
420 return true;
421 }
422 }
423
424 false
425 }
426
427 fn find_lfu_victim(&self, cache: &Arc<RwLock<HashMap<K, CacheEntry<T>>>>) -> Option<K> {
428 let cache_read = cache.read().expect("read lock should not be poisoned");
429 cache_read
430 .iter()
431 .min_by(|(_, a), (_, b)| {
432 a.pattern
433 .priority_score()
434 .partial_cmp(&b.pattern.priority_score())
435 .unwrap_or(std::cmp::Ordering::Equal)
436 })
437 .map(|(k, _)| k.clone())
438 }
439
440 fn find_expired_victim(&self, cache: &Arc<RwLock<HashMap<K, CacheEntry<T>>>>) -> Option<K> {
441 let cache_read = cache.read().expect("read lock should not be poisoned");
442 cache_read
443 .iter()
444 .find(|(_, entry)| entry.is_expired())
445 .map(|(k, _)| k.clone())
446 }
447
448 fn promote_entry(&self, key: K, entry: CacheEntry<T>, target_level: CacheLevel) {
449 let original_level = entry.level.clone();
450 if self.try_insert_at_level(key.clone(), entry, target_level) {
451 match original_level {
453 CacheLevel::L3Remote => {
454 let mut cache = self
455 .l3_cache
456 .write()
457 .expect("write lock should not be poisoned");
458 cache.remove(&key);
459 }
460 CacheLevel::L2Storage => {
461 let mut cache = self
462 .l2_cache
463 .write()
464 .expect("write lock should not be poisoned");
465 cache.remove(&key);
466 }
467 _ => {}
468 }
469
470 let mut stats = self.stats.lock().expect("lock should not be poisoned");
471 stats.promotions += 1;
472 }
473 }
474
475 fn update_stats_hit(&self, level: CacheLevel, start_time: Instant) {
476 let mut stats = self.stats.lock().expect("lock should not be poisoned");
477 match level {
478 CacheLevel::L1Memory => stats.l1_hits += 1,
479 CacheLevel::L2Storage => stats.l2_hits += 1,
480 CacheLevel::L3Remote => stats.l3_hits += 1,
481 }
482
483 let access_time = start_time.elapsed();
484 let alpha = 0.1;
485 stats.avg_access_time = Duration::from_secs_f64(
486 alpha * access_time.as_secs_f64() + (1.0 - alpha) * stats.avg_access_time.as_secs_f64(),
487 );
488 }
489
490 pub fn stats(&self) -> CacheStats {
492 self.stats
493 .lock()
494 .expect("lock should not be poisoned")
495 .clone()
496 }
497
498 pub fn clear(&self) {
500 let mut l1 = self
501 .l1_cache
502 .write()
503 .expect("write lock should not be poisoned");
504 let mut l2 = self
505 .l2_cache
506 .write()
507 .expect("write lock should not be poisoned");
508 let mut l3 = self
509 .l3_cache
510 .write()
511 .expect("write lock should not be poisoned");
512
513 l1.clear();
514 l2.clear();
515 l3.clear();
516
517 *self
518 .l1_current_size
519 .lock()
520 .expect("lock should not be poisoned") = 0;
521 *self
522 .l2_current_size
523 .lock()
524 .expect("lock should not be poisoned") = 0;
525 *self
526 .l3_current_size
527 .lock()
528 .expect("lock should not be poisoned") = 0;
529 }
530
531 pub fn cleanup_expired(&self) {
533 for level in [
534 CacheLevel::L1Memory,
535 CacheLevel::L2Storage,
536 CacheLevel::L3Remote,
537 ] {
538 while self
539 .find_expired_victim(match level {
540 CacheLevel::L1Memory => &self.l1_cache,
541 CacheLevel::L2Storage => &self.l2_cache,
542 CacheLevel::L3Remote => &self.l3_cache,
543 })
544 .is_some()
545 {
546 self.evict_from_level(level.clone());
547 }
548 }
549 }
550}
551
552pub struct SmartCachedDataset<T, D: Dataset<T>> {
554 dataset: D,
555 cache: Arc<SmartCache<T, usize>>,
556}
557
558impl<T, D: Dataset<T>> SmartCachedDataset<T, D>
559where
560 T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
561{
562 pub fn new(
564 dataset: D,
565 l1_size: usize,
566 l2_size: usize,
567 l3_size: usize,
568 policy: EvictionPolicy,
569 config: CacheConfig,
570 ) -> Self {
571 let cache = Arc::new(SmartCache::new(l1_size, l2_size, l3_size, policy, config));
572
573 Self { dataset, cache }
574 }
575
576 pub fn cache_stats(&self) -> CacheStats {
578 self.cache.stats()
579 }
580
581 pub fn clear_cache(&self) {
583 self.cache.clear();
584 }
585}
586
587impl<T, D: Dataset<T>> Dataset<T> for SmartCachedDataset<T, D>
588where
589 T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
590{
591 fn len(&self) -> usize {
592 self.dataset.len()
593 }
594
595 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
596 if let Some(cached) = self.cache.get(&index) {
598 return Ok(cached);
599 }
600
601 let sample = self.dataset.get(index)?;
603
604 self.cache.put(index, sample.clone());
606
607 Ok(sample)
608 }
609}
610
611#[derive(Debug, Clone)]
613pub struct AccessPatternPredictor<K>
614where
615 K: Eq + Hash + Clone + Send + Sync,
616{
617 access_history: VecDeque<(K, Instant)>,
619 sequence_patterns: HashMap<Vec<K>, f64>,
621 max_history_size: usize,
623 min_pattern_length: usize,
625 max_pattern_length: usize,
627}
628
629impl<K> AccessPatternPredictor<K>
630where
631 K: Eq + Hash + Clone + Send + Sync,
632{
633 pub fn new() -> Self {
634 Self {
635 access_history: VecDeque::with_capacity(1000),
636 sequence_patterns: HashMap::new(),
637 max_history_size: 1000,
638 min_pattern_length: 2,
639 max_pattern_length: 5,
640 }
641 }
642
643 pub fn record_access(&mut self, key: K) {
645 let now = Instant::now();
646
647 self.access_history.push_back((key.clone(), now));
649
650 if self.access_history.len() > self.max_history_size {
652 self.access_history.pop_front();
653 }
654
655 self.update_patterns();
657 }
658
659 pub fn predict_next_accesses(&self, current_key: &K, max_predictions: usize) -> Vec<(K, f64)> {
661 let mut predictions = Vec::new();
662
663 for pattern_len in self.min_pattern_length..=self.max_pattern_length {
665 if let Some(recent_sequence) = self.get_recent_sequence(pattern_len) {
666 if recent_sequence.last() == Some(current_key) {
667 for (pattern, confidence) in &self.sequence_patterns {
669 if pattern.len() > pattern_len
670 && pattern[..pattern_len] == recent_sequence[..]
671 {
672 let next_key = &pattern[pattern_len];
673 predictions.push((next_key.clone(), *confidence));
674 }
675 }
676 }
677 }
678 }
679
680 predictions.sort_by(|a, b| {
682 b.1.partial_cmp(&a.1)
683 .expect("partial_cmp should not return None for valid values")
684 });
685 predictions.truncate(max_predictions);
686 predictions
687 }
688
689 fn get_recent_sequence(&self, length: usize) -> Option<Vec<K>> {
691 if self.access_history.len() < length {
692 return None;
693 }
694
695 let recent: Vec<K> = self
696 .access_history
697 .iter()
698 .rev()
699 .take(length)
700 .map(|(k, _)| k.clone())
701 .collect::<Vec<_>>()
702 .into_iter()
703 .rev()
704 .collect();
705
706 Some(recent)
707 }
708
709 fn update_patterns(&mut self) {
711 let history_keys: Vec<K> = self.access_history.iter().map(|(k, _)| k.clone()).collect();
712
713 for pattern_len in self.min_pattern_length..=self.max_pattern_length {
715 if history_keys.len() >= pattern_len {
716 for i in 0..=(history_keys.len() - pattern_len) {
717 let pattern = history_keys[i..i + pattern_len].to_vec();
718
719 let age_factor = 1.0 - (i as f64 / history_keys.len() as f64 * 0.1);
721
722 *self.sequence_patterns.entry(pattern).or_insert(0.0) += age_factor;
723 }
724 }
725 }
726
727 for confidence in self.sequence_patterns.values_mut() {
729 *confidence *= 0.99; }
731
732 self.sequence_patterns
734 .retain(|_, confidence| *confidence > 0.1);
735 }
736}
737
738impl<K> Default for AccessPatternPredictor<K>
739where
740 K: Eq + Hash + Clone + Send + Sync,
741{
742 fn default() -> Self {
743 Self::new()
744 }
745}
746
747pub struct PredictiveSmartCache<T, K>
749where
750 T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
751 K: Eq + Hash + Clone + Send + Sync,
752{
753 base_cache: SmartCache<T, K>,
755 predictor: Arc<Mutex<AccessPatternPredictor<K>>>,
757 dataset: Option<Arc<dyn Dataset<T>>>,
759 prefetch_queue: Arc<Mutex<VecDeque<K>>>,
761 max_prefetch_size: usize,
763}
764
765impl<T, K> PredictiveSmartCache<T, K>
766where
767 T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
768 K: Eq + Hash + Clone + Send + Sync,
769{
770 pub fn new(
771 l1_max_size: usize,
772 l2_max_size: usize,
773 l3_max_size: usize,
774 policy: EvictionPolicy,
775 config: CacheConfig,
776 max_prefetch_size: usize,
777 ) -> Self {
778 Self {
779 base_cache: SmartCache::new(l1_max_size, l2_max_size, l3_max_size, policy, config),
780 predictor: Arc::new(Mutex::new(AccessPatternPredictor::new())),
781 dataset: None,
782 prefetch_queue: Arc::new(Mutex::new(VecDeque::with_capacity(max_prefetch_size))),
783 max_prefetch_size,
784 }
785 }
786
787 pub fn set_dataset(&mut self, dataset: Arc<dyn Dataset<T>>) {
789 self.dataset = Some(dataset);
790 }
791
792 pub fn get(&self, key: &K) -> Option<(Tensor<T>, Tensor<T>)> {
794 {
796 let mut predictor = self.predictor.lock().expect("lock should not be poisoned");
797 predictor.record_access(key.clone());
798 }
799
800 if let Some(result) = self.base_cache.get(key) {
802 self.trigger_prefetch(key);
804 return Some(result);
805 }
806
807 if let Some(ref dataset) = self.dataset {
809 if let Some(data) = self.load_from_dataset(dataset, key) {
812 self.base_cache.put(key.clone(), data.clone());
813 self.trigger_prefetch(key);
814 return Some(data);
815 }
816 }
817
818 None
819 }
820
821 pub fn put(&self, key: K, data: (Tensor<T>, Tensor<T>)) {
823 self.base_cache.put(key, data);
824 }
825
826 pub fn stats(&self) -> CacheStats {
828 self.base_cache.stats()
829 }
830
831 fn trigger_prefetch(&self, current_key: &K) {
833 let predictions = {
834 let predictor = self.predictor.lock().expect("lock should not be poisoned");
835 predictor.predict_next_accesses(current_key, 3) };
837
838 let mut prefetch_queue = self
839 .prefetch_queue
840 .lock()
841 .expect("lock should not be poisoned");
842
843 for (predicted_key, confidence) in predictions {
844 if confidence > 0.5 && self.base_cache.get(&predicted_key).is_none() {
846 prefetch_queue.push_back(predicted_key);
847
848 if prefetch_queue.len() > self.max_prefetch_size {
850 prefetch_queue.pop_front();
851 }
852 }
853 }
854 }
855
856 fn load_from_dataset(
858 &self,
859 _dataset: &Arc<dyn Dataset<T>>,
860 _key: &K,
861 ) -> Option<(Tensor<T>, Tensor<T>)> {
862 None
867 }
868
869 pub fn process_prefetch_queue(&self) {
871 if let Some(ref dataset) = self.dataset {
872 let mut prefetch_queue = self
873 .prefetch_queue
874 .lock()
875 .expect("lock should not be poisoned");
876
877 for _ in 0..3 {
879 if let Some(key) = prefetch_queue.pop_front() {
880 if self.base_cache.get(&key).is_none() {
882 if let Some(data) = self.load_from_dataset(dataset, &key) {
883 self.base_cache.put(key, data);
884 }
885 }
886 }
887 }
888 }
889 }
890}
891
892#[cfg(test)]
893mod tests {
894 use super::*;
895 use crate::TensorDataset;
896 use tenflowers_core::Tensor;
897
898 #[test]
899 fn test_smart_cache_creation() {
900 let cache: SmartCache<f32, usize> = SmartCache::new(
901 100, 1000, 10000, EvictionPolicy::LRU,
905 CacheConfig::default(),
906 );
907
908 let stats = cache.stats();
909 assert_eq!(stats.total_requests, 0);
910 assert_eq!(stats.l1_hits, 0);
911 }
912
913 #[test]
914 fn test_smart_cache_put_get() {
915 let cache: SmartCache<f32, usize> = SmartCache::new(
916 100,
917 1000,
918 10000,
919 EvictionPolicy::LRU,
920 CacheConfig::default(),
921 );
922
923 let features = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2])
924 .expect("test: tensor creation should succeed");
925 let labels =
926 Tensor::<f32>::from_vec(vec![0.0], &[]).expect("test: tensor creation should succeed");
927
928 cache.put(0, (features.clone(), labels.clone()));
929
930 let retrieved = cache.get(&0).expect("test: get should succeed");
931 assert_eq!(retrieved.0.shape().dims(), features.shape().dims());
932 assert_eq!(retrieved.1.shape().dims(), labels.shape().dims());
933
934 let stats = cache.stats();
935 assert_eq!(stats.l1_hits, 1);
936 assert_eq!(stats.total_requests, 1);
937 }
938
939 #[test]
940 fn test_smart_cached_dataset() {
941 let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
942 .expect("test: tensor creation should succeed");
943 let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
944 .expect("test: tensor creation should succeed");
945
946 let base_dataset = TensorDataset::new(features, labels);
947 let cached_dataset = SmartCachedDataset::new(
948 base_dataset,
949 10, 100, 1000, EvictionPolicy::Adaptive,
953 CacheConfig::default(),
954 );
955
956 assert_eq!(cached_dataset.len(), 2);
957
958 let (feat0, _label0) = cached_dataset.get(0).expect("index should be in bounds");
960 assert_eq!(feat0.shape().dims(), &[2]);
961
962 let (feat0_cached, _) = cached_dataset.get(0).expect("index should be in bounds");
964 assert_eq!(feat0_cached.shape().dims(), &[2]);
965
966 let stats = cached_dataset.cache_stats();
967 assert_eq!(stats.total_requests, 2);
968 assert_eq!(stats.l1_hits, 1);
969 }
970}