tenflowers_dataset/cache/
persistent.rs1use crate::cache::dataset::CacheStats;
6use crate::Dataset;
7use std::collections::HashMap;
8use std::fs::{create_dir_all, File};
9use std::hash::Hash;
10use std::io::{BufReader, BufWriter};
11use std::marker::PhantomData;
12use std::path::Path;
13use std::sync::{Arc, Mutex};
14use tenflowers_core::{Result, Tensor, TensorError};
15
16#[cfg(feature = "serialize")]
17pub struct PersistentCache<K, V> {
19 cache_dir: std::path::PathBuf,
20 capacity: usize,
21 index: HashMap<K, (String, usize)>, access_counter: usize,
23 _phantom: PhantomData<V>,
24}
25
26impl<K, V> PersistentCache<K, V>
27where
28 K: Clone + Eq + Hash + std::fmt::Display + std::str::FromStr,
29 V: Clone + serde::Serialize + for<'de> serde::Deserialize<'de>,
30{
31 pub fn new<P: AsRef<Path>>(cache_dir: P, capacity: usize) -> Result<Self> {
33 let cache_dir = cache_dir.as_ref().to_path_buf();
34
35 if !cache_dir.exists() {
37 create_dir_all(&cache_dir).map_err(|e| {
38 TensorError::invalid_argument(format!("Failed to create cache directory: {e}"))
39 })?;
40 }
41
42 let mut cache = Self {
43 cache_dir,
44 capacity,
45 index: HashMap::new(),
46 access_counter: 0,
47 _phantom: PhantomData,
48 };
49
50 cache.load_index()?;
52
53 Ok(cache)
54 }
55
56 fn load_index(&mut self) -> Result<()> {
58 let index_path = self.cache_dir.join("cache_index.json");
59
60 if !index_path.exists() {
61 return Ok(()); }
63
64 let file = File::open(&index_path).map_err(|e| {
65 TensorError::invalid_argument(format!("Failed to open cache index: {e}"))
66 })?;
67
68 let reader = BufReader::new(file);
69
70 let index_data: HashMap<String, (String, usize)> = serde_json::from_reader(reader)
72 .map_err(|e| {
73 TensorError::invalid_argument(format!("Failed to parse cache index: {e}"))
74 })?;
75
76 for (key_str, (filename, access_order)) in index_data {
78 if let Ok(key) = key_str.parse::<K>() {
79 self.index.insert(key, (filename, access_order));
80 self.access_counter = self.access_counter.max(access_order);
81 }
82 }
83
84 self.access_counter += 1; Ok(())
87 }
88
89 fn save_index(&self) -> Result<()> {
91 let index_path = self.cache_dir.join("cache_index.json");
92
93 let file = File::create(&index_path).map_err(|e| {
94 TensorError::invalid_argument(format!("Failed to create cache index: {e}"))
95 })?;
96
97 let writer = BufWriter::new(file);
98
99 let index_data: HashMap<String, (String, usize)> = self
101 .index
102 .iter()
103 .map(|(k, v)| (k.to_string(), v.clone()))
104 .collect();
105
106 serde_json::to_writer(writer, &index_data).map_err(|e| {
107 TensorError::invalid_argument(format!("Failed to save cache index: {e}"))
108 })?;
109
110 Ok(())
111 }
112
113 pub fn get(&mut self, key: &K) -> Result<Option<V>> {
115 if let Some((filename, access_time)) = self.index.get_mut(key) {
116 self.access_counter += 1;
118 *access_time = self.access_counter;
119
120 let file_path = self.cache_dir.join(filename);
122
123 if !file_path.exists() {
124 self.index.remove(key);
126 return Ok(None);
127 }
128
129 let file = File::open(&file_path).map_err(|e| {
130 TensorError::invalid_argument(format!("Failed to open cache file: {e}"))
131 })?;
132
133 let reader = BufReader::new(file);
134
135 let value: V =
136 oxicode::serde::decode_from_std_read(reader, oxicode::config::standard())
137 .map_err(|e| {
138 TensorError::invalid_argument(format!(
139 "Failed to deserialize cached value: {e}"
140 ))
141 })?
142 .0;
143
144 Ok(Some(value))
145 } else {
146 Ok(None)
147 }
148 }
149
150 pub fn insert(&mut self, key: K, value: V) -> Result<()> {
152 self.access_counter += 1;
153
154 if self.index.len() >= self.capacity && !self.index.contains_key(&key) {
156 self.evict_lru()?;
157 }
158
159 let filename = format!("cache_{}_{}.bin", key, self.access_counter);
161 let file_path = self.cache_dir.join(&filename);
162
163 let file = File::create(&file_path).map_err(|e| {
165 TensorError::invalid_argument(format!("Failed to create cache file: {e}"))
166 })?;
167
168 let writer = BufWriter::new(file);
169
170 oxicode::serde::encode_into_std_write(&value, writer, oxicode::config::standard())
171 .map_err(|e| {
172 TensorError::invalid_argument(format!("Failed to serialize value: {e}"))
173 })?;
174
175 if let Some((old_filename, _)) = self.index.insert(key, (filename, self.access_counter)) {
177 let old_path = self.cache_dir.join(old_filename);
179 let _ = std::fs::remove_file(old_path); }
181
182 self.save_index()?;
184
185 Ok(())
186 }
187
188 fn evict_lru(&mut self) -> Result<()> {
190 if let Some((lru_key, (filename, _))) = self
191 .index
192 .iter()
193 .min_by_key(|(_, (_, access_time))| *access_time)
194 .map(|(k, v)| (k.clone(), v.clone()))
195 {
196 let file_path = self.cache_dir.join(&filename);
198 let _ = std::fs::remove_file(file_path); self.index.remove(&lru_key);
202 }
203
204 Ok(())
205 }
206
207 pub fn len(&self) -> usize {
209 self.index.len()
210 }
211
212 pub fn is_empty(&self) -> bool {
214 self.index.is_empty()
215 }
216
217 pub fn clear(&mut self) -> Result<()> {
219 for (filename, _) in self.index.values() {
221 let file_path = self.cache_dir.join(filename);
222 let _ = std::fs::remove_file(file_path); }
224
225 self.index.clear();
227 self.access_counter = 0;
228
229 self.save_index()?;
231
232 Ok(())
233 }
234
235 pub fn capacity(&self) -> usize {
237 self.capacity
238 }
239
240 pub fn cache_dir(&self) -> &Path {
242 &self.cache_dir
243 }
244}
245
246#[cfg(feature = "serialize")]
247pub struct TensorPersistentCache {
249 cache: PersistentCache<usize, (Vec<u8>, Vec<u8>)>, }
251
252impl TensorPersistentCache {
253 pub fn new<P: AsRef<Path>>(cache_dir: P, capacity: usize) -> Result<Self> {
255 Ok(Self {
256 cache: PersistentCache::new(cache_dir, capacity)?,
257 })
258 }
259
260 pub fn get<T>(&mut self, index: &usize) -> Result<Option<(Tensor<T>, Tensor<T>)>>
262 where
263 T: Clone
264 + Default
265 + scirs2_core::numeric::Zero
266 + Send
267 + Sync
268 + 'static
269 + scirs2_core::num_traits::cast::NumCast,
270 {
271 if let Some((features_bytes, labels_bytes)) = self.cache.get(index)? {
272 let features_tensor = Self::deserialize_tensor(&features_bytes)?;
274 let labels_tensor = Self::deserialize_tensor(&labels_bytes)?;
275 Ok(Some((features_tensor, labels_tensor)))
276 } else {
277 Ok(None)
278 }
279 }
280
281 pub fn insert<T>(
283 &mut self,
284 index: usize,
285 features: &Tensor<T>,
286 labels: &Tensor<T>,
287 ) -> Result<()>
288 where
289 T: Clone
290 + Default
291 + scirs2_core::numeric::Zero
292 + Send
293 + Sync
294 + 'static
295 + scirs2_core::num_traits::cast::NumCast,
296 {
297 let features_bytes = Self::serialize_tensor(features)?;
299 let labels_bytes = Self::serialize_tensor(labels)?;
300
301 self.cache.insert(index, (features_bytes, labels_bytes))?;
303 Ok(())
304 }
305
306 pub fn clear(&mut self) -> Result<()> {
308 self.cache.clear()
309 }
310
311 fn serialize_tensor<T>(tensor: &Tensor<T>) -> Result<Vec<u8>>
314 where
315 T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
316 {
317 let mut bytes = Vec::new();
318
319 let type_id = std::mem::size_of::<T>() as u8;
321 bytes.push(type_id);
322
323 let shape = tensor.shape().dims();
325 let shape_len = shape.len() as u32;
326 bytes.extend_from_slice(&shape_len.to_le_bytes());
327
328 for &dim in shape {
329 bytes.extend_from_slice(&(dim as u32).to_le_bytes());
330 }
331
332 if let Some(data_slice) = tensor.as_slice() {
334 for element in data_slice.iter() {
336 let element_ptr = element as *const T as *const u8;
338 let element_bytes = std::mem::size_of::<T>();
339 #[allow(unsafe_code)]
341 let element_data =
342 unsafe { std::slice::from_raw_parts(element_ptr, element_bytes) };
343 bytes.extend_from_slice(element_data);
344 }
345 } else {
346 return Err(TensorError::invalid_argument(
347 "Cannot serialize GPU tensors or tensors without CPU data".to_string(),
348 ));
349 }
350
351 Ok(bytes)
352 }
353
354 fn deserialize_tensor<T>(bytes: &[u8]) -> Result<Tensor<T>>
356 where
357 T: Clone
358 + Default
359 + scirs2_core::numeric::Zero
360 + Send
361 + Sync
362 + 'static
363 + scirs2_core::num_traits::cast::NumCast,
364 {
365 if bytes.len() < 5 {
366 return Err(TensorError::invalid_argument(
368 "Invalid tensor serialization: too few bytes".to_string(),
369 ));
370 }
371
372 let mut offset = 0;
373
374 let _type_id = bytes[offset];
376 offset += 1;
377
378 let shape_len = u32::from_le_bytes([
380 bytes[offset],
381 bytes[offset + 1],
382 bytes[offset + 2],
383 bytes[offset + 3],
384 ]) as usize;
385 offset += 4;
386
387 if bytes.len() < offset + shape_len * 4 {
388 return Err(TensorError::invalid_argument(
389 "Invalid tensor serialization: insufficient bytes for shape".to_string(),
390 ));
391 }
392
393 let mut shape = Vec::with_capacity(shape_len);
395 for _ in 0..shape_len {
396 let dim = u32::from_le_bytes([
397 bytes[offset],
398 bytes[offset + 1],
399 bytes[offset + 2],
400 bytes[offset + 3],
401 ]) as usize;
402 shape.push(dim);
403 offset += 4;
404 }
405
406 let total_elements = shape.iter().product::<usize>();
408 let element_size = std::mem::size_of::<T>();
409 let expected_data_bytes = total_elements * element_size;
410
411 if bytes.len() < offset + expected_data_bytes {
412 return Err(TensorError::invalid_argument(
413 "Invalid tensor serialization: insufficient bytes for data".to_string(),
414 ));
415 }
416
417 let data_bytes = &bytes[offset..offset + expected_data_bytes];
419
420 let mut data = Vec::with_capacity(total_elements);
422 for i in 0..total_elements {
423 let element_offset = i * element_size;
424
425 let value = match element_size {
427 1 => {
428 let byte_val = data_bytes[element_offset];
430 scirs2_core::num_traits::cast::NumCast::from(byte_val)
431 .unwrap_or_else(T::default)
432 }
433 2 => {
434 if element_offset + 2 <= data_bytes.len() {
436 let val = u16::from_le_bytes([
437 data_bytes[element_offset],
438 data_bytes[element_offset + 1],
439 ]);
440 scirs2_core::num_traits::cast::NumCast::from(val).unwrap_or_else(T::default)
441 } else {
442 T::default()
443 }
444 }
445 4 => {
446 if element_offset + 4 <= data_bytes.len() {
448 let val = f32::from_le_bytes([
449 data_bytes[element_offset],
450 data_bytes[element_offset + 1],
451 data_bytes[element_offset + 2],
452 data_bytes[element_offset + 3],
453 ]);
454 scirs2_core::num_traits::cast::NumCast::from(val).unwrap_or_else(T::default)
455 } else {
456 T::default()
457 }
458 }
459 8 => {
460 if element_offset + 8 <= data_bytes.len() {
462 let val = f64::from_le_bytes([
463 data_bytes[element_offset],
464 data_bytes[element_offset + 1],
465 data_bytes[element_offset + 2],
466 data_bytes[element_offset + 3],
467 data_bytes[element_offset + 4],
468 data_bytes[element_offset + 5],
469 data_bytes[element_offset + 6],
470 data_bytes[element_offset + 7],
471 ]);
472 scirs2_core::num_traits::cast::NumCast::from(val).unwrap_or_else(T::default)
473 } else {
474 T::default()
475 }
476 }
477 _ => {
478 T::default()
480 }
481 };
482
483 data.push(value);
484 }
485
486 Tensor::from_vec(data, &shape)
488 }
489}
490
491#[cfg(feature = "serialize")]
492pub struct PersistentlyCachedDataset<T, D: Dataset<T>> {
494 dataset: D,
495 cache: Arc<Mutex<TensorPersistentCache>>,
496 cache_stats: Arc<Mutex<CacheStats>>,
497 _phantom: PhantomData<T>,
498}
499
500impl<T, D: Dataset<T>> PersistentlyCachedDataset<T, D>
501where
502 T: Clone
503 + Default
504 + scirs2_core::numeric::Zero
505 + Send
506 + Sync
507 + 'static
508 + scirs2_core::num_traits::cast::NumCast,
509{
510 pub fn new<P: AsRef<Path>>(dataset: D, cache_dir: P, cache_capacity: usize) -> Result<Self> {
512 let cache = TensorPersistentCache::new(cache_dir, cache_capacity)?;
513
514 Ok(Self {
515 dataset,
516 cache: Arc::new(Mutex::new(cache)),
517 cache_stats: Arc::new(Mutex::new(CacheStats::default())),
518 _phantom: PhantomData,
519 })
520 }
521
522 pub fn cache_stats(&self) -> Result<CacheStats> {
524 match self.cache_stats.lock() {
525 Ok(stats) => Ok(stats.clone()),
526 Err(_) => Err(TensorError::CacheError {
527 operation: "persistent_cache_stats".to_string(),
528 details: "Persistent cache stats mutex poisoned".to_string(),
529 recoverable: true,
530 context: None,
531 }),
532 }
533 }
534
535 pub fn clear_cache(&self) -> Result<()> {
537 match self.cache.lock() {
538 Ok(mut cache) => cache.clear()?,
539 Err(_) => {
540 return Err(TensorError::CacheError {
541 operation: "persistent_cache_clear".to_string(),
542 details: "Persistent cache mutex poisoned during clear".to_string(),
543 recoverable: false,
544 context: None,
545 })
546 }
547 }
548
549 match self.cache_stats.lock() {
550 Ok(mut stats) => {
551 *stats = CacheStats::default();
552 Ok(())
553 }
554 Err(_) => Err(TensorError::CacheError {
555 operation: "persistent_cache_clear_stats".to_string(),
556 details: "Persistent cache stats mutex poisoned during clear".to_string(),
557 recoverable: false,
558 context: None,
559 }),
560 }
561 }
562
563 pub fn into_inner(self) -> D {
565 self.dataset
566 }
567
568 pub fn inner(&self) -> &D {
570 &self.dataset
571 }
572}
573
574impl<T, D: Dataset<T>> Dataset<T> for PersistentlyCachedDataset<T, D>
575where
576 T: Clone
577 + Default
578 + scirs2_core::numeric::Zero
579 + Send
580 + Sync
581 + 'static
582 + scirs2_core::num_traits::cast::NumCast,
583{
584 fn len(&self) -> usize {
585 self.dataset.len()
586 }
587
588 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
589 match self.cache_stats.lock() {
591 Ok(mut stats) => stats.total_requests += 1,
592 Err(_) => {
593 return Err(TensorError::CacheError {
594 operation: "persistent_cache_stats_update".to_string(),
595 details: "Persistent cache stats mutex poisoned during total requests update"
596 .to_string(),
597 recoverable: false,
598 context: None,
599 })
600 }
601 }
602
603 let cache_result = match self.cache.lock() {
605 Ok(mut cache) => cache.get(&index),
606 Err(_) => {
607 return Err(TensorError::CacheError {
608 operation: "persistent_cache_get".to_string(),
609 details: "Persistent cache mutex poisoned during get operation".to_string(),
610 recoverable: false,
611 context: None,
612 })
613 }
614 };
615
616 if let Ok(Some(cached_sample)) = cache_result {
617 match self.cache_stats.lock() {
619 Ok(mut stats) => stats.hits += 1,
620 Err(_) => {
621 return Err(TensorError::CacheError {
622 operation: "persistent_cache_hit_stats".to_string(),
623 details: "Persistent cache stats mutex poisoned during hit update"
624 .to_string(),
625 recoverable: false,
626 context: None,
627 })
628 }
629 }
630 return Ok(cached_sample);
631 }
632
633 let sample = self.dataset.get(index)?;
635
636 match self.cache.lock() {
638 Ok(mut cache) => {
639 if let Err(e) = cache.insert(index, &sample.0, &sample.1) {
640 eprintln!("Warning: Failed to cache sample {index}: {e}");
642 }
643 }
644 Err(_) => {
645 eprintln!("Warning: Cache mutex poisoned during insert for sample {index}");
647 }
648 }
649
650 match self.cache_stats.lock() {
652 Ok(mut stats) => stats.misses += 1,
653 Err(_) => {
654 return Err(TensorError::CacheError {
655 operation: "persistent_cache_miss_stats".to_string(),
656 details: "Persistent cache stats mutex poisoned during miss update".to_string(),
657 recoverable: false,
658 context: None,
659 })
660 }
661 }
662
663 Ok(sample)
664 }
665}