Skip to main content

tenflowers_dataset/cache/
persistent.rs

1//! Persistent caching implementations
2//!
3//! This module provides disk-based caching that persists across program runs.
4
5use crate::cache::dataset::CacheStats;
6use crate::Dataset;
7use std::collections::HashMap;
8use std::fs::{create_dir_all, File};
9use std::hash::Hash;
10use std::io::{BufReader, BufWriter};
11use std::marker::PhantomData;
12use std::path::Path;
13use std::sync::{Arc, Mutex};
14use tenflowers_core::{Result, Tensor, TensorError};
15
16#[cfg(feature = "serialize")]
17/// Persistent cache that stores data on disk with LRU eviction
18pub struct PersistentCache<K, V> {
19    cache_dir: std::path::PathBuf,
20    capacity: usize,
21    index: HashMap<K, (String, usize)>, // (filename, access_order)
22    access_counter: usize,
23    _phantom: PhantomData<V>,
24}
25
26impl<K, V> PersistentCache<K, V>
27where
28    K: Clone + Eq + Hash + std::fmt::Display + std::str::FromStr,
29    V: Clone + serde::Serialize + for<'de> serde::Deserialize<'de>,
30{
31    /// Create a new persistent cache with the specified directory and capacity
32    pub fn new<P: AsRef<Path>>(cache_dir: P, capacity: usize) -> Result<Self> {
33        let cache_dir = cache_dir.as_ref().to_path_buf();
34
35        // Create cache directory if it doesn't exist
36        if !cache_dir.exists() {
37            create_dir_all(&cache_dir).map_err(|e| {
38                TensorError::invalid_argument(format!("Failed to create cache directory: {e}"))
39            })?;
40        }
41
42        let mut cache = Self {
43            cache_dir,
44            capacity,
45            index: HashMap::new(),
46            access_counter: 0,
47            _phantom: PhantomData,
48        };
49
50        // Load existing cache index
51        cache.load_index()?;
52
53        Ok(cache)
54    }
55
56    /// Load cache index from disk
57    fn load_index(&mut self) -> Result<()> {
58        let index_path = self.cache_dir.join("cache_index.json");
59
60        if !index_path.exists() {
61            return Ok(()); // No existing index
62        }
63
64        let file = File::open(&index_path).map_err(|e| {
65            TensorError::invalid_argument(format!("Failed to open cache index: {e}"))
66        })?;
67
68        let reader = BufReader::new(file);
69
70        // Simple JSON format: {"key": {"filename": "...", "access_order": 123}, ...}
71        let index_data: HashMap<String, (String, usize)> = serde_json::from_reader(reader)
72            .map_err(|e| {
73                TensorError::invalid_argument(format!("Failed to parse cache index: {e}"))
74            })?;
75
76        // Convert string keys back to original type (simplified approach)
77        for (key_str, (filename, access_order)) in index_data {
78            if let Ok(key) = key_str.parse::<K>() {
79                self.index.insert(key, (filename, access_order));
80                self.access_counter = self.access_counter.max(access_order);
81            }
82        }
83
84        self.access_counter += 1; // Ensure next access has higher number
85
86        Ok(())
87    }
88
89    /// Save cache index to disk
90    fn save_index(&self) -> Result<()> {
91        let index_path = self.cache_dir.join("cache_index.json");
92
93        let file = File::create(&index_path).map_err(|e| {
94            TensorError::invalid_argument(format!("Failed to create cache index: {e}"))
95        })?;
96
97        let writer = BufWriter::new(file);
98
99        // Convert to string keys for JSON serialization
100        let index_data: HashMap<String, (String, usize)> = self
101            .index
102            .iter()
103            .map(|(k, v)| (k.to_string(), v.clone()))
104            .collect();
105
106        serde_json::to_writer(writer, &index_data).map_err(|e| {
107            TensorError::invalid_argument(format!("Failed to save cache index: {e}"))
108        })?;
109
110        Ok(())
111    }
112
113    /// Get a value from the cache
114    pub fn get(&mut self, key: &K) -> Result<Option<V>> {
115        if let Some((filename, access_time)) = self.index.get_mut(key) {
116            // Update access time
117            self.access_counter += 1;
118            *access_time = self.access_counter;
119
120            // Load value from disk
121            let file_path = self.cache_dir.join(filename);
122
123            if !file_path.exists() {
124                // File was deleted, remove from index
125                self.index.remove(key);
126                return Ok(None);
127            }
128
129            let file = File::open(&file_path).map_err(|e| {
130                TensorError::invalid_argument(format!("Failed to open cache file: {e}"))
131            })?;
132
133            let reader = BufReader::new(file);
134
135            let value: V =
136                oxicode::serde::decode_from_std_read(reader, oxicode::config::standard())
137                    .map_err(|e| {
138                        TensorError::invalid_argument(format!(
139                            "Failed to deserialize cached value: {e}"
140                        ))
141                    })?
142                    .0;
143
144            Ok(Some(value))
145        } else {
146            Ok(None)
147        }
148    }
149
150    /// Insert a value into the cache
151    pub fn insert(&mut self, key: K, value: V) -> Result<()> {
152        self.access_counter += 1;
153
154        // Check if we need to evict items
155        if self.index.len() >= self.capacity && !self.index.contains_key(&key) {
156            self.evict_lru()?;
157        }
158
159        // Generate filename for this entry
160        let filename = format!("cache_{}_{}.bin", key, self.access_counter);
161        let file_path = self.cache_dir.join(&filename);
162
163        // Serialize and save to disk
164        let file = File::create(&file_path).map_err(|e| {
165            TensorError::invalid_argument(format!("Failed to create cache file: {e}"))
166        })?;
167
168        let writer = BufWriter::new(file);
169
170        oxicode::serde::encode_into_std_write(&value, writer, oxicode::config::standard())
171            .map_err(|e| {
172                TensorError::invalid_argument(format!("Failed to serialize value: {e}"))
173            })?;
174
175        // Update index
176        if let Some((old_filename, _)) = self.index.insert(key, (filename, self.access_counter)) {
177            // Remove old file if it exists
178            let old_path = self.cache_dir.join(old_filename);
179            let _ = std::fs::remove_file(old_path); // Ignore errors
180        }
181
182        // Save updated index
183        self.save_index()?;
184
185        Ok(())
186    }
187
188    /// Evict the least recently used item
189    fn evict_lru(&mut self) -> Result<()> {
190        if let Some((lru_key, (filename, _))) = self
191            .index
192            .iter()
193            .min_by_key(|(_, (_, access_time))| *access_time)
194            .map(|(k, v)| (k.clone(), v.clone()))
195        {
196            // Remove file
197            let file_path = self.cache_dir.join(&filename);
198            let _ = std::fs::remove_file(file_path); // Ignore errors
199
200            // Remove from index
201            self.index.remove(&lru_key);
202        }
203
204        Ok(())
205    }
206
207    /// Get current cache size
208    pub fn len(&self) -> usize {
209        self.index.len()
210    }
211
212    /// Check if cache is empty
213    pub fn is_empty(&self) -> bool {
214        self.index.is_empty()
215    }
216
217    /// Clear all cached items
218    pub fn clear(&mut self) -> Result<()> {
219        // Remove all cache files
220        for (filename, _) in self.index.values() {
221            let file_path = self.cache_dir.join(filename);
222            let _ = std::fs::remove_file(file_path); // Ignore errors
223        }
224
225        // Clear index
226        self.index.clear();
227        self.access_counter = 0;
228
229        // Save empty index
230        self.save_index()?;
231
232        Ok(())
233    }
234
235    /// Get cache capacity
236    pub fn capacity(&self) -> usize {
237        self.capacity
238    }
239
240    /// Get cache directory
241    pub fn cache_dir(&self) -> &Path {
242        &self.cache_dir
243    }
244}
245
246#[cfg(feature = "serialize")]
247/// Persistent cache that works with byte arrays for tensor data
248pub struct TensorPersistentCache {
249    cache: PersistentCache<usize, (Vec<u8>, Vec<u8>)>, // Serialized tensor data
250}
251
252impl TensorPersistentCache {
253    /// Create a new tensor persistent cache
254    pub fn new<P: AsRef<Path>>(cache_dir: P, capacity: usize) -> Result<Self> {
255        Ok(Self {
256            cache: PersistentCache::new(cache_dir, capacity)?,
257        })
258    }
259
260    /// Get tensors from cache
261    pub fn get<T>(&mut self, index: &usize) -> Result<Option<(Tensor<T>, Tensor<T>)>>
262    where
263        T: Clone
264            + Default
265            + scirs2_core::numeric::Zero
266            + Send
267            + Sync
268            + 'static
269            + scirs2_core::num_traits::cast::NumCast,
270    {
271        if let Some((features_bytes, labels_bytes)) = self.cache.get(index)? {
272            // Deserialize tensors from byte arrays
273            let features_tensor = Self::deserialize_tensor(&features_bytes)?;
274            let labels_tensor = Self::deserialize_tensor(&labels_bytes)?;
275            Ok(Some((features_tensor, labels_tensor)))
276        } else {
277            Ok(None)
278        }
279    }
280
281    /// Insert tensors into cache
282    pub fn insert<T>(
283        &mut self,
284        index: usize,
285        features: &Tensor<T>,
286        labels: &Tensor<T>,
287    ) -> Result<()>
288    where
289        T: Clone
290            + Default
291            + scirs2_core::numeric::Zero
292            + Send
293            + Sync
294            + 'static
295            + scirs2_core::num_traits::cast::NumCast,
296    {
297        // Serialize tensors to byte arrays
298        let features_bytes = Self::serialize_tensor(features)?;
299        let labels_bytes = Self::serialize_tensor(labels)?;
300
301        // Store in persistent cache
302        self.cache.insert(index, (features_bytes, labels_bytes))?;
303        Ok(())
304    }
305
306    /// Clear cache
307    pub fn clear(&mut self) -> Result<()> {
308        self.cache.clear()
309    }
310
311    /// Serialize a tensor to bytes
312    /// Format: [type_id: u8][shape_len: u32][shape: u32...][data: T...]
313    fn serialize_tensor<T>(tensor: &Tensor<T>) -> Result<Vec<u8>>
314    where
315        T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
316    {
317        let mut bytes = Vec::new();
318
319        // Determine type ID based on size of T (simple heuristic)
320        let type_id = std::mem::size_of::<T>() as u8;
321        bytes.push(type_id);
322
323        // Serialize shape
324        let shape = tensor.shape().dims();
325        let shape_len = shape.len() as u32;
326        bytes.extend_from_slice(&shape_len.to_le_bytes());
327
328        for &dim in shape {
329            bytes.extend_from_slice(&(dim as u32).to_le_bytes());
330        }
331
332        // Serialize data - try to get raw data
333        if let Some(data_slice) = tensor.as_slice() {
334            // For CPU tensors, convert each element to bytes safely
335            for element in data_slice.iter() {
336                // Use a safe approach to get bytes representation
337                let element_ptr = element as *const T as *const u8;
338                let element_bytes = std::mem::size_of::<T>();
339                // SAFETY: We're reading from a valid T reference for exactly size_of::<T>() bytes
340                #[allow(unsafe_code)]
341                let element_data =
342                    unsafe { std::slice::from_raw_parts(element_ptr, element_bytes) };
343                bytes.extend_from_slice(element_data);
344            }
345        } else {
346            return Err(TensorError::invalid_argument(
347                "Cannot serialize GPU tensors or tensors without CPU data".to_string(),
348            ));
349        }
350
351        Ok(bytes)
352    }
353
354    /// Deserialize a tensor from bytes
355    fn deserialize_tensor<T>(bytes: &[u8]) -> Result<Tensor<T>>
356    where
357        T: Clone
358            + Default
359            + scirs2_core::numeric::Zero
360            + Send
361            + Sync
362            + 'static
363            + scirs2_core::num_traits::cast::NumCast,
364    {
365        if bytes.len() < 5 {
366            // At least type_id + shape_len
367            return Err(TensorError::invalid_argument(
368                "Invalid tensor serialization: too few bytes".to_string(),
369            ));
370        }
371
372        let mut offset = 0;
373
374        // Read type ID (for validation)
375        let _type_id = bytes[offset];
376        offset += 1;
377
378        // Read shape length
379        let shape_len = u32::from_le_bytes([
380            bytes[offset],
381            bytes[offset + 1],
382            bytes[offset + 2],
383            bytes[offset + 3],
384        ]) as usize;
385        offset += 4;
386
387        if bytes.len() < offset + shape_len * 4 {
388            return Err(TensorError::invalid_argument(
389                "Invalid tensor serialization: insufficient bytes for shape".to_string(),
390            ));
391        }
392
393        // Read shape
394        let mut shape = Vec::with_capacity(shape_len);
395        for _ in 0..shape_len {
396            let dim = u32::from_le_bytes([
397                bytes[offset],
398                bytes[offset + 1],
399                bytes[offset + 2],
400                bytes[offset + 3],
401            ]) as usize;
402            shape.push(dim);
403            offset += 4;
404        }
405
406        // Calculate expected data size
407        let total_elements = shape.iter().product::<usize>();
408        let element_size = std::mem::size_of::<T>();
409        let expected_data_bytes = total_elements * element_size;
410
411        if bytes.len() < offset + expected_data_bytes {
412            return Err(TensorError::invalid_argument(
413                "Invalid tensor serialization: insufficient bytes for data".to_string(),
414            ));
415        }
416
417        // Deserialize data
418        let data_bytes = &bytes[offset..offset + expected_data_bytes];
419
420        // Convert bytes back to T values
421        let mut data = Vec::with_capacity(total_elements);
422        for i in 0..total_elements {
423            let element_offset = i * element_size;
424
425            // Simple conversion based on element size
426            let value = match element_size {
427                1 => {
428                    // u8 or i8
429                    let byte_val = data_bytes[element_offset];
430                    scirs2_core::num_traits::cast::NumCast::from(byte_val)
431                        .unwrap_or_else(T::default)
432                }
433                2 => {
434                    // u16 or i16
435                    if element_offset + 2 <= data_bytes.len() {
436                        let val = u16::from_le_bytes([
437                            data_bytes[element_offset],
438                            data_bytes[element_offset + 1],
439                        ]);
440                        scirs2_core::num_traits::cast::NumCast::from(val).unwrap_or_else(T::default)
441                    } else {
442                        T::default()
443                    }
444                }
445                4 => {
446                    // u32, i32, or f32
447                    if element_offset + 4 <= data_bytes.len() {
448                        let val = f32::from_le_bytes([
449                            data_bytes[element_offset],
450                            data_bytes[element_offset + 1],
451                            data_bytes[element_offset + 2],
452                            data_bytes[element_offset + 3],
453                        ]);
454                        scirs2_core::num_traits::cast::NumCast::from(val).unwrap_or_else(T::default)
455                    } else {
456                        T::default()
457                    }
458                }
459                8 => {
460                    // u64, i64, or f64
461                    if element_offset + 8 <= data_bytes.len() {
462                        let val = f64::from_le_bytes([
463                            data_bytes[element_offset],
464                            data_bytes[element_offset + 1],
465                            data_bytes[element_offset + 2],
466                            data_bytes[element_offset + 3],
467                            data_bytes[element_offset + 4],
468                            data_bytes[element_offset + 5],
469                            data_bytes[element_offset + 6],
470                            data_bytes[element_offset + 7],
471                        ]);
472                        scirs2_core::num_traits::cast::NumCast::from(val).unwrap_or_else(T::default)
473                    } else {
474                        T::default()
475                    }
476                }
477                _ => {
478                    // Unsupported size, use default
479                    T::default()
480                }
481            };
482
483            data.push(value);
484        }
485
486        // Create tensor from deserialized data
487        Tensor::from_vec(data, &shape)
488    }
489}
490
491#[cfg(feature = "serialize")]
492/// Dataset wrapper that uses persistent caching with simplified implementation
493pub struct PersistentlyCachedDataset<T, D: Dataset<T>> {
494    dataset: D,
495    cache: Arc<Mutex<TensorPersistentCache>>,
496    cache_stats: Arc<Mutex<CacheStats>>,
497    _phantom: PhantomData<T>,
498}
499
500impl<T, D: Dataset<T>> PersistentlyCachedDataset<T, D>
501where
502    T: Clone
503        + Default
504        + scirs2_core::numeric::Zero
505        + Send
506        + Sync
507        + 'static
508        + scirs2_core::num_traits::cast::NumCast,
509{
510    /// Create a new persistently cached dataset
511    pub fn new<P: AsRef<Path>>(dataset: D, cache_dir: P, cache_capacity: usize) -> Result<Self> {
512        let cache = TensorPersistentCache::new(cache_dir, cache_capacity)?;
513
514        Ok(Self {
515            dataset,
516            cache: Arc::new(Mutex::new(cache)),
517            cache_stats: Arc::new(Mutex::new(CacheStats::default())),
518            _phantom: PhantomData,
519        })
520    }
521
522    /// Get cache statistics
523    pub fn cache_stats(&self) -> Result<CacheStats> {
524        match self.cache_stats.lock() {
525            Ok(stats) => Ok(stats.clone()),
526            Err(_) => Err(TensorError::CacheError {
527                operation: "persistent_cache_stats".to_string(),
528                details: "Persistent cache stats mutex poisoned".to_string(),
529                recoverable: true,
530                context: None,
531            }),
532        }
533    }
534
535    /// Clear cache
536    pub fn clear_cache(&self) -> Result<()> {
537        match self.cache.lock() {
538            Ok(mut cache) => cache.clear()?,
539            Err(_) => {
540                return Err(TensorError::CacheError {
541                    operation: "persistent_cache_clear".to_string(),
542                    details: "Persistent cache mutex poisoned during clear".to_string(),
543                    recoverable: false,
544                    context: None,
545                })
546            }
547        }
548
549        match self.cache_stats.lock() {
550            Ok(mut stats) => {
551                *stats = CacheStats::default();
552                Ok(())
553            }
554            Err(_) => Err(TensorError::CacheError {
555                operation: "persistent_cache_clear_stats".to_string(),
556                details: "Persistent cache stats mutex poisoned during clear".to_string(),
557                recoverable: false,
558                context: None,
559            }),
560        }
561    }
562
563    /// Get underlying dataset
564    pub fn into_inner(self) -> D {
565        self.dataset
566    }
567
568    /// Get reference to underlying dataset
569    pub fn inner(&self) -> &D {
570        &self.dataset
571    }
572}
573
574impl<T, D: Dataset<T>> Dataset<T> for PersistentlyCachedDataset<T, D>
575where
576    T: Clone
577        + Default
578        + scirs2_core::numeric::Zero
579        + Send
580        + Sync
581        + 'static
582        + scirs2_core::num_traits::cast::NumCast,
583{
584    fn len(&self) -> usize {
585        self.dataset.len()
586    }
587
588    fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
589        // Update stats
590        match self.cache_stats.lock() {
591            Ok(mut stats) => stats.total_requests += 1,
592            Err(_) => {
593                return Err(TensorError::CacheError {
594                    operation: "persistent_cache_stats_update".to_string(),
595                    details: "Persistent cache stats mutex poisoned during total requests update"
596                        .to_string(),
597                    recoverable: false,
598                    context: None,
599                })
600            }
601        }
602
603        // Try cache first
604        let cache_result = match self.cache.lock() {
605            Ok(mut cache) => cache.get(&index),
606            Err(_) => {
607                return Err(TensorError::CacheError {
608                    operation: "persistent_cache_get".to_string(),
609                    details: "Persistent cache mutex poisoned during get operation".to_string(),
610                    recoverable: false,
611                    context: None,
612                })
613            }
614        };
615
616        if let Ok(Some(cached_sample)) = cache_result {
617            // Cache hit - update hit stats
618            match self.cache_stats.lock() {
619                Ok(mut stats) => stats.hits += 1,
620                Err(_) => {
621                    return Err(TensorError::CacheError {
622                        operation: "persistent_cache_hit_stats".to_string(),
623                        details: "Persistent cache stats mutex poisoned during hit update"
624                            .to_string(),
625                        recoverable: false,
626                        context: None,
627                    })
628                }
629            }
630            return Ok(cached_sample);
631        }
632
633        // Cache miss - load from dataset
634        let sample = self.dataset.get(index)?;
635
636        // Cache the result (currently a no-op due to serialization limitations)
637        match self.cache.lock() {
638            Ok(mut cache) => {
639                if let Err(e) = cache.insert(index, &sample.0, &sample.1) {
640                    // Log warning but don't fail the operation
641                    eprintln!("Warning: Failed to cache sample {index}: {e}");
642                }
643            }
644            Err(_) => {
645                // Log warning but don't fail the operation
646                eprintln!("Warning: Cache mutex poisoned during insert for sample {index}");
647            }
648        }
649
650        // Update miss stats
651        match self.cache_stats.lock() {
652            Ok(mut stats) => stats.misses += 1,
653            Err(_) => {
654                return Err(TensorError::CacheError {
655                    operation: "persistent_cache_miss_stats".to_string(),
656                    details: "Persistent cache stats mutex poisoned during miss update".to_string(),
657                    recoverable: false,
658                    context: None,
659                })
660            }
661        }
662
663        Ok(sample)
664    }
665}